개발 공부/Machine Learning, Deep Learning

CNN을 이용한 텍스트 분류

뚜덩ㅇ 2022. 6. 30. 23:01
반응형

 

개인적으로 공부한 내용을 정리하여 작성하는 글입니다. 내용이 틀릴 수 있으니 참고 바랍니다.

수정할 오류가 있다면 알려주시면 정말 감사하겠습니다! 

 

전체 코드는 아래 Github주소에서 확인할 수 있습니다.

https://github.com/chaehyun29/MLDL/blob/main/rnn_text_classification.ipynb

 

GitHub - chaehyun29/MLDL

Contribute to chaehyun29/MLDL development by creating an account on GitHub.

github.com

 

 

데이터 셋

작업 전 우선 어휘 사전의 크기 와 임베딩 사이즈를 지정해준다.

VOCA_SIZE = 10000 # 어휘 사전의 크기
EMBEDDING_SIZE = 64 # 단어를 임베딩한 벡터 크기

 

 

keras의 datasets 를 이용해 데이터를 불러온다. imdb.data_load() 를 이용하면 영화 리뷰 데이터를 다운 받을 수 있다.

이때 num_word 로 빈도 순위로 몇등까지의 단어를 사용할 것인지를 지정한다. 

import tensorflow as tf

print('Loading data...')
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.imdb.load_data(num_words=VOCA_SIZE)

위에서  VOCA_SIZE를 10000으로 설정했으니 1~10000등 까지의 단어만 사용하는 것이고 이는 즉 10000개의 단어를 사용한다는 뜻이다. 

 

print(train_x.shape)
print(train_y.shape)
print(test_x.shape)
print(test_y.shape)
(25000,)
(25000,)
(25000,)
(25000,)

훈련용, 테스트용 데이터를 모두 25,000개를 받았다.

어떤 값을 받았는지 확인 해보면

print(train_x[:5])
print(train_y[:5])
[list([1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 2, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 2, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 2, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 2, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 2, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 2, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 2, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 2, 113, 103, 32, 15, 16, 2, 19, 178, 32])
 list([1, 194, 1153, 194, 2, 78, 228, 5, 6, 1463, 2, 2, 134, 26, 4, 715, 8, 118, 1634, 14, 394, 20, 13, 119, 954, 189, 102, 5, 207, 110, 3103, 21, 14, 69, 188, 8, 30, 23, 7, 4, 249, 126, 93, 4, 114, 9, 2300, 1523, 5, 647, 4, 116, 9, 35, 2, 4, 229, 9, 340, 1322, 4, 118, 9, 4, 130, 2, 19, 4, 1002, 5, 89, 29, 952, 46, 37, 4, 455, 9, 45, 43, 38, 1543, 1905, 398, 4, 1649, 26, 2, 5, 163, 11, 3215, 2, 4, 1153, 9, 194, 775, 7, 2, 2, 349, 2637, 148, 605, 2, 2, 15, 123, 125, 68, 2, 2, 15, 349, 165, 2, 98, 5, 4, 228, 9, 43, 2, 1157, 15, 299, 120, 5, 120, 174, 11, 220, 175, 136, 50, 9, 2, 228, 2, 5, 2, 656, 245, 2350, 5, 4, 2, 131, 152, 491, 18, 2, 32, 2, 1212, 14, 9, 6, 371, 78, 22, 625, 64, 1382, 9, 8, 168, 145, 23, 4, 1690, 15, 16, 4, 1355, 5, 28, 6, 52, 154, 462, 33, 89, 78, 285, 16, 145, 95])
 list([1, 14, 47, 8, 30, 31, 7, 4, 249, 108, 7, 4, 2, 54, 61, 369, 13, 71, 149, 14, 22, 112, 4, 2401, 311, 12, 16, 3711, 33, 75, 43, 1829, 296, 4, 86, 320, 35, 534, 19, 263, 2, 1301, 4, 1873, 33, 89, 78, 12, 66, 16, 4, 360, 7, 4, 58, 316, 334, 11, 4, 1716, 43, 645, 662, 8, 257, 85, 1200, 42, 1228, 2578, 83, 68, 3912, 15, 36, 165, 1539, 278, 36, 69, 2, 780, 8, 106, 14, 2, 1338, 18, 6, 22, 12, 215, 28, 610, 40, 6, 87, 326, 23, 2300, 21, 23, 22, 12, 272, 40, 57, 31, 11, 4, 22, 47, 6, 2307, 51, 9, 170, 23, 595, 116, 595, 1352, 13, 191, 79, 638, 89, 2, 14, 9, 8, 106, 607, 624, 35, 534, 6, 227, 7, 129, 113])
 list([1, 4, 2, 2, 33, 2804, 4, 2040, 432, 111, 153, 103, 4, 1494, 13, 70, 131, 67, 11, 61, 2, 744, 35, 3715, 761, 61, 2, 452, 2, 4, 985, 7, 2, 59, 166, 4, 105, 216, 1239, 41, 1797, 9, 15, 7, 35, 744, 2413, 31, 8, 4, 687, 23, 4, 2, 2, 6, 3693, 42, 38, 39, 121, 59, 456, 10, 10, 7, 265, 12, 575, 111, 153, 159, 59, 16, 1447, 21, 25, 586, 482, 39, 4, 96, 59, 716, 12, 4, 172, 65, 9, 579, 11, 2, 4, 1615, 5, 2, 7, 2, 17, 13, 2, 12, 19, 6, 464, 31, 314, 11, 2, 6, 719, 605, 11, 8, 202, 27, 310, 4, 3772, 3501, 8, 2722, 58, 10, 10, 537, 2116, 180, 40, 14, 413, 173, 7, 263, 112, 37, 152, 377, 4, 537, 263, 846, 579, 178, 54, 75, 71, 476, 36, 413, 263, 2504, 182, 5, 17, 75, 2306, 922, 36, 279, 131, 2895, 17, 2867, 42, 17, 35, 921, 2, 192, 5, 1219, 3890, 19, 2, 217, 2, 1710, 537, 2, 1236, 5, 736, 10, 10, 61, 403, 9, 2, 40, 61, 2, 5, 27, 2, 159, 90, 263, 2311, 2, 309, 8, 178, 5, 82, 2, 4, 65, 15, 2, 145, 143, 2, 12, 2, 537, 746, 537, 537, 15, 2, 4, 2, 594, 7, 2, 94, 2, 3987, 2, 11, 2, 4, 538, 7, 1795, 246, 2, 9, 2, 11, 635, 14, 9, 51, 408, 12, 94, 318, 1382, 12, 47, 6, 2683, 936, 5, 2, 2, 19, 49, 7, 4, 1885, 2, 1118, 25, 80, 126, 842, 10, 10, 2, 2, 2, 27, 2, 11, 1550, 3633, 159, 27, 341, 29, 2733, 19, 2, 173, 7, 90, 2, 8, 30, 11, 4, 1784, 86, 1117, 8, 3261, 46, 11, 2, 21, 29, 9, 2841, 23, 4, 1010, 2, 793, 6, 2, 1386, 1830, 10, 10, 246, 50, 9, 6, 2750, 1944, 746, 90, 29, 2, 8, 124, 4, 882, 4, 882, 496, 27, 2, 2213, 537, 121, 127, 1219, 130, 5, 29, 494, 8, 124, 4, 882, 496, 4, 341, 7, 27, 846, 10, 10, 29, 9, 1906, 8, 97, 6, 236, 2, 1311, 8, 4, 2, 7, 31, 7, 2, 91, 2, 3987, 70, 4, 882, 30, 579, 42, 9, 12, 32, 11, 537, 10, 10, 11, 14, 65, 44, 537, 75, 2, 1775, 3353, 2, 1846, 4, 2, 7, 154, 5, 4, 518, 53, 2, 2, 7, 3211, 882, 11, 399, 38, 75, 257, 3807, 19, 2, 17, 29, 456, 4, 65, 7, 27, 205, 113, 10, 10, 2, 4, 2, 2, 9, 242, 4, 91, 1202, 2, 5, 2070, 307, 22, 7, 2, 126, 93, 40, 2, 13, 188, 1076, 3222, 19, 4, 2, 7, 2348, 537, 23, 53, 537, 21, 82, 40, 2, 13, 2, 14, 280, 13, 219, 4, 2, 431, 758, 859, 4, 953, 1052, 2, 7, 2, 5, 94, 40, 25, 238, 60, 2, 4, 2, 804, 2, 7, 4, 2, 132, 8, 67, 6, 22, 15, 9, 283, 8, 2, 14, 31, 9, 242, 955, 48, 25, 279, 2, 23, 12, 1685, 195, 25, 238, 60, 796, 2, 4, 671, 7, 2804, 5, 4, 559, 154, 888, 7, 726, 50, 26, 49, 2, 15, 566, 30, 579, 21, 64, 2574])
 list([1, 249, 1323, 7, 61, 113, 10, 10, 13, 1637, 14, 20, 56, 33, 2401, 18, 457, 88, 13, 2626, 1400, 45, 3171, 13, 70, 79, 49, 706, 919, 13, 16, 355, 340, 355, 1696, 96, 143, 4, 22, 32, 289, 7, 61, 369, 71, 2359, 5, 13, 16, 131, 2073, 249, 114, 249, 229, 249, 20, 13, 28, 126, 110, 13, 473, 8, 569, 61, 419, 56, 429, 6, 1513, 18, 35, 534, 95, 474, 570, 5, 25, 124, 138, 88, 12, 421, 1543, 52, 725, 2, 61, 419, 11, 13, 1571, 15, 1543, 20, 11, 4, 2, 5, 296, 12, 3524, 5, 15, 421, 128, 74, 233, 334, 207, 126, 224, 12, 562, 298, 2167, 1272, 7, 2601, 5, 516, 988, 43, 8, 79, 120, 15, 595, 13, 784, 25, 3171, 18, 165, 170, 143, 19, 14, 5, 2, 6, 226, 251, 7, 61, 113])]
[1 0 0 1 0]

이런 형태이다.  카테고리는 0,1 두가지로

len(set(y_train))

위 값이 2로 나온다.

영화 리뷰 문장이 긍정리뷰인지, 부정리뷰인지를 카테고리로 한다. 긍정일 경우 1이다.

 

 

데이터 전처리

imdb.get_word_index() 는 단어와 정수 인덱스를 매핑한 딕셔너리를 반환한다.

word_index = tf.keras.datasets.imdb.get_word_index()
print(word_index)
{'fawn': 34701, 'tsukino': 52006, 'nunnery': 52007, ...

인덱스로도 단어에 접근 할 수 있도록 리버스한 딕셔너리를 만들어준다.

reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])

그러면 train_x의 값들로 단어들을 꺼내올 수 있다.

def decode_review(text):
    return ' '.join([reverse_word_index.get(i, '?') for i in text])

print(train_x[0])
print(decode_review(train_x[0]))
[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 2, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 2, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 2, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 2, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 2, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 2, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 2, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 2, 113, 103, 32, 15, 16, 2, 19, 178, 32]
<START> this film was just brilliant casting location scenery story direction <UNK> really suited the part they played and you could just imagine being there robert <UNK> is an amazing actor and now the same being director <UNK> father came from the same <UNK> island as myself so i loved the fact there was a real connection with this film the witty <UNK> throughout the film were great it was just brilliant so much that i bought the film as soon as it was released for <UNK> and would recommend it to everyone to watch and the fly <UNK> was amazing really cried at the end it was so sad and you know what they say if you cry at a film it must have been good and this definitely was also <UNK> to the two little <UNK> that played the <UNK> of norman and paul they were just brilliant children are often left out of the <UNK> list i think because the stars that play them all grown up are such a big <UNK> for the whole film but these children are amazing and should be <UNK> for what they have done don't you think the whole story was so lovely because it was true and was <UNK> life after all that was <UNK> with us all

 

데이터들의 길이가 다르면 제대로 훈련이 되지 않는다. 길이를 확인해보자.

print(len(train_x[0]))
print(len(train_x[1]))
print(len(train_x[2]))
print(len(train_x[3]))
print(len(train_x[4]))
218
189
141
550
147

역시 제각각이다. keras.preprocessing의 sequence를 이용해서 패딩해보자.

from tensorflow.keras.preprocessing import sequence

train_x = sequence.pad_sequences(train_x, maxlen=400, padding='post')
test_x = sequence.pad_sequences(test_x, maxlen=400, padding='post')
print(train_x.shape)
print(test_x.shape)

sequence.pad_sequences(x, maxlen = None , padding = 'pre', value = 0.0, truncating = 'pre')

maxlen :은 최대 얼마까지 패딩을 할건지 길이를 지정한다.

padding 에서 'post'는 뒤쪽을 채운다는 의미이다. 디폴트 값은 'pre'로, 앞쪽을 채운다.

value 는 어떤 값으로 채울지 지정한다. 디폴트 값은 0이다.

truncating 은 최대길이를 초과한 경우 어디서 자를지를 지정한다. 디폴트값인 'pre'는 앞, 'post'로 지정할 경우 뒤쪽에서 자른다.

 

print(train_x.shape)
print(test_x.shape)
(25000, 400)
(25000, 400)

패딩이 잘 완료되었다.

 

 

CNN모델 사용

from tensorflow.keras.preprocessing import sequence
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Input, Dense, Dropout, Activation
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import Conv1D, GlobalMaxPooling1D
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

 

패딩한 문자 길이가 400이므로 Input은 400이다.이진 분류이기 때문에 마지막 출력층은 1이고 loss 함수도 binary cross entropy로 설정한다.

model = Sequential()
model.add(Embedding(VOCA_SIZE, EMBEDDING_SIZE))
model.add(GRU(hidden_units))
model.add(Dense(1, activation='sigmoid'))

es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=4)
mc = ModelCheckpoint('GRU_model.h5', monitor='val_acc', mode='max', verbose=1, save_best_only=True)

model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
history = model.fit(train_x, train_y, epochs=15, callbacks=[es, mc], batch_size=64,validation_split=0.2)

epochs 는 15회로 두고 콜백 함수로 EarlyStopping, ModelCheckpoint 를 사용했다.  val_loss값이 가장 높은 모델을 저장하고 이후 4회까지 더 돌다 더 높은게 안나오면 종료 된다.

 

 

이제 저장된 모델을 불러와 정확도를 확인해보자

model = load_model('CNN_test_model.h5')
loss, acc = model.evaluate(test_x, test_y)
print("loss =", loss)
print("acc =", acc)
782/782 [==============================] - 2s 3ms/step - loss: 0.2519 - acc: 0.8960
loss = 0.2518715560436249
acc = 0.8960400223731995

accuracy 값이 0.896이 나왔다.

 

 

 

이제 실제 리뷰 글을 가져와 모델이 어떻게 판단하는지 확인해 보자.

#copy from : https://wikidocs.net/24586

def sentiment_predict(new_sentence):
  # 알파벳과 숫자를 제외하고 모두 제거 및 알파벳 소문자화
  new_sentence = re.sub('[^0-9a-zA-Z ]', '', new_sentence).lower()
  encoded = []

  # 띄어쓰기 단위 토큰화 후 정수 인코딩
  for word in new_sentence.split():
    try :
      # 단어 집합의 크기를 10,000으로 제한.
      if word_index[word] <= VOCA_SIZE:
        encoded.append(word_index[word]+3)
      else:
      # 10,000 이상의 숫자는 <unk> 토큰으로 변환.
        encoded.append(2)
    # 단어 집합에 없는 단어는 <unk> 토큰으로 변환.
    except KeyError as e:
      print(e)
      encoded.append(2)

  pad_sequence = sequence.pad_sequences([encoded], maxlen=400)
  score = float(model.predict(pad_sequence)) # 예측

  if(score > 0.5):
    print("{:.2f}% 확률로 긍정 리뷰입니다.".format(score * 100))
  else:
    print("{:.2f}% 확률로 부정 리뷰입니다.".format((1 - score) * 100))

문자를 가져와서 확인하기 때문에  키가 단어고 값이 인덱스인 word_index 리스트에 넣어 인덱스 값을 확인한다.  처음 설정해준 VOCA_SIZE 보다 크면, 즉 받아둔 단어가 아니라면 <UNK> 토큰으로 변환시킨다. VOCA_SIZE에 있는 단어라면 인덱스 값으로 변환하여 리스트에 넣고 이를 패딩하고 score를 확인한다.

 

로튼토마토에서 현재 상영되고 있는 '탑건:매버릭' 의 리뷰 중 별점 5개짜리 평을 가져와서 확인해 봤다.

test_input = "Somehow Cruises foray back into the danger zone will be remembered more than the original, setting a new standard in the era of reboots."
sentiment_predict(test_input)
97.68% 확률로 긍정 리뷰입니다.

학습이 잘 된 것 같다.

 

이번엔 별점 반개짜리 평을 가져왔다.

test_input = "I have nothing to say other than this being yet another sequel made 30 plus years after the original with painful dialogue and improves on nothing! This is just there for nostalgia purposes and trying to get your ass in the seats. Tom Cruise isn't even that good of an actor. Nobody asked for this lazy excuse of a movie!"
sentiment_predict(test_input)
93.23% 확률로 부정 리뷰입니다.

이번에도 학습이 잘 된 것으로 보인다.

 

반응형