機械学習

U-Netの構造を解説!Pythonで実装して使い方を見ていこう!

U-Net
記事内に商品プロモーションを含む場合があります
ウマたん
ウマたん
当サイト【スタビジ】の本記事では、セマンティックセグメンテーションの有用なアプローチの1つU-Netについて解説します。U-Netの特徴を理解しPythonでのモデル構築方法も一緒に見ていきましょう!

こんにちは!データサイエンティストのウマたん(@statistics1012)です!

この記事ではセマンティックセグメンテーションのアプローチとして有名なU-Netについて解説していきます。

ディープラーニングを使いこなす上で非常に重要なアーキテクチャですのでしっかり理解しておきましょう!

以下のYoutube動画でも解説していますので合わせてチェックしてみてください!

U-Netの特徴

まずは、U-Netの特徴を見ていきましょう!

U-Netは2015年に医療用画像のセグメンテーションアプローチとして提案されました。

論文は以下になります。

この論文に掲載されているU-Netのアーキテクチャは以下の通りです。

U-Net
ウマたん
ウマたん
このアーキテクチャがアルファベットのUに見えることからU-Netという名前がつけられているんだ!

このアーキテクチャを見ると、左側で徐々にダウンサンプリングが行われ右側で徐々にアップサンプリングが行われていることが分かります。

U-net

左側のアーキテクチャをエンコーダ、右側のアーキテクチャをデコーダと呼びます。

そして特徴的なのが左側のエンコーダから右側のデコーダへの灰色の矢印であるSkip接続。

U-Net

これは左側のエンコーダアーキテクチャでダウンサンプリングする中で抽出された特徴をそのまま後続のデコーダに直接渡す役割を担っています。

これにより、画像の細かい特徴と抽象的な特徴を捉えることができ、より正確なセグメンテーションを行えるようになっています。

元々医療用画像のセグメンテーションで利用されることを想定して提案されたU-Netですが、その汎用性から一般的な画像における画像変換や画像生成のアーキテクチャのベースとして利用されるようになりました。

U-NetのPython実装方法

それでは、U-NetのPython実装方法を見てきましょう!

ここでは、U-Netのアーキテクチャを作り学習させるためのコードを見ていきたいと思います。

Kerasを使ってU-Netのアーキテクチャを作っていきます。

コード全体は以下のようになります。

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate

def conv_block(input_tensor, num_filters):
    x = Conv2D(num_filters, (3, 3), padding='same')(input_tensor)
    x = tf.keras.layers.Activation('relu')(x)
    x = Conv2D(num_filters, (3, 3), padding='same')(x)
    x = tf.keras.layers.Activation('relu')(x)
    return x

def encoder_block(input_tensor, num_filters):
    x = conv_block(input_tensor, num_filters)
    p = MaxPooling2D((2, 2))(x)
    return x, p

def decoder_block(input_tensor, concat_tensor, num_filters):
    x = UpSampling2D((2, 2))(input_tensor)
    x = concatenate([x, concat_tensor], axis=-1)
    x = conv_block(x, num_filters)
    return x

def unet(input_shape, num_filters_start=32, num_classes=1):
    inputs = Input(input_shape)

    # Encoder
    x1, p1 = encoder_block(inputs, num_filters_start)
    x2, p2 = encoder_block(p1, num_filters_start * 2)
    x3, p3 = encoder_block(p2, num_filters_start * 4)
    x4, p4 = encoder_block(p3, num_filters_start * 8)

    # Bridge
    bridge = conv_block(p4, num_filters_start * 16)

    # Decoder
    d1 = decoder_block(bridge, x4, num_filters_start * 8)
    d2 = decoder_block(d1, x3, num_filters_start * 4)
    d3 = decoder_block(d2, x2, num_filters_start * 2)
    d4 = decoder_block(d3, x1, num_filters_start)

    # Output
    outputs = Conv2D(num_classes, (1, 1), activation='sigmoid')(d4)

    model = tf.keras.Model(inputs, outputs)
    return model

 

これがU-Netのアーキテクチャになります。

関数conv_blockで畳み込み処理を定義していて、encoder_blockでは畳み込み処理をおこなっており、decoder_blockではアップサンプリング処理をおこなっているのが分かると思います。

そしてメインのunet関数の中では、encoder_blockを通して4回畳み込み処理でダウンサンプリングを行い、その後に4回decoder_blockを通してアップサンプリングをしています。

以下の関数内でSkip接続を実現しています。

def encoder_block(input_tensor, num_filters):
    x = conv_block(input_tensor, num_filters)
    p = MaxPooling2D((2, 2))(x)
    return x, p

def decoder_block(input_tensor, concat_tensor, num_filters):
    x = UpSampling2D((2, 2))(input_tensor)
    x = concatenate([x, concat_tensor], axis=-1)
    x = conv_block(x, num_filters)
    return x

それぞれの関数の引き数を見てみると、p1はp2、p3、p4と順番に次の層にわたっていきブリッジ関数を通して、デコーダーのd1,d2,d3と連携されています。

一方で、x1は最後のdecoder_blockの引き数になっていて、ちゃんと先ほどお話ししたアーキテクチャ通りスキップ接続されていることが分かります。

encoder_blockでエンコーダ部分からxを出力して、Skip接続しデコーダ部分でconcatenate関数を使って、エンコーダ部分からの出力(concat_tensor)とデコーダ部分でアップサンプリングされた特徴マップ(input_tensor)を結合しています!

このコードはU-Netのモデルアーキテクチャを作っただけなので、手元の画像データで学習するためには別途コードを書く必要があります。

例えば以下のように書くことでU-Netの学習を行うことができます。

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# データジェネレータの設定
train_datagen = ImageDataGenerator(rescale=1./255)
mask_datagen = ImageDataGenerator(rescale=1./255)

# トレーニング用の画像とマスクのデータジェネレータを作成
train_generator = train_datagen.flow_from_directory(
    'path_to_train_images',
    target_size=(128, 128),
    batch_size=32,
    class_mode=None)

mask_generator = mask_datagen.flow_from_directory(
    'path_to_train_masks',
    target_size=(128, 128),
    batch_size=32,
    class_mode=None,
    color_mode='grayscale')  # セグメンテーションマスクはグレースケール

# 画像とマスクを組み合わせる
train_generator = zip(train_generator, mask_generator)

# モデルのコンパイル
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# モデルのトレーニング
model.fit(train_generator, steps_per_epoch=100, epochs=10)

path_to_train_imagesとpath_to_train_masksにはそれぞれ画像データとセグメンテーション後の各ピクセルのラベリングデータが入ったディレクトリを指定します。

そして学習が完了したらモデルに対して新しい画像をインプットしてセグメンテーションを実施していくことになります。

ちなみにU-NetではないのですがResNetをベースに学習したセマンティックセグメンテーション用の事前学習モデルはPyTorchに存在します。

そちらを使ったセグメンテーション実装方法は以下の記事で解説していますので興味のある方はチェックしてみてください。

セマンティックセグメンテーション
【入門】セマンティックセグメンテーションをPythonで実装してみよう!当サイト【スタビジ】の本記事では、セマンティックセグメンテーションについて解説していきます。セマンティックセグメンテーションの特徴を見ていき、最終的にPythonで実装し画像にセマンティックセグメンテーションをかけていきます。...

U-Net まとめ

今回はU-Netについて解説してきました。

U-NetはAIの進化の歴史の中でも非常に重要な技術です。ぜひ理解しておきましょう!

ディープラーニングの様々なモデルを知りたい方は以下の記事を参考にしてみてください。

また、さらに詳しくAIやデータサイエンスの勉強がしたい!という方は当サイト「スタビジ」が提供するスタビジアカデミーというサービスで体系的に学ぶことが可能ですので是非参考にしてみてください!

AIデータサイエンス特化スクール「スタアカ」

スタアカトップ
【価格】ライトプラン:1280円/月
プレミアムプラン:149,800円
【オススメ度】
【サポート体制】
【受講形式】オンライン形式
【学習範囲】データサイエンスを網羅的に学ぶ
実践的なビジネスフレームワークを学ぶ
SQLとPythonを組み合わせて実データを使った様々なワークを行う
マーケティングの実行プラン策定
マーケティングとデータ分析の掛け合わせで集客マネタイズ

データサイエンティストとしての自分の経験をふまえてエッセンスを詰め込んだのがこちらのスタビジアカデミー、略して「スタアカ」!!

当メディアが運営するスクールです。

24時間以内の質問対応と現役データサイエンティストによる複数回のメンタリングを実施します!

カリキュラム自体は、他のスクールと比較して圧倒的に良い自信があるのでぜひ受講してみてください!

他のスクールのカリキュラムはPythonでの機械学習実装だけに焦点が当たっているものが多く、実務に即した内容になっていないものが多いです。

そんな課題感に対して、実務で使うことの多いSQLや機械学習のビジネス導入プロセスの理解なども合わせて学べるボリューム満点のコースになっています!

Pythonが初めての人でも学べるようなカリキュラムしておりますので是非チェックしてみてください!

ウォルマートのデータを使って商品の予測分析をしたり、実務で使うことの多いGoogleプロダクトのBigQueryを使って投球分析をしたり、データサイエンティストに必要なビジネス・マーケティングの基礎を学んでマーケティングプランを作ってもらったり・Webサイト構築してデータ基盤構築してWebマーケ×データ分析実践してもらったりする盛りだくさんの内容になってます!

・BigQuery上でSQL、Google Colab上でPythonを使い野球の投球分析
・世界最大手小売企業のウォルマートの実データを用いた需要予測
・ビジネス・マーケティングの基礎を学んで実際の企業を題材にしたマーケティングプランの策定
・Webサイト構築してデータ基盤構築してWebマーケ×データ分析実践して稼ぐ

スタビジアカデミーでデータサイエンスをさらに深く学ぼう!

スタアカサービスバナースタビジのコンテンツをさらに深堀りしたコンテンツが動画と一緒に学べるスクールです。

プレミアムプランでは私がマンツーマンで伴走させていただきます!ご受講お待ちしております!

スタビジアカデミーはこちら