【PyTorch】転移学習やってみた【マラリア細胞】

こんにちは、のっくんです。

今日は深層学習フレームワークPyTorchで転移学習のやり方をご紹介します。

使用するデータセットは、マラリアの細胞データを使います。

細胞データには2パターンあり、感染しているものと、感染していないものがデータセットに含まれています。

このデータはkaggleに公開されているものですので、誰でも入手可能です。

データセットの詳しい内容は以下の記事に記載してあります。

マラリアに感染しているかディープラーニングで判別してみた

転移学習って何?

そもそも転移学習って何かと言うと、すでに学習済みのモデルとパラメータを使って深層学習する方法です。

普通の学習では、学習することで重みなどのパラメータが調整されていきますが、転移学習では事前に調整されたパラメータをそのまま(凍結させて)使います。

事前に調整されたモデルやパラメータをそのまま使用することを、転移学習と呼ぶわけですね。

データの読み込み

データの読み込みには、ImageFolderを使うと便利です。

ImageFolderを使用するためにディレクトリ構成を整える必要があります。

以下のように訓練用と検証用にフォルダを分けて、その中に感染している細胞、感染していない細胞のフォルダを作成する感じです。

元のデータでは2万枚ほどありましたが、その中から訓練用に1000枚、テスト用に500枚をコピーします。

コピーする際にはリスト内包表記を使って、ファイル名をリストに格納しました。

データ拡張するために、transformsというパッケージを使用します

訓練画像の場合には、

  • RandomResizedCrop, ランダムな位置をランダムなサイズで切り取って、指定したサイズにリサイズします。
  • RandomHorizontalFlip,50%の確率で水平方向反転します。
  • Normalize, 正規化します。

テスト画像の場合には、

  • Resize, 指定したサイズにリサイズします。
  • CenterCrop, 指定したサイズ分中心を切り抜きます。
  • Normalize, 正規化します。

上記コードではカレントディレクトリにtrainとvalというフォルダがあることが前提で、ImageFolderを使って画像を読み込みます。

転移学習

PyTorchでは、Alexnet、VGG、ResNet、SqueezeNet、Inception v3などの代表的なネットワークが使えます。

ここではAlexnetを使って転移学習をしてみます。pretrained=Trueにすることで学習済みの重みを使用します。

classifierの6番目(最終層)を見ると、out_featuresが1000になっています。

これは最終的に1000個のクラスに分類するように学習されているということです。

今回は2クラス分類なのでこれを2に変更し、最終層以外のパラメータを凍結させます。

個人的な感想ですが、kerasよりもPyTorchの方がこの辺りのパラメータの変更が直感的で分かりやすいと思います。

これだけで転移学習の準備が完了です。

学習

細かい部分は省略しますが、学習時に学習率を変更するようにlr_schedulerを使用します。

学習結果をプロットしてみました。

検証精度を見ると、最初の方は不安定ですが最後の方は安定して9割を超えました。

参考

ABOUTこの記事をかいた人

のっくん

理系院卒で大企業の研究所に就職。 趣味はプログラミング、レアジョブ英会話、筋トレ、旅行。 Twitter:@yamagablog