渋谷駅前で働くデータサイエンティストのブログ

元祖「六本木で働くデータサイエンティスト」です / 道玄坂→銀座→東京→六本木→渋谷駅前

Fashion-MNIST: 簡単になり過ぎたMNISTに代わる初心者向け画像認識ベンチマーク

f:id:TJO:20200108100655p:plain
(MNIST database - Wikipedia)

僕は画像認識分野は門外漢なのですが、ここ最近初心者向けにCNNのトレーニングを行うことを企画していて、その目的に適した画像認識のオープンデータセットを探していたのでした。


というと誰しも思いつくのがMNISTではないかと思うのですが、Kaggleのベンチマークにも出ているように、実はMNISTはチューニングなしのデフォルトのランダムフォレストで回しても97%以上のACCが出てしまいます。そしてちょっとチューニングしたCNNなら99.7%を叩き出せてしまう上に、そういったノウハウがネットのあちこちにHello World並みのイージーハウツーコンテンツとして溢れ返っていて、初心者向け教材という意味では全く参考になりません。そこで、ちょっとサーベイして探してみることにしました。


MNIST以外のMNIST的なデータセットを探す


実は、以前からあるそういう声に応えるためか、様々なフォントのアルファベットを収録したnotMNISTというデータベースが9年も前からあるのですが、データ形式に難があったり、初心者が手を付けるには若干難しいという側面もあるせいか、あまり普及していないようです。


そこで、まずググって見つけたこちらの一覧をチェックしてみたのですが、初心者向けにCIFARとかImageNetとか薦めてもヘビーなだけなのでどうしようかと思い、そこからさらにリンクを踏んで別の一覧を探しに行ったのでした。


で、その次にこちらの一覧を拝見した結果、とあるMNISTの改良版にたどり着きました。


f:id:TJO:20200108145835p:plain
(Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms. Han Xiao, Kashif Rasul, Roland Vollgraf. arXiv:1708.07747)

それがこちらのFashion-MNIST。名前の通りで、0: T-shirt/top, 1: Trouser, 2: Pullover, 3: Dress, 4: Coat, 5: Sandal, 6: Shirt, 7: Sneaker, 8: Bag, 9: Ankle bootという10種類のファッションアイテムのグレースケール画像7万点(学習データ6万&テストデータ1万)が収録されたデータセットです。


MNISTと全く同じフォーマット、全く同じ行数&列数、全く同じ分類クラス数で構成されていて、MNISTで用いられる前処理などのスクリプト一式がそのまま使えます(ちなみに日本語版のREADME.ja.mdがあります)。ただ、このデータは結構人気なので大抵の機械学習ライブラリに同梱されていて、よく見たらTensorFlowでも同梱されているんですねorz 己の不明さを恥じるばかりです。。。


Rで試してみる


Pythonでやってしまうと数あるHello World的なベンチマークの一つに新たになってしまうだけなので、ここではあえて誰もやらなそうなRでやってみます(笑)。まずデータ形式をRで読み込めるように揃えましょう。


Fashion-MNISTデータセットを保存したフォルダを確認した上で、こちらのスクリプトで前処理すればデータセット自体は簡単に出来上がります。後は適当に以下のように回すだけです。

load_mnist()
d_train <- matrix(0, 60000, 785)
d_train[, 1] <- train$y
d_train[, -1] <- train$x
d_test <- matrix(0, 10000, 785)
d_test[, 1] <- test$y
d_test[, -1] <- test$x
d_train <- as.data.frame(d_train)
d_test <- as.data.frame(d_test)
names(d_train)[1] <- "label"
names(d_test)[1] <- "label"
for (i in 1:784){
  names(d_train)[i + 1] <- paste0("V", i)
}
for (i in 1:784){
  names(d_test)[i + 1] <- paste0("V", i)
}
t <- proc.time()
fit <- randomForest(as.factor(label)~., d_train)
proc.time() - t
pred <- predict(fit, d_test)
table(d_test$label, pred)
sum(diag(table(d_test$label, pred)))/nrow(d_test)
# 100GB over RAMインスタンスでの結果
> t <- proc.time()
> fit <- randomForest(as.factor(label)~., d_train)
> proc.time() - t
    user   system  elapsed 
2964.437    2.818 2966.887 

# ... #

> table(d_test$label, pred)
   pred
      0   1   2   3   4   5   6   7   8   9
  0 855   0  12  32   3   1  85   0  12   0
  1   3 961   2  23   3   0   6   0   2   0
  2  11   0 800   9 118   0  57   0   5   0
  3  17   2   7 917  25   0  30   0   2   0
  4   0   0  90  36 822   0  49   0   3   0
  5   0   0   0   1   0 958   0  28   1  12
  6 149   1 123  31  87   0 591   0  18   0
  7   0   0   0   0   0  11   0 950   0  39
  8   1   2   5   2   5   2   6   5 972   0
  9   0   0   0   0   0   7   1  43   2 947
> sum(diag(table(d_test$label, pred)))/nrow(d_test)
[1] 0.8773

ACC 87.7%という結果になりました。ランダムフォレストでも97%が当たり前の元祖MNISTに比べると、工夫の余地があるように見えますね。RでCNNをやるのは面倒なので今回は割愛します。ということで、今年も皆様よろしくお願いいたします。