せっかくの週末にもかかわらず台風が来てしまい、テニスも出来なければ街歩きも出来ず暇を極めることになってしまったので、UCI ML repositoryを眺めていて見つけた適当なデータセットに対してTensorFlowで遊ぶということをしてみました。
基本的にはこのシリーズの前回の記事の続きです。
データセット
用いたデータセットはこちら。YouTubeの5種類の動画についたコメントに対して、スパムか否かのタグ付けがされたものです。一つ一つの動画に対するサンプルサイズだけでは小さ過ぎるので、手元で1つにまとめた上で、不具合のあった2行*1を削除したものを用意しました。
分類器
やることは簡単で、TF-Hubのpre-trained modelでコメント欄の内容を適当な特徴次元にembedし、これを使ってestimatorのDNNで分類するだけです。モデルはまずWikipediaコーパス由来のものにしてみました。
import tensorflow as tf import tensorflow_hub as hub import numpy as np import pandas as pd from sklearn.utils import shuffle if __name__ == "__main__": df = pd.read_csv("youtube_dataset.csv", delimiter='\t') df["category_id"] = df.CLASS.factorize()[0] df = shuffle(df) train_input_fn = tf.estimator.inputs.pandas_input_fn( df[:1500], df[:1500]["category_id"], num_epochs=None, shuffle=True) predict_test_input_fn = tf.estimator.inputs.pandas_input_fn( df[1500:], df[1500:]["category_id"], shuffle=False) embedded_text_feature_column = hub.text_embedding_column( key="CONTENT", module_spec="https://tfhub.dev/google/Wiki-words-500/1") estimator = tf.estimator.DNNClassifier( hidden_units=[500, 100], feature_columns=[embedded_text_feature_column], n_classes=2, optimizer=tf.train.AdagradOptimizer(learning_rate=0.005)) estimator.train(input_fn=train_input_fn, steps=800); test_eval_result = estimator.evaluate(input_fn=predict_test_input_fn) print("Test set accuracy: {accuracy}".format(**test_eval_result))
INFO:tensorflow:Saving dict for global step 800: accuracy = 0.9019608, accuracy_baseline = 0.503268, auc = 0.9525518, auc_precision_recall = 0.93581414, average_loss = 0.2916418, global_step = 800, label/mean = 0.49673203, loss = 33.465897, precision = 0.8645418, prediction/mean = 0.5046279, recall = 0.9517544 Test set accuracy: 0.901960790157
import seaborn as sns import matplotlib.pyplot as plt def get_predictions(estimator, input_fn): return [x["class_ids"][0] for x in estimator.predict(input_fn=input_fn)] LABELS = [ 0, 1 ] # Create a confusion matrix on training data. with tf.Graph().as_default(): cm = tf.confusion_matrix(df[1500:]["category_id"], get_predictions(estimator, predict_test_input_fn)) with tf.Session() as session: cm_out = session.run(cm) # Normalize the confusion matrix so that each row sums to 1. cm_out = cm_out.astype(float) / cm_out.sum(axis=1)[:, np.newaxis] sns.heatmap(cm_out, annot=True, xticklabels=LABELS, yticklabels=LABELS); plt.xlabel("Predicted"); plt.ylabel("True");
一応、ほとんどまともにチューニングもしていないDNNでもテストデータに対してACC 0.90というパフォーマンスを出すことが出来ました。TF-Hubのword embedding modelを使うと、ややこしいNLP周りの処理をしなくても結構がっさりと回せるので、当たりさえ付けば良いというレベルであればこれで十分ということで、楽で良いかなと。
今度は英語版Googleニュースのコーパスモデルを使います。一般にはこちらのモデルを使う方が多いと思われます。
import tensorflow as tf import tensorflow_hub as hub import numpy as np import pandas as pd from sklearn.utils import shuffle if __name__ == "__main__": df = pd.read_csv("youtube_dataset.csv", delimiter='\t') df["category_id"] = df.CLASS.factorize()[0] df = shuffle(df) train_input_fn = tf.estimator.inputs.pandas_input_fn( df[:1500], df[:1500]["category_id"], num_epochs=None, shuffle=True) predict_test_input_fn = tf.estimator.inputs.pandas_input_fn( df[1500:], df[1500:]["category_id"], shuffle=False) embedded_text_feature_column = hub.text_embedding_column( key="CONTENT", module_spec="https://tfhub.dev/google/nnlm-en-dim128/1") estimator = tf.estimator.DNNClassifier( hidden_units=[500, 100], feature_columns=[embedded_text_feature_column], n_classes=2, optimizer=tf.train.AdagradOptimizer(learning_rate=0.005)) estimator.train(input_fn=train_input_fn, steps=800); test_eval_result = estimator.evaluate(input_fn=predict_test_input_fn) print("Test set accuracy: {accuracy}".format(**test_eval_result))
INFO:tensorflow:Saving dict for global step 800: accuracy = 0.8714597, accuracy_baseline = 0.52505445, auc = 0.93307704, auc_precision_recall = 0.911756, average_loss = 0.4362399, global_step = 800, label/mean = 0.47494555, loss = 50.05853, precision = 0.832636, prediction/mean = 0.50560534, recall = 0.91284406 Test set accuracy: 0.871459722519
ACC 0.87とちょっと下がりました。チューニングで変動する範囲でもあるのでそこまで再現性が高いとも思えないのですが、あえて解釈するとすればスパムコメントのような「この単語が入っていたら大体はスパムだろう」的な情報はWikipediaコーパスの方が豊富なのかもしれません。ということで、週末の暇つぶしでした。
*1:Delimiterのミスが1行、タグ付け漏れでNAが1行