データ解析

Catboostとは?XgboostやLightGBMとの違いとPythonでの実装方法を見ていこうー!!

Catboost
記事内に商品プロモーションを含む場合があります
ウマたん
ウマたん
当サイト【スタビジ】の本記事では、XgboostやLightGBMに代わる新たな勾配ブースティング手法「Catboost」について徹底的に解説していき最終的にPythonにてMnistの分類モデルを構築していきます。LightGBMやディープラーニングとの精度差はいかに!?

こんにちは!

消費財メーカーでデジタルマーケター・データサイエンティストをやっているウマたん(@statistics1012)です!

Xgboostに代わる手法としてLightGBMが登場し、さらにCatboostという手法が2017年に登場いたしました。

これらは弱学習器である決定木を勾配ブースティングによりアンサンブル学習した非常に強力な機械学習手法群。

計算負荷もそれほど重くなく非常に高い精度が期待できるため、Kaggleなどのデータ分析コンペや実務シーンなど様々な場面で頻繁に使用されているのです。

ロボたん
ロボたん
最新のアルゴリズムがどんどん登場するけど、勾配ブースティング×決定木の組み合わせであることは変わらないんだね!
ウマたん
ウマたん
そうなんだよー!それだけ勾配ブースティング×決定木の組み合わせが強いということだね!

この記事では、そんなCatboostの概要について見ていき、PythonでMnistの画像データを分類するモデルを作っていきます。

LightGBMディープラーニングとの精度比較も行いますよー!

機械学習・ディープラーニングをまず勉強したい方は当メディアが運営する「スタアカ(スタビジアカデミー)」の以下のコースをチェックしてみて下さい!

ウマたん
ウマたん
Pythonの勉強は以下の記事をチェック!
Python独学勉強法
【Python独学勉強法】Python入門を3ヶ月で習得できる学習ロードマップ当サイト【スタビジ】の本記事では、過去僕自身がPythonを独学を駆使しながら習得した経験をもとにPythonを効率よく勉強する方法を具体的なコード付き実装例と合わせてまとめていきます。Pythonはできることが幅広いので自分のやりたいことを明確にして勉強法を選ぶことが大事です。...

Catboostとは?XgboostやLightGBMとの違い

勾配ブースティング

Catboostは、「Category Boosting」の略であり2017年にYandex社から発表された機械学習ライブラリ。

発表時期としてはLightGBMよりも若干後になっています。

Yandex社はロシアのGoogle。

ロシアの検索エンジン市場の過半数を占有しており、自動車産業にも事業展開している大手IT企業なんです。

ロボたん
ロボたん
へー!中国以外にも検索市場でGoogleが天下を取っていない国があるんだねー!
ウマたん
ウマたん
韓国はNaverが圧倒的だし、国産会社が検索市場でトップシェアを取っている国はまだまだ多いね!

さて、そんなYandex社が発表した「Category Boosting」ですが他の勾配ブースティング手法と比較してどのような特徴があるのでしょうか?

「Category Boosting」を発表した論文を見てみると以下のように書いてあります。

In this paper we present a new gradient boosting algorithm that successfully handles categorical features and takes advantage of dealing with them during training as opposed to preprocessing time.
Another advantage of the algorithm is that it uses a new schema for calculating leaf values when selecting the tree structure, which helps to reduce overfitting.
(引用元:”CatBoost: gradient boosting with categorical features
support”)

ここで言っているのは2つ。

・カテゴリカル変数(質的変数)の扱い方が上手いよ
・決定木のツリー構造を最適にして過学習を防ぐよ

まあ、要はデータセットによるけどXgboostやLightGBMよりも精度が高くなる可能性があるよってことですね。

実際に論文の中でもいくつかのデータセットに各手法を適用させてLogloss値で比較をしています。

Catboost performance
(引用元:”CatBoost: gradient boosting with categorical features support”)

実際にCatboostが最も良い精度をたたき出しているのが分かると思います。

ちなみにCatboostの公式HPに色々とまとめてありますが、大規模データセットのモデル構築も予測も比較的早い速度で出来るよ、とも書いてあります。

CatboostをPythonで実装してみよう!

実際にそんなCatboostをPythonで実装していきましょう!

今回は、Mnistというよく画像認識に使われるデータセットを用いていきます。

このデータセットには0~9の文字が様々なカタチで書かれたサンプルが入っています。

LightGBMディープラーニング(CNN)と比較していきますよー!

ちなみに機械学習モデルを個人で構築するならGoogle colaboratoryがオススメです!

永久無料でGPUが使えるので!

Google Colaboratory
Google Colaboratoryのメリットと使い方!GPU環境でPython回すならこれだ!当サイト【スタビジ】の本記事では、Googleが無償で提供する機械学習のプラットフォーム「Google Colaboratory」をメリット・デメリット・使い方について見ていきます!実際にPythonを実行していきGPUの威力を見ていきます。...

CatboostについてもGoogle colaboratoryを使ってモデル構築していきます!

CatboostでMnistを分類

Catboostは、その名の通りcatboostという名前のライブラリから使用することが可能なんです。

前準備として必要なライブラリを読み込んでいきましょう!

import pandas as pd
import numpy as np
from tensorflow.keras.datasets import mnist
from sklearn.model_selection import train_test_split

続いて、今回分類を行うMnistデータを加工。

# Kerasに付属の手書き数字画像データをダウンロード
np.random.seed(0)
(X_train_base, labels_train_base), (X_test, labels_test) = mnist.load_data()

# Training set を学習データ(X_train, labels_train)と検証データ(X_validation, labels_validation)に8:2で分割する
X_train,X_validation,labels_train,labels_validation = train_test_split(X_train_base,labels_train_base,test_size = 0.2)

# 各画像は行列なので1次元に変換→X_train,X_validation,X_testを上書き
X_train = X_train.reshape(-1,784)
X_validation = X_validation.reshape(-1,784)
X_test = X_test.reshape(-1,784)

#正規化
X_train = X_train.astype('float32')
X_validation = X_validation.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_validation /= 255
X_test /= 255

Catboostにインプットするためのデータを用意します。

# 訓練・テストデータの設定
from catboost import Pool

train_pool = Pool(X_train, labels_train)
validate_pool = Pool(X_validation, labels_validation)

パラメータを設定して・・

#経過時間計測
import time
start = time.time()

from catboost import CatBoostClassifier
params = {
    'early_stopping_rounds' : 10,
    'iterations' : 100, 
    'custom_loss' :['Accuracy'], 
    'random_seed' :42
}

モデル構築と予測を行っていきます。

この時処理速度も同時に測っていきます。

# パラメータを指定した場合は、以下のようにインスタンスに適用させる
model = CatBoostClassifier(**params)
cab = model.fit(train_pool, eval_set=validate_pool)
preds = cab.predict(X_test)
from sklearn.metrics import accuracy_score
print('accuracy_score:{}'.format(accuracy_score(labels_test, preds)))

#経過時間
print('elapsed_timetime:{}'.format(time.time()-start))

推定精度は・・・

Catboost Mnist

推定精度は0.9565!

非常に高い精度を得ることができました。

最新の勾配ブースティング手法「Catboost」がこんなにもカンタンに実装できちゃうんです!

全コードをまとめておきましょう!

LightGBMとディープラーニング(CNN)でMnistを分類

同様にLightGBMを使用してMnistを分類した結果は・・・0.972!

Light gbm Mnist

結果はLight gbmに軍配があがりました。

処理速度もCatboostが負けている・・・

やはりLight gbmつよし!ですねー

LightGBMの全コードは以下です。

詳しくは以下の記事でまとめています!

Light GBM
【図解で解説】LightGBMの仕組みとPythonでの実装を見ていこう!当サイト【スタビジ】の本記事では、最強の機械学習手法「LightGBM」についてまとめていきます。LightGBM の特徴とPythonにおける回帰タスクと分類タスクの実装をしていきます。LightGBMは決定木と勾配ブースティングを組み合わせた手法で、Xgboostよりも計算負荷が軽い手法であり非常によく使われています。...

ちなみにディープラーニング(CNN)で回した時の推定精度は、0.9696!

以下がディープラーニングのコードになります。

詳しくは以下の記事でまとめています!

【入門】ディープラーニング(深層学習)の仕組みとPython実装のやり方!当サイト【スタビジ】の本記事では、ディープラーニングの仕組みやPythonでの実装方法について解説していきます。ディープラーニングってなんとなくブラックボックスなイメージがあるかもしれませんが、実はシンプルなアルゴリズムなんですよー!...

もう少し真面目にパラメータチューニングを行えば順序が逆転することは多いにあります。

ぜひ色々と試してみてください!

Catboost まとめ

Catboostについて見てきましたー!

最後にカンタンにCatboostの特徴をまとめておきましょう!

・カテゴリカル変数(質的変数)の扱い方が上手いよ
・決定木のツリー構造を最適にして過学習を防ぐよ
・計算負荷が低いよ

Pythonで容易に実装できるので、ぜひ使ってみてください。

XgboostやLightGBMより高い精度を得られるかもしれませんよー!

ウマたん
ウマたん
興味がある人は、カテゴリカル変数を含むデータセットで試してみよう!

ちなみにPythonの学習方法については以下の記事でまとめていますのでよければチェックしてみてください!

Python独学勉強法
【Python独学勉強法】Python入門を3ヶ月で習得できる学習ロードマップ当サイト【スタビジ】の本記事では、過去僕自身がPythonを独学を駆使しながら習得した経験をもとにPythonを効率よく勉強する方法を具体的なコード付き実装例と合わせてまとめていきます。Pythonはできることが幅広いので自分のやりたいことを明確にして勉強法を選ぶことが大事です。...

よりさらに深く統計学や機械学習を学びたい人のために勉強法を以下にまとめていますのでぜひチェックしてみてください!

統計学入門に必要な知識と独学勉強方法を簡単に学ぼう!当ブログ【スタビジ】の本記事では、統計学入門に必要な知識をカンタンにまとめ、それらをどのように効率的に独学で勉強していけばよいかをお話ししていきます。統計学は難しいイメージが少しありますが、学び方をしっかり考えれば大丈夫!...
機械学習
【入門】機械学習のアルゴリズム・手法をPythonとRの実装と一緒に5分で解説!当サイト【スタビジ】の本記事では、入門者向けに機械学習についてカンタンにまとめていきます。最終的にはどのように機械学習を学んでいけばよいかも見ていきます。細かい手法の実装もPython/Rを用いておこなっていくので適宜参考にしてみてください。...

また、当メディアではデータサイエンティストになるための分野を体系的に学ぶスクール「スタアカ(スタビジアカデミー)」を運営しておりますので、データサイエンス全般に興味のある方は是非チェックしてみてください!

AIデータサイエンス特化スクール「スタアカ」

スタアカトップ
【価格】ライトプラン:980円/月
プレミアムプラン:98,000円
【オススメ度】
【サポート体制】
【受講形式】オンライン形式
【学習範囲】データサイエンスを網羅的に学ぶ
実践的なビジネスフレームワークを学ぶ
SQLとPythonを組み合わせて実データを使った様々なワークを行う
マーケティングの実行プラン策定
マーケティングとデータ分析の掛け合わせで集客マネタイズ

データサイエンティストとしての自分の経験をふまえてエッセンスを詰め込んだのがこちらのスタビジアカデミー、略して「スタアカ」!!

24時間以内の質問対応と現役データサイエンティストによる複数回のメンタリングを実施します!

カリキュラム自体は、他のスクールと比較して圧倒的に良い自信があるのでぜひ受講してみてください!

他のスクールのカリキュラムはPythonでの機械学習実装だけに焦点が当たっているものが多く、実務に即した内容になっていないものが多いです。

そんな課題感に対して、実務で使うことの多いSQLや機械学習のビジネス導入プロセスの理解なども合わせて学べるボリューム満点のコースになっています!

Pythonが初めての人でも学べるようなカリキュラムしておりますので是非チェックしてみてください!

ウォルマートのデータを使って商品の予測分析をしたり、実務で使うことの多いGoogleプロダクトのBigQueryを使って投球分析をしたり、データサイエンティストに必要なビジネス・マーケティングの基礎を学んでマーケティングプランを作ってもらったり・Webサイト構築してデータ基盤構築してWebマーケ×データ分析実践してもらったりする盛りだくさんの内容になってます!

・BigQuery上でSQL、Google Colab上でPythonを使い野球の投球分析
・世界最大手小売企業のウォルマートの実データを用いた需要予測
・ビジネス・マーケティングの基礎を学んで実際の企業を題材にしたマーケティングプランの策定
・Webサイト構築してデータ基盤構築してWebマーケ×データ分析実践して稼ぐ

スタビジアカデミーでデータサイエンスをさらに深く学ぼう!

スタアカサービスバナースタビジのコンテンツをさらに深堀りしたコンテンツが動画と一緒に学べるスクールです。

プレミアムプランでは私がマンツーマンで伴走させていただきます!ご受講お待ちしております!

スタビジアカデミーはこちら