【Keras】cifar10の画像分類をやってみた

 

ディープラーニング用の画像データセットcifar10を使って、画像の分類にチャレンジしてみました。

 

 

cifar10とは

 

飛行機や車、カエルなどを含む10種類の画像データセットです。kerasには、データセットをダウンロードする関数があります。

 

from keras.datasets import cifar10
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

 

コードを実行すると、トロント大学からデータセットをダウンロードしてきます。私の環境では2分ほどかかりました。

 

%matplotlib inline
import matplotlib.pyplot as plt
from PIL import Image

print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

plt.figure(figsize=(10,10))
labels = ["airplane","automobile","bird","cat","deer","dog","frog","horse",
         "ship","truck"]

for i in range(0,40):
    im = Image.fromarray(X_train[i])
    plt.subplot(5,8,i+1)
    plt.title(labels[y_train[i][0]])
    # x軸をオフ
    plt.tick_params(labelbottom="off",bottom="off")
    # y軸をオフ
    plt.tick_params(labelleft="off",left="off")
    plt.imshow(im)
    
plt.show()

 

 

解像度は32*32のカラー画像で、学習用に5万枚、テスト用に1万枚あることがわかりました。

 

MLPを使って分類

 

ディープラーニングの1つである多層パーセプトロン(MLP)のアルゴリズムを使って分類してみます。

 

import matplotlib.pyplot as plt
import keras
from keras.datasets import cifar10
from keras.models import Sequential
from keras.layers import Dense, Dropout

num_classes = 10
im_rows = 32
im_cols = 32
im_size = im_rows * im_cols * 3

# データを読み込む
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

# データを一次元配列に変換 
X_train = X_train.reshape(-1, im_size).astype('float32') / 255
X_test = X_test.reshape(-1, im_size).astype('float32') / 255
# ラベルデータをOne-Hot形式に変換
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# モデルを定義
model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(im_size,)))
model.add(Dense(num_classes, activation='softmax'))

# モデルをコンパイル
model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy'])

# 学習を実行
hist = model.fit(X_train, y_train,
    batch_size=32, epochs=50,
    verbose=1,
    validation_data=(X_test, y_test))

# モデルを評価
score = model.evaluate(X_test, y_test, verbose=1)
print('正解率=', score[1], 'loss=', score[0])

 

コードを実行すると20分程度かかりました。

 

正解率= 0.4853 loss= 1.4944498865127565

 

適当に分類したとすると10%(1/10)なのでそれよりはソコソコ分類できているようです。

 

CNNを使って分類

 

ディープラーニングの1つであるCNN(畳み込みニューラルネットワーク)を使って分類してみます。

 

import keras
from keras.datasets import cifar10
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D

num_classes = 10
im_rows = 32
im_cols = 32
in_shape = (im_rows, im_cols, 3)

(X_train, y_train), (X_test, y_test) = cifar10.load_data()

X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
                 input_shape=in_shape))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy'])

hist = model.fit(X_train, y_train,
    batch_size=32, epochs=50,
    verbose=1,
    validation_data=(X_test, y_test))

score = model.evaluate(X_test, y_test, verbose=1)
print('正解率=', score[1], 'loss=', score[0])

 

正解率= 0.7952 loss= 0.6938282356500626

 

MLPに比べてかなり精度が上がっています。

 

ワンポイントアドバイス
CNNを動かす場合には、かなり時間がかかります。そういう場合は、Jupyter Notebook上ではなく、SSHでUbuntuにログインしコマンドで実行するのをオススメします。Ubuntuで実行する場合、screenコマンドでセッションを作りそこでPythonファイルを実行しておけば好きなときに結果を確認できます。

 

2019/2/20追記:

CPUを使った場合の実行時間は250分、Geforce GTX1060を使って動かした時の実行時間は、12.5分でした。速度が約20倍違いますので、GPUを使って動かすのをオススメします。

参考

ABOUTこの記事をかいた人

個人アプリ開発者。Python、Swift、Unityのことを発信します。月間2.5万PVブログ運営。 Twitter:@yamagablog