機械学習

【入門】セマンティックセグメンテーションをPythonで実装してみよう!

セマンティックセグメンテーション
記事内に商品プロモーションを含む場合があります
ウマたん
ウマたん
当サイト【スタビジ】の本記事では、セマンティックセグメンテーションについて解説していきます。セマンティックセグメンテーションの特徴を見ていき、最終的にPythonで実装し画像にセマンティックセグメンテーションをかけていきます。

こんにちは!データサイエンティストのウマたん(@statistics1012)です!

この記事ではセマンティックセグメンテーション(SemanticSegmentation)について解説していきます。

ディープラーニングブームのはじまりとともに盛んに研究されるようになったセマンティックセグメンテーションを理解してPythonで実装できるようになっておきましょう!

以下の動画でセマンティックセグメンテーションについて解説していますのであわせてチェックしてみてください!

セマンティックセグメンテーションとは

まずは、セマンティックセグメンテーションの特徴を見ていきましょう!

セマンティックセグメンテーションは画像の対象物を特定のカテゴリごとに分類するアプローチです。

以下のように画像を対象物ごとに色分けすることができます。

semanticsegmentation

(出典:“SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation”

様々な画像を車なら車、建物なら建物、イスならイスというように同じカテゴリの対象物は同じ色で分類して色分けされているのが分かります。

segnet

セマンティックセグメンテーションは、AlexNetが2012年に登場しディープラーニングブームに突入し自動運転AIの研究が盛んになりはじめた2015年ごろからホットな領域になりました。

2015年にセマンティックセグメンテーションの手法として登場したSegNetは以下のようなアーキテクチャになっており、畳み込みニューラルネットワークをベースにしたエンコード・デコード構造になっています。

segnet

(出典:“SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation”

また、同じく2015年に登場したU-Netは以下のようなアーキテクチャになっています。

U-Net

(出典:”U-Net: Convolutional Networks for Biomedical Image Segmentation“)

ウマたん
ウマたん
ちなみに車や建物など同じカテゴリであったとしても別対象物として別色で色分けするインスタントセグメンテーションという手法もあるよ!

セマンティックセグメンテーションをPythonで実装してみよう!

「モノは試し!」ということで、早速セマンティックセグメンテーションを実装していきましょう!

今回はディープラーニングを実装するためのライブラリPyTorchを使っていきます。

PyTorchには多くのモデルが用意されているのですが、その中からResNetをベースに学習されたdeeplabv3_resnet50という学習済みモデルを使っていきます。

まずはコード全体を確認し、細かく順を追ってみていきます。

コード全体は以下のようになります。

import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt

# モデルをロード
model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True)
model.eval()

# 画像をロードして前処理
def preprocess_image(img_path):
    input_image = Image.open(img_path)
    preprocess = T.Compose([
        T.ToTensor()
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)  # バッチ次元を追加
    return input_batch

# セグメンテーションの実行
def segment_image(img_path):
    input_batch = preprocess_image(img_path)
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)['out'][0]
    output_predictions = output.argmax(0)
    return output_predictions

# 結果を表示
def display_segmentation(img_path, output_predictions):
    input_image = Image.open(img_path)
    plt.imshow(input_image)
    # テンソルをCPUに移動させる
    output_predictions_cpu = output_predictions.cpu()
    plt.imshow(output_predictions_cpu, alpha=0.7)
    plt.axis('off')
    plt.show()

# 画像パス
img_path = '/content/mypic.jpg'  # ここに画像のパスを入力

# セグメンテーションの実行と表示
output_predictions = segment_image(img_path)
display_segmentation(img_path, output_predictions)

 

どんな処理をしているのか順を追って見ていきましょう!

必要なライブラリと学習済みモデルのインポート

まずは、必要なライブラリとモデルをインポートしていきます。

ここで、学習済みモデルのdeeplabv3_resnet50をロードしています。

import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt

# モデルをロード
model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True)
model.eval()

 

前処理の関数を定義

続いて、画像に前処理を施すための関数を定義します。

# 画像をロードして前処理
def preprocess_image(img_path):
    input_image = Image.open(img_path)
    preprocess = T.Compose([
        T.ToTensor()
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)  # バッチ次元を追加
    return input_batch

ここでは、PyTorchで扱うために通常の画像をテンソル型に変換しています(T.ToTensor)。

その上でバッチの次元を追加しています。

このバッチ次元どういうことでしょう?

画像は通常、縦×横×チャネル(RGB)の3次元で表現されますが、ディープラーニングは複数の画像で処理するようにバッチの次元が追加されて処理されるのが一般的です。

今回は画像一枚一枚に対してセマンティックセグメンテーションをかける処理なのですが、次元をあわせるためにバッチ用の次元を追加して4次元にしているのです!

セマンティックセグメンテーションの実行

続いてセマンティックセグメンテーションを実行する部分です。

# セグメンテーションの実行
def segment_image(img_path):
    input_batch = preprocess_image(img_path)
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)['out'][0]
    output_predictions = output.argmax(0)
    return output_predictions

ここがまさにセマンティックセグメンテーションを実装している部分。

まず最初に先ほど定義した関数を使って画像に前処理をかけています。

その後、NvidiaのGPU環境であるcudaを使うように設定し、モデルに画像を投入してアウトプットを得ています。

torch.no_grad()では、ディープラーニングで最適パラメータを求める際に必要な勾配を初期化しておりPyTorchでこのように表記します。

また、output = model(input_batch)[‘out’][0]で得られるアウトプットは、[どのクラスに属するかを判断する確率のようなもの、画像の縦ピクセル、画像の横ピクセル]になっています。

以下のようなイメージ。

セマンティックセグメンテーション

そのためoutput_predictions = output.argmax(0)によって各ピクセルにおいてどのクラスに属するかを算出して2次元配列にしています。

元画像とセマンティックセグメンテーション後の結果表示

ここまでくれば後は結果を表示するだけ!

# 結果を表示
def display_segmentation(img_path, output_predictions):
    input_image = Image.open(img_path)
    plt.imshow(input_image)
    # テンソルをCPUに移動させる
    output_predictions_cpu = output_predictions.cpu()
    plt.imshow(output_predictions_cpu, alpha=0.7)
    plt.axis('off')
    plt.show()

# 画像パス
img_path = '/content/mypic.jpg'  # ここに画像のパスを入力

# セグメンテーションの実行と表示
output_predictions = segment_image(img_path)
display_segmentation(img_path, output_predictions)

最終的に元の画像にセマンティックセグメンテーション後のクラス判別した結果をあわせて表示しています。

実際に画像をインプットして出力結果を見ていこう!

実際に特定の画像にセマンティックセグメンテーションをかけた結果を見ていきましょう!

まずAI時代の天才経営者OpenAIのサムアルトマン。

サムアルトマン

以下のようになりました。

サムアルトマン

ちゃんと人を分類できているのが分かります。

複数の人が移っている画像だとどうでしょうか?

若干女性の腕が背景と同化してしまっていますが、以下のように人とそれ以外を分類することができました。

こんな感じで対象物を特定のクラスごとに分類できるのです!

セマンティックセグメンテーション まとめ

今回はセマンティックセグメンテーションについて解説してきました。

セマンティックセグメンテーションはAI時代に重要な技術です。

ディープラーニングの様々なモデルを知りたい方は以下の記事を参考にしてみてください。

また、さらに詳しくAIやデータサイエンスの勉強がしたい!という方は当サイト「スタビジ」が提供するスタビジアカデミーというサービスで体系的に学ぶことが可能ですので是非参考にしてみてください!

AIデータサイエンス特化スクール「スタアカ」

スタアカトップ
【価格】ライトプラン:1280円/月
プレミアムプラン:149,800円
【オススメ度】
【サポート体制】
【受講形式】オンライン形式
【学習範囲】データサイエンスを網羅的に学ぶ
実践的なビジネスフレームワークを学ぶ
SQLとPythonを組み合わせて実データを使った様々なワークを行う
マーケティングの実行プラン策定
マーケティングとデータ分析の掛け合わせで集客マネタイズ

データサイエンティストとしての自分の経験をふまえてエッセンスを詰め込んだのがこちらのスタビジアカデミー、略して「スタアカ」!!

当メディアが運営するスクールです。

24時間以内の質問対応と現役データサイエンティストによる複数回のメンタリングを実施します!

カリキュラム自体は、他のスクールと比較して圧倒的に良い自信があるのでぜひ受講してみてください!

他のスクールのカリキュラムはPythonでの機械学習実装だけに焦点が当たっているものが多く、実務に即した内容になっていないものが多いです。

そんな課題感に対して、実務で使うことの多いSQLや機械学習のビジネス導入プロセスの理解なども合わせて学べるボリューム満点のコースになっています!

Pythonが初めての人でも学べるようなカリキュラムしておりますので是非チェックしてみてください!

ウォルマートのデータを使って商品の予測分析をしたり、実務で使うことの多いGoogleプロダクトのBigQueryを使って投球分析をしたり、データサイエンティストに必要なビジネス・マーケティングの基礎を学んでマーケティングプランを作ってもらったり・Webサイト構築してデータ基盤構築してWebマーケ×データ分析実践してもらったりする盛りだくさんの内容になってます!

・BigQuery上でSQL、Google Colab上でPythonを使い野球の投球分析
・世界最大手小売企業のウォルマートの実データを用いた需要予測
・ビジネス・マーケティングの基礎を学んで実際の企業を題材にしたマーケティングプランの策定
・Webサイト構築してデータ基盤構築してWebマーケ×データ分析実践して稼ぐ

スタビジアカデミーでデータサイエンスをさらに深く学ぼう!

スタアカサービスバナースタビジのコンテンツをさらに深堀りしたコンテンツが動画と一緒に学べるスクールです。

プレミアムプランでは私がマンツーマンで伴走させていただきます!ご受講お待ちしております!

スタビジアカデミーはこちら