【Keras】転移学習とファインチューニング【犬猫判別4】

こんにちは。のっくんです。

前回の記事では特徴量抽出の方法を学びました。

【Keras】特徴量の抽出【犬猫判別3】

今回はモデルの拡張とファインチューニングをしていきたいと思います。

スポンサーリンク

やりたいこと

 

今回は畳み込みベース(VGG16)にオリジナルの全結合分類器を接続して新しいカスタムネットワークを作り学習するようにします。(左の図)

この時に畳み込みベースは凍結、つまり学習によってパラメータ(重み)の更新を行わないようにします。

VGG16が壊れないようにして出力だけを利用する感じですね。

 

 

さらに畳み込みベースの一部を解凍して訓練することで精度の向上を狙います。これはファインチューニングと呼ばれる方法です。(右の図)

モデルを全部訓練するのはよろしくないので、一部だけ解凍します。

理由としては、

  • パラメータ数が多いので学習するのは大変
  • モデルの下の方がより具体的な特徴を含んでいて、そこをチューニングした方が精度が高まりやすい

とのこと。

モデルの拡張(VGG16+全結合分類器)

 

VGG16モデルをダウンロードして構造を出力します。

 

 

19個のレイヤがあり、訓練可能なパラメータは1470万ほどあることが確認できます。

 

次に、全結合層を追加します。

 

 

出力を見ると、vgg16を含む全てのパラメータが訓練可能になっています。

畳み込みベースであるvgg16を凍結します。

 

 

vgg16のパラメータが訓練不可能になりました。

 

ジェネレータでデータを作っていきます。バリデーションデータは水増ししないこと、バッチサイズは20であることに注意です。

 

学習します。エポックは30です。

 

学習過程を可視化します。

 

バリデーションの精度は90.50%でした。なかなか良い精度。

モデルのファインチューニング

 

畳み込みベースの一部を解凍して学習するファインチューニングをします。

再度VGG16のネットワークを表示してみます。

 

ここで19層あるうちのblock5_conv1からblock5_conv3までを訓練可能にします。

具体的には以下の通り。

  • 一度全ての畳み込みベースを解凍し、訓練可能にする。
  • 畳み込み層の5ブロック目より上を訓練不可能な状態にする。
  • リストのスライスでは、0から始まり終点(15)を含まないことに注意。
  • 学習率は低めに設定。

 

こちらも可視化します。

 

バリデーションの精度は94.40%で前回より向上しました!

おわり。

続きは以下の記事です。

【keras】InceptionResNetV2の転移学習【犬猫判別5】

参考

ABOUTこの記事をかいた人

のっくん

理系院卒で大企業の研究所に就職。 趣味はプログラミング、レアジョブ英会話、筋トレ、旅行。 Twitter:@yamagablog