【TensorFlow】アヤメの分類をする

 

こんにちは。のっくん(@yamagablog)です。

 

今日の記事では、TensorFlowを使ったアヤメの分類にチャレンジしたいと思います。

 

「TensorFlowの使い方がよく分からない」

 

そんな方に読んでいただければと思います。

 

スポンサーリンク

 

使用するデータセット

 

以下のような花びらの長さとアヤメの種類が入ったcsv形式のデータを使用します。

 

 

表示は6個だけですが、全部で150個あります。

 

このデータは前回の記事でも使用していまして、ダウンロード方法は以下のページをみてください。

 

https://ymgsapo.com/classify-flower/

 

入力データとラベルの作成

 

入力データは花の種類を推測するのに使用するデータです。花びらやガクの長さを含めた4つです。

 

ラベルというのはアヤメの種類です。アヤメの種類は3種類でしたね。

 

pandasを使って、入力データとラベルの列を取り出してみます。

 

 

loc[:,”Name”]でアヤメの種類だけ取り出しています。第一引数に、「:」を指定することで全部の行を取り出しています。ちなみに、「:9」とかにすると9行だけ取り出せます。

 

ラベルをOne-Hotベクトルにする

 

One-Hotベクトルは1つだけ1であとは0のベクトルです。[0,0,1]とかですね。

 

TensorFlowを使う場合、ラベルをこの形式にしなければいけません。

 

変換用の辞書を作って、for文で回して変換します。

 

 

実行してみると、、

 

 

変換できました。

 

学習用とテスト用に分離

 

入力データ(x)とラベル(y)を学習用に8割、テスト用に2割になるようにそれぞれ分割します。

 

 

ちなみにこのコードを実行すると以下のようにワーニングが出ますが、x_trainをみてみるとちゃんと分割されてたので気にしないことにします。

 

 

おそらく、「(test_sizeとtrain_size)が両方記述されていないと、test_sizeからtrain_sizeを補完するぞ」と言っているようです。

学習とテストの実行

 

学習アルゴリズムを記載して、学習とテストを実行してみます。

 

 

いきなりコードが長くなってしまいました。

 

学習アルゴリズムは難しいので全て理解する必要はありません。

 

クロスエントロピーはクラス分類でよく使われる誤差関数の1つで、アダム方は確率的勾配降下法の1つです。

 

あまり深く考えずに、最初のx、y_、重み、バイアスの次元を設定できればそれで良いかと思います。

 

実行すると以下のような値が出ました。

 

 

ネコ
コードは複雑なのに、機械学習と精度ほとんど変わらないじゃん!

 

確かにそうですねw

 

ちなみに機械学習でやってみた記事は以下にありますので良かったらどうぞ。

 

https://ymgsapo.com/classify-flower/

 

プレースホルダーとか知りたい方は以下の記事もどうぞ。

 

https://ymgsapo.com/tensorflow-beginner/

 

以上です、お疲れ様でした。

 

参考