渋谷駅前で働くデータサイエンティストのブログ

元祖「六本木で働くデータサイエンティスト」です / 道玄坂→銀座→東京→六本木→渋谷駅前

生TensorFlow七転八倒記(7):TensorFlow Hubの通常の英語コーパスではなくWikipedia英語版コーパスのtext embeddingを使ってみた

これは前回の記事の続きです。

小ネタにしてただの備忘録ですので、予めご了承ください。


前回の記事で元々参考にさせていただいた以下のブログ記事なんですが、これは基本的に英語NNLMの128次元embeddingで試したものなんですね。そのままやるとACC 0.965ぐらい出ます。


で、ボサーッとTensorFlow Hubのサイトを眺めていたら、NNLMの中に英語版Wikipedia記事をコーパスにして500次元のembeddingにまとめるモデルがあるなと気付きまして。ということで、面白そうなのでこのモデルを使ってやり直してみました。以下にその詳細を書いておきます。


前回同様、単にTF-Hubのモデルで特徴量を作ってDNNを回すだけ


データセットは上記リンク先ブログ記事同様にBBCのニューステキストを使います。ただしファイル名があまりにも一般的過ぎるので、"bbc_dataset.csv"と変えてあります。


あとは前回同様にJuPyter Notebookで回すだけです。

import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
import seaborn as sns
import matplotlib.pyplot as plt

if __name__ == "__main__":
    df = pd.read_csv("bbc_dataset.csv")
    df['category_id'] = df.type.factorize()[0]
    df = shuffle(df)

    train_input_fn = tf.estimator.inputs.pandas_input_fn(
        df[:1500], df[:1500]["category_id"], num_epochs=None, shuffle=True)

    predict_test_input_fn = tf.estimator.inputs.pandas_input_fn(
        df[1500:], df[1500:]["category_id"], shuffle=False)

    embedded_text_feature_column = hub.text_embedding_column(
        key="news", 
        module_spec="https://tfhub.dev/google/Wiki-words-500/1")

    estimator = tf.estimator.DNNClassifier(
        hidden_units=[500, 100],
        feature_columns=[embedded_text_feature_column],
        n_classes=5,
        optimizer=tf.train.AdagradOptimizer(learning_rate=0.003))

    estimator.train(input_fn=train_input_fn, steps=1000);
    test_eval_result = estimator.evaluate(input_fn=predict_test_input_fn)
    print("Test set accuracy: {accuracy}".format(**test_eval_result))
# 前略 #
INFO:tensorflow:loss = 220.5558, step = 1
INFO:tensorflow:global_step/sec: 33.3464
INFO:tensorflow:loss = 20.075544, step = 101 (3.001 sec)
INFO:tensorflow:global_step/sec: 33.9003
INFO:tensorflow:loss = 22.160091, step = 201 (2.950 sec)
INFO:tensorflow:global_step/sec: 34.588
INFO:tensorflow:loss = 13.642236, step = 301 (2.890 sec)
INFO:tensorflow:global_step/sec: 34.5069
INFO:tensorflow:loss = 12.801352, step = 401 (2.898 sec)
INFO:tensorflow:global_step/sec: 34.4459
INFO:tensorflow:loss = 10.384497, step = 501 (2.903 sec)
INFO:tensorflow:global_step/sec: 33.7817
INFO:tensorflow:loss = 6.273137, step = 601 (2.961 sec)
INFO:tensorflow:global_step/sec: 33.7938
INFO:tensorflow:loss = 10.3562975, step = 701 (2.960 sec)
INFO:tensorflow:global_step/sec: 33.4937
INFO:tensorflow:loss = 11.7235155, step = 801 (2.985 sec)
INFO:tensorflow:global_step/sec: 32.5283
INFO:tensorflow:loss = 6.1770945, step = 901 (3.074 sec)
INFO:tensorflow:Saving checkpoints for 1000 into /var/folders/yn/9h_42l352g739rm66y_3yznh0000gn/T/tmp53LIpo/model.ckpt.
INFO:tensorflow:Loss for final step: 5.46763.
# 中略 #
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.96827585, average_loss = 0.09233589, global_step = 1000, loss = 11.157253
Test set accuracy: 0.968275845051

あっさりACC 0.968が出ました。ついでにconfusion matrixを書いてみます。

import seaborn as sns
import matplotlib.pyplot as plt

def get_predictions(estimator, input_fn):
  return [x["class_ids"][0] for x in estimator.predict(input_fn=input_fn)]

LABELS = [
    0, 1, 2, 3, 4
]

# Create a confusion matrix on training data.
with tf.Graph().as_default():
  cm = tf.confusion_matrix(df[1500:]["category_id"], 
                           get_predictions(estimator, predict_test_input_fn))
  with tf.Session() as session:
    cm_out = session.run(cm)

# Normalize the confusion matrix so that each row sums to 1.
cm_out = cm_out.astype(float) / cm_out.sum(axis=1)[:, np.newaxis]

sns.heatmap(cm_out, annot=True, xticklabels=LABELS, yticklabels=LABELS);
plt.xlabel("Predicted");
plt.ylabel("True");

f:id:TJO:20180708144019p:plain

非常に綺麗に5クラスとも分類されているのが分かります。


感想


ある意味予想通りと言えば予想通りなんですが、通常の英語コーパスのNNLMがGoogle News英語版の記事を使っていることを考えると、似たようなニュース記事の分類には今回使ったWikipedia英語版コーパスNNLMだと精度が落ちるかなと思ったんですが、思ったほど落ちなかったというか(乱数シードの差程度だとは思いますが)僅かにACCが上がったのがちょっと面白かったです。


これで誰かWikipedia日本語版コーパスのNNLM作ってくれたら便利だよなーと思ったんですが、そういうことを言ってるとお前がやれというツッコミが飛んでくるのでこの辺にしておきますm(_ _)m