六本木で働くデータサイエンティストのブログ

元祖「銀座で働くデータサイエンティスト」です / 道玄坂→銀座→東京→六本木

生TensorFlow七転八倒記(6):TensorFlow Hubのtext embeddingsを使って日本語テキストを分類してみた

だいぶ久しぶりの生TensorFlow七転八倒記です。今回もただの備忘録につき、何一つ新しいことも参考になることも書いておりませんし、クソコードの羅列でしかありませんので、何か調べ物でたどり着かれた方はこの記事のリンク先などなどをご覧ください。


今回やろうと思ったのはテキスト分類です。というのは、従前はテキスト分類と言えば特徴量(=単語)がスパースゆえ潜在的な意味の類似性とかを勘案してモデリングしようと思ったらトピックモデルでやるしかないと思っていたのでした。トピックモデルについてはこちらの解説が今でも分かりやすいと思います。

ところが、TensorFlow Hubで学習済みのword embeddingsモデルが提供されるようになり、トピックモデルを使わなくてもword2vecと同じ理屈で、個々のドキュメントの内容をある決まった次元の特徴空間に射影したベクトルを使って、機械学習分類器をモデリングできるようになったと知りました。これはやるしかないでしょう。


ということで、実際にTensorFlow Hubの学習済みtext embeddingsモデルを使ってやってみます。基本的にはTensorFlow Hubの公式チュートリアルをなぞっているだけですが、分かりやすくまとめてくださった方のブログ記事があるのでそちらを参照します。

ちなみに某所でコードの不具合を聞きまくったせいで、わざわざ追記していただいてしまったようで。。。有難うございますm(_ _)m


青空文庫のデータセットで2クラス分類をやってみる


何度かこのブログでもお世話になっている青空文庫のデータを使うことにします。今回は簡単のため、夏目漱石『こころ』と島崎藤村『破戒』の2編だけを使い、それぞれの本編1行ずつをデータセットの1行とみなしてMeCab分かち書きした上で'souseki', 'touson'とラベル付けしておきます。この2編は微妙に全体の長さが違うので、適当にランダムに並び替えた上で短い方の『こころ』の長さに合わせて『破戒』をdownsamplingして揃え、さらに前処理として括弧類を全て削除しておきます。


前処理済みのデータをGitHubに置いてありますので、試してみたい方は以下からDLしてください。学習データ'train_aozora.csv'は1000行、テストデータ'test_aozora.csv'は200行余りです。

やることは完全に公式チュートリアルと同じで、text embeddingsモデルを使って特徴量を作り、これを高レベルAPIでDNNにかけて分類するだけです。ただし、オプティマイザだけAdagradではなく最近良く推奨されるAdamに替えてあります。

import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
import seaborn as sns
import matplotlib.pyplot as plt

if __name__ == "__main__":
    df_train = pd.read_csv("train_aozora.csv", sep='\t')
    df_train['category_id'] = df_train.type.factorize()[0]

    train_input_fn = tf.estimator.inputs.pandas_input_fn(
        df_train, df_train["category_id"], num_epochs=None, shuffle=True)

    df_test = pd.read_csv("test_aozora.csv", sep='\t')
    df_test['category_id'] = df_test.type.factorize()[0]

    predict_test_input_fn = tf.estimator.inputs.pandas_input_fn(
        df_test, df_test["category_id"], shuffle=False)
 
    embedded_text_feature_column = hub.text_embedding_column(
        key="text", 
        module_spec="https://tfhub.dev/google/nnlm-ja-dim128/1")

    estimator = tf.estimator.DNNClassifier(
        hidden_units=[512, 128],
        feature_columns=[embedded_text_feature_column],
        n_classes=2,
        optimizer=tf.train.AdamOptimizer(learning_rate=0.003))

    estimator.train(input_fn=train_input_fn, steps=1000);
    test_eval_result = estimator.evaluate(input_fn=predict_test_input_fn)
    print("Test set accuracy: {accuracy}".format(**test_eval_result))
# 前略 #
INFO:tensorflow:loss = 89.431, step = 1
INFO:tensorflow:global_step/sec: 129.804
INFO:tensorflow:loss = 2.4034119, step = 101 (0.772 sec)
INFO:tensorflow:global_step/sec: 154.871
INFO:tensorflow:loss = 0.502995, step = 201 (0.646 sec)
INFO:tensorflow:global_step/sec: 155.802
INFO:tensorflow:loss = 0.21461238, step = 301 (0.642 sec)
INFO:tensorflow:global_step/sec: 160.36
INFO:tensorflow:loss = 0.03812225, step = 401 (0.623 sec)
INFO:tensorflow:global_step/sec: 165.029
INFO:tensorflow:loss = 0.023185248, step = 501 (0.606 sec)
INFO:tensorflow:global_step/sec: 165.909
INFO:tensorflow:loss = 0.00883118, step = 601 (0.603 sec)
INFO:tensorflow:global_step/sec: 164.816
INFO:tensorflow:loss = 0.005985881, step = 701 (0.607 sec)
INFO:tensorflow:global_step/sec: 164.004
INFO:tensorflow:loss = 0.0021691, step = 801 (0.609 sec)
INFO:tensorflow:global_step/sec: 164.929
INFO:tensorflow:loss = 0.0032610907, step = 901 (0.607 sec)
INFO:tensorflow:Saving checkpoints for 1000 into /var/folders/yn/9h_42l352g739rm66y_3yznh0000gn/T/tmpZmW6ql/model.ckpt.
INFO:tensorflow:Loss for final step: 0.0012356994.
# 中略 #
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.9724771, accuracy_baseline = 0.5, auc = 0.98556525, auc_precision_recall = 0.9865266, average_loss = 0.15995193, global_step = 1000, label/mean = 0.5, loss = 17.43476, precision = 0.9557522, prediction/mean = 0.5152738, recall = 0.9908257
Test set accuracy: 0.972477078438

あっさりACC 0.97が出ました。ついでなのでチュートリアルに従ってconfusion matrixも出してみましょう。

def get_predictions(estimator, input_fn):
  return [x["class_ids"][0] for x in estimator.predict(input_fn=input_fn)]

LABELS = [
    0, 1
]

# Create a confusion matrix on training data.
with tf.Graph().as_default():
  cm = tf.confusion_matrix(df_test["category_id"], 
                           get_predictions(estimator, predict_test_input_fn))
  with tf.Session() as session:
    cm_out = session.run(cm)

# Normalize the confusion matrix so that each row sums to 1.
cm_out = cm_out.astype(float) / cm_out.sum(axis=1)[:, np.newaxis]

sns.heatmap(cm_out, annot=True, xticklabels=LABELS, yticklabels=LABELS);
plt.xlabel("Predicted");
plt.ylabel("True");

f:id:TJO:20180626212257p:plain

なかなか悪くない結果になりました。とりあえず、夏目漱石島崎藤村を1行単位のデータで比べる限りは、このやり方で十分に分類できるということが分かりました。


国交省のデータセットを使って多クラス分類してみる


調子に乗って、今度は以下のGitHubで公開されている国交省の車両に関するらしきデータセットを使って多クラス分類をやってみます。

前処理として、丸数字や半角英数字などを削除してあります(TF-Hubの日本語学習済みモデルは何と半角英数字に対してエラーを吐くので)。著作権の都合もありますので、このデータセットは僕の方からは公開しません。皆さんご自身でお試しくださいm(_ _)m なお学習データは37000行余り、テストデータは4700行余りあります。

import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
import seaborn as sns
import matplotlib.pyplot as plt

if __name__ == "__main__":
    df_train = pd.read_csv("kk_train.tsv", encoding="utf-8", sep='\t')
    df_train['category_id'] = df_train.type.factorize()[0]

    train_input_fn = tf.estimator.inputs.pandas_input_fn(
        df_train, df_train["category_id"], num_epochs=None, shuffle=True)

    df_test = pd.read_csv("kk_test.tsv", encoding="utf-8", sep='\t')
    df_test['category_id'] = df_test.type.factorize()[0]

    predict_test_input_fn = tf.estimator.inputs.pandas_input_fn(
        df_test, df_test["category_id"], shuffle=False)

    embedded_text_feature_column = hub.text_embedding_column(
        key="text", 
        module_spec="https://tfhub.dev/google/nnlm-ja-dim128/1")

    estimator = tf.estimator.DNNClassifier(
        hidden_units=[512, 128],
        feature_columns=[embedded_text_feature_column],
        n_classes=16,
        optimizer=tf.train.AdamOptimizer(learning_rate=0.003))

    estimator.train(input_fn=train_input_fn, steps=2000);
    test_eval_result = estimator.evaluate(input_fn=predict_test_input_fn)
    print("Test set accuracy: {accuracy}".format(**test_eval_result))
# 前後略 #
INFO:tensorflow:Saving dict for global step 2000: accuracy = 0.0724299, average_loss = 7.0073934, global_step = 2000, loss = 891.6435
Test set accuracy: 0.0724299028516

f:id:TJO:20180626215345p:plain

全然ダメじゃんorz 16クラス分類でACC 0.072なのでほぼchance levelです。。。やっぱりこれだけの多クラス分類だともっと学習データの行数が必要なんですかね? どうしたものかなぁと思って公式チュートリアルの下の方を見たら、こんなことが書いてありました。

Further improvements

  1. Regression on sentiment: we used a classifier to assign each example into a polarity class. But we actually have another categorical feature at our disposal - sentiment. Here classes actually represent a scale and the underlying value (positive/negative) could be well mapped into a continuous range. We could make use of this property by computing a regression (DNN Regressor) instead of a classification (DNN Classifier).
  2. Larger module: for the purposes of this tutorial we used a small module to restrict the memory use. There are modules with larger vocabularies and larger embedding space that could give additional accuracy points.
  3. Parameter tuning: we can improve the accuracy by tuning the meta-parameters like the learning rate or the number of steps, especially if we use a different module. A validation set is very important if we want to get any reasonable results, because it is very easy to set-up a model that learns to predict the training data without generalizing well to the test set.
  4. More complex model: we used a module that computes a sentence embedding by embedding each individual word and then combining them with average. One could also use a sequential module (e.g. Universal Sentence Encoder module) to better capture the nature of sentences. Or an ensemble of two or more TF-Hub modules.
  5. Regularization: to prevent overfitting, we could try to use an optimizer that does some sort of regularization, for example Proximal Adagrad Optimizer.

色々言っているようで、何も言っていないような。。。ともあれ必要があればこのガイドラインに沿ってもうちょっとあれこれ試してみようかと思います。


Text embeddingsは裏では何をやっているのか


ところで、text embeddingsは裏では何をやっているんでしょうか? ちょっと気になったので、TensorFlow Hubの公式ページに倣って以下のようにやってみました。

import tensorflow as tf
import tensorflow_hub as hub

with tf.Graph().as_default():
  embed = hub.Module("https://tfhub.dev/google/nnlm-ja-dim128-with-normalization/1")
  embeddings = embed([u"私 は 猫 で ある", u"猫 は 動物 で ある"])

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())

    print(sess.run(embeddings))
INFO:tensorflow:Downloading TF-Hub Module 'https://tfhub.dev/google/nnlm-ja-dim128-with-normalization/1'.
INFO:tensorflow:Downloaded TF-Hub Module 'https://tfhub.dev/google/nnlm-ja-dim128-with-normalization/1'.
INFO:tensorflow:Initialize variable module/embeddings/part_0:0 from checkpoint /var/folders/yn/9h_42l352g739rm66y_3yznh0000gn/T/tfhub_modules/f2587ea6ed3c5c0de24d63525ca949545fc6de9e/variables/variables with embeddings
[[ 8.52455199e-02  2.94013321e-03 -8.22397135e-03 -1.46733671e-02
  -1.72823593e-02  2.90279258e-02 -7.07929730e-02 -6.85681477e-02
  -7.48756155e-03 -8.18520486e-02 -5.90340532e-02  7.91012049e-02
  -2.91991364e-02 -2.60195076e-01  4.98240739e-02 -1.73879832e-01
  -1.70718566e-01 -1.53188765e-01  2.90644839e-02 -7.03767464e-02
  -5.78003265e-02  7.17197284e-02  2.06977129e-05  4.11862656e-02
  -1.12049118e-01 -1.78419292e-01 -1.26710907e-01  2.67293863e-03
  -5.20219430e-02 -1.88256279e-02 -9.61715803e-02  7.35456496e-02
   1.22188404e-03 -1.28141582e-01  1.56235471e-02 -9.11903158e-02
  -7.63232112e-02  5.42782322e-02 -5.15313074e-02 -1.90471262e-01
   3.20132673e-02 -1.35457635e-01  1.68082714e-02  7.00900927e-02
   5.35486042e-02 -5.86200505e-03  1.06252551e-01 -5.24865612e-02
  -3.08582243e-02  2.08250955e-02  7.91827664e-02 -7.52752945e-02
  -1.73314512e-02 -2.62607113e-02 -1.87900215e-02 -5.49748540e-03
  -1.24954190e-02 -7.70226270e-02  4.65861335e-02 -5.75305261e-02
   4.26497571e-02 -4.30508032e-02 -8.71970505e-03  5.88850975e-02
  -1.04124591e-01  2.32928365e-01  2.08118260e-01  6.19726107e-02
  -5.32182157e-02 -5.08430637e-02 -1.09306917e-01  3.90085429e-02
   1.46081582e-01  2.04974785e-02  1.49110425e-02  2.28865296e-02
   1.63206011e-02  4.29741219e-02  3.39215100e-02  1.43105552e-01
   2.04084441e-01 -5.62003814e-03 -5.52381091e-02  5.23940623e-02
  -6.61633834e-02 -6.30271435e-02 -4.47466299e-02  8.09627324e-02
  -5.67415208e-02 -7.18336776e-02 -6.73303530e-02  9.41584632e-02
  -2.17595883e-02 -3.60757336e-02 -2.00801771e-02  9.24135968e-02
  -3.96264270e-02  9.19071734e-02  9.62694883e-02  9.98672470e-03
   6.10636920e-03 -2.84189656e-02 -6.34012520e-02  1.05268247e-01
  -1.38996735e-01  1.48895934e-01 -1.67204142e-02  8.91915336e-03
   1.61656179e-02  1.39357075e-02 -2.07343027e-01 -7.84583390e-02
   9.34032127e-02 -8.10054541e-02 -5.95692061e-02  1.12060737e-02
  -4.03481647e-02 -3.83854657e-02 -2.76221670e-02 -8.31059217e-02
  -7.44923502e-02 -1.17676266e-01  8.66378769e-02  5.51073402e-02
   1.39137879e-02  7.48097897e-02 -9.77614522e-02  2.04898883e-02]
 [ 5.68491995e-01 -8.65925755e-03  1.18343875e-01  8.14386606e-02
  -1.37463808e-01 -6.91957995e-02 -8.17279145e-02 -1.96645781e-01
  -1.83535237e-02 -9.37304944e-02 -6.18879721e-02  1.47726787e-02
  -1.36250108e-01  1.64986018e-03  2.77892426e-02 -3.61956507e-02
  -1.24522433e-01  3.62478308e-02  2.21148804e-01 -2.34831264e-03
  -5.89277968e-02 -1.03735467e-02 -3.10243331e-02  1.34466946e-01
  -1.98412389e-02 -1.09978944e-01 -2.22971320e-01  1.66467745e-02
   1.33619845e-01  6.92416281e-02 -1.21783525e-01  5.64023107e-03
   1.39355600e-01 -1.08985685e-01 -1.22326287e-02 -6.02420233e-02
  -4.33986969e-02  1.51147824e-02 -1.13851940e-02 -1.40483841e-01
   5.33875115e-02 -9.06368867e-02  3.14746834e-02 -1.19504504e-01
   1.78923011e-01  1.55183775e-02  7.42514133e-02 -2.68397510e-01
   1.45406020e-03  1.38821200e-01  1.43817425e-01  7.08675906e-02
   1.26253348e-03  1.43299818e-01  4.36784141e-02  4.62467931e-02
   6.21854179e-02 -8.35543722e-02  6.95923939e-02 -3.35774794e-02
  -1.89198609e-02 -1.78088583e-02 -1.10436425e-01  1.60233956e-02
  -1.64298892e-01  2.13777870e-02  3.06044016e-02 -3.06188818e-02
  -3.38123776e-02 -4.75936979e-02  4.24533561e-02 -1.55665884e-02
   5.64597808e-02 -3.59726213e-02  5.57259880e-02 -8.41505975e-02
  -5.01365103e-02 -9.59435627e-02 -6.92752451e-02 -2.76028588e-02
   2.17243910e-01  3.59871201e-02 -1.76482916e-01 -4.18751799e-02
  -1.78482339e-01  1.20470814e-01 -9.16808769e-02  8.27920809e-02
   1.11592263e-02 -8.79351720e-02 -4.39643525e-02 -1.10055231e-01
  -4.60435152e-02  3.79438661e-02 -1.72184203e-02 -1.36174649e-01
  -5.88618498e-03  5.77106960e-02  8.20119604e-02 -2.57646404e-02
  -8.07605013e-02 -7.74711296e-02  4.67604958e-02  5.37206829e-02
  -1.35630861e-01  9.78991538e-02 -1.20537942e-02  9.40130875e-02
  -3.98719162e-02  1.01150006e-01  3.09805982e-02  4.26106937e-02
  -4.82343277e-03 -5.86888045e-02 -6.03201799e-02 -2.15122607e-02
  -1.08524218e-01 -7.65707344e-02  2.84913797e-02 -1.53922603e-01
  -1.63088903e-01 -4.40106466e-02 -2.51780581e-02  8.10946226e-02
   1.14474252e-01  1.06071301e-01  4.96090427e-02  4.83132862e-02]]

こういう感じで、それぞれの文章が(このモデルだと)128次元の特徴量に変換されるということが分かります。ということで、今回もお後がよろしいようで。。。いやあまりよろしくないのかorz