Python

MnistデータセットをPythonで分類して使い方を理解していこう!

Mnist
ウマたん
ウマたん
当サイト【スタビジ】の本記事では、Mnistという手書き文字のデータセットをPythonで分類してどのように扱っていけばよいか見ていきます。Mnistはディープラーニングとはじめとした手法の分類精度を比較するのによく使われるんです。

こんにちは!

消費財メーカーでデータサイエンティストをしているウマたん(@statistics1012)です。

データサイエンスの世界では最終的には実データを使っていくことになりますが、手法を簡易的に実装してみたり手法の精度を比較する上で既存のデータセットが存在します。

そんな既存のデータセットの中でもディープラーニングの画像分類精度を測る上で非常によく使われるデータセットがMnist。

ロボたん
ロボたん
既存のデータセットを使って意味あるのー?
ウマたん
ウマたん
実データはしっかりデータ整形されていないことが多いから、実装してみるだけだったら既存のデータセットを使った方がよいよ

データセットは色んなタイプがありますが、Mnistの扱い方を知っておくことは非常に重要です。

ということで、この記事ではMnistというデータセットについて見ていきたいと思います。

Mnistデータとは?

まずは、Mnistデータとは何か見ていきましょう!

MnistはMixed National Institute of Standards and Technology databaseの略で、手書き数字画像60,000枚とテスト画像10,000枚を集めた、画像データセット。

0~9の手書き数字が教師ラベルとして各画像に与えられています。

つまりデータセットの構造は以下のようになっています。

・学習用画像データ
・学習用教師ラベルデータ
・予測用画像データ
・予測用教師ラベルデータ

ロボたん
ロボたん
1つの画像はどのようなデータとして入っているの?
ウマたん
ウマたん
1つ1つの画像は28×28のピクセル単位で格納されているんだ!

1つ1つの画像は、文字画像をタテヨコ28×28のピクセルに分け、1つのピクセルあたり0~255の数値で白黒のスケールを表します。

非常に扱いやすくて、データセットとしての完成度が高いので多くの論文やカリキュラムで取り上げられるデータセットです。

画像分類問題に取り組むのであればまずはMnistを使ってディープラーニングを実装してみるのが良いと思います。

PythonでMnistデータを使って分類精度を比較

そんなMnistデータを使ってPythonで分類を行っていきましょう!

比較する手法は以下の4つ!

・CNN(畳み込みニューラルネットワーク)
・Xgboost
・LightGBM
・Catboost

CNN(畳み込みニューラルネットワーク)

ディープラーニングに畳み込み層を組みあわせた画像分類には定番の手法です。

まずは、必要なライブラリをインストールしていきます!

tensorflowなどのライブラリはあらかじめpip installしておいてくださいね!

続いてMnistのデータを学習データとテストデータに分けます。そしてさらに学習データからパラメータチューニングのための検証データを取り出します。

この時画像データは、描画がしやすいように28×28の行列になっているのですが、学習するために1×784に直しましょう!さらに0~255のスケールを正規化しましょう!

続いてラベルをダミー変数化します。

ここでデータの成型が終了したので、ディープラーニングのネットワーク構築に入ります。

隠れ層では、RELU(ランプ)関数を用いて出力層ではソフトマックス関数を用いています。

Model.addを使うことで隠れ層をいくつも積み重ねることが可能です。

ネットワークの構築が終了した後は、最適な重みを見つけていきます。AdamOptimizerは最近よく使われている最適化手法です。

Early stoppingとはもう精度が改善しないようなら学習を止めてしまう条件です。これによってムダな学習を省くことが可能です。最後のvalidation_dataで過学習が起こらないように検証を行っています。

最後にテストデータで予測を実行して実測値と予測値の正解率を求めます!

最終的な結果は・・・・96.96%!!

そこそこな精度をたたき出すことができました。パラメータをいじることで精度を99%まで伸ばしてみてください!

最後にまとめてコードを載せておきます。

以下の記事で詳しくまとめていますのでチェックしてみてください!

【入門】ディープラーニングとは?仕組みとPythonでの実装を見ていこう!当サイト【スタビジ】の本記事では、ディープラーニングの仕組みやPythonでの実装方法について解説していきます。ディープラーニングってなんとなくブラックボックスなイメージがあるかもしれませんが、実はシンプルなアルゴリズムなんですよー!...

Xgboost

続いて勾配ブースティング手法の定番「Xgboost

アンサンブル学習決定木を組み合わせた手法で非常に高い汎化能力を誇ります。

Xgboostでは、アンサンブル学習の中でもブースティングを用いています。

バギングは並列学習なのですが、ブースティングは直列で学習していくイメージです。

前期に上手く学習できなかったら誤差を目的変数にして次の学習を行います。

コードをみていきましょう!

精度・・0.976!!非常に高い!

以下の記事で詳しくまとめています!

XGboostとは?理論とPythonとRでの実践方法!当ブログ【スタビジ】の本記事では、機械学習手法の中でも非常に有用で様々なコンペで良く用いられるXgboostについてまとめていきたいと思います。最後にはRで他の機械学習手法と精度比較を行っているのでぜひ参考にしてみてください。...

LightGBM

続いてXgboostを改良した手法として発表されたLightGBM

データコンペでもよく使われる手法です。

Xgboostを含む通常の決定木モデルは以下のように階層を合わせて学習していきます。

それをLevel-wiseと呼びます。

level-wise学習法
(引用元:Light GBM公式リファレンス

一方Light GBMは以下のように葉ごとの学習を行います。これをleaf-wise法と呼びます。

leaf-wise学習法
(引用元:Light GBM公式リファレンス

これにより、ムダな学習をしなくても済むためより効率的に学習を進めることができます。

Light GBMの解到達スピードが速いゆえんはここにあるのです。

では実装していきましょう!

結果的に0.972という推定精度が得られました!

結果はXgboostに負けていますが、処理時間ではXgboostの8分の1ほどで終了しています。

LightGBMに関しては以下の記事で詳しくまとめています!

Light GBM
Light GBMの仕組みとPythonでの実装を見ていこう!当ブログ【スタビジ】の本記事では、最強の機械学習手法「Light GBM」についてまとめていきます。Light GBMは決定木と勾配ブースティングを組み合わせた手法で、Xgboostよりも計算負荷が軽い手法として注目を集めています。...

Catboost

Catboostは、「Category Boosting」の略であり2017年にYandex社から発表された機械学習ライブラリ。

発表時期としてはLightGBMよりも若干後になっています。

ポイントは2つ。

・カテゴリカル変数(質的変数)の扱い方が上手いよ
・決定木のツリー構造を最適にして過学習を防ぐよ

まあ、要はデータセットによるけどXgboostやLightGBMよりも精度が高くなる可能性があるよってことですね。

では、実装していきましょう!

推定精度は0.9565!

Catboostについては以下の記事でまとめています!

Catboost
Catboostとは?XgboostやLightGBMとの違いとPythonでの実装方法を見ていこうー!!当サイト【スタビジ】の本記事では、XgboostやLightGBMに代わる新たな勾配ブースティング手法「Catboost」について徹底的に解説していき最終的にPythonにてMnistの分類モデルを構築していきます。LightGBMやディープラーニングとの精度差はいかに!?...

結果は・・・以下のようになりました!

CNNXgboostLightGBMCatboost
96.96%97.60%97.20%95.65%

絶妙な差ですが、Xgboostが最も高くCatboostが最も低いという結果に!

ぜひ色んな手法でMnistを分類して精度比較してみてください!

Pythonについて学びたい方はぜひ以下の記事を参考にしてみてくださいね!

Python 勉強
【入門】初心者が3か月でPythonを習得できるようになる勉強法!当ブログ【スタビジ】の本記事では、Pythonを効率よく独学で習得する勉強法を具体的なコード付き実装例と合わせてまとめていきます。Pythonはできることが幅広いので自分のやりたいことを明確にして勉強法を選ぶことが大事です。...

Mnistデータまとめ

Mnistデータの特徴とPythonでの分類実装を見てきました。

ロボたん
ロボたん
Mnistデータってこんなに簡単に使えるんだね!
ウマたん
ウマたん
このレベルのデータセットがフリーで使えるのは本当にありがたいよねー!

Mnistの特徴をもう一度まとめておきましょう!

・0~9の手書き数字がまとめられたデータセット
・6万枚の訓練データ用(画像とラベル)
・1万枚のテストデータ用(画像とラベル)
・白「0」~黒「255」の256段階
・幅28×高さ28フィールド

今回精度比較に用いたディープラーニングをはじめとする機械学習手法については以下の記事でまとめています。

【初心者向け】ディープラーニングの学習ロードマップまとめ当サイト【スタビジ】本記事では、ディープラーニングの学習方法について詳しくまとめていきます!ディープラーニングは難しいと思われがちですが、アルゴリズムは意外とシンプルで実装自体も非常に簡単なんです!Pythonでの実装もおこなっていきますよー!...
機械学習
機械学習入門に必要な知識と勉強方法をPythonとRの実装と一緒に見ていこう!当サイト【スタビジ】の本記事では、入門者向けに機械学習についてカンタンにまとめていきます。最終的にはどのように機械学習を学んでいけばよいかも見ていきます。細かい手法の実装もPython/Rを用いておこなっていくので適宜参考にしてみてください。...

ちなみにディープラーニングは機械学習手法の1つだということを覚えておきましょう!

Pythonを初学者が最短で習得する勉強法

Pythonを使うと様々なことができます。しかしどんなことをやりたいかという明確な目的がないと勉強は捗りません。

Pythonを習得するためのロードマップをまとめましたのでぜひチェックしてみてくださいね!