ほぼ週刊ハマりどころメモ

筆者が画像認識(CV)と自然言語処理(NLP)を研究する中でハマった点を共有することで、世の研究者から余計な時間が奪われることを防げたらいいなぁ...

Kerasで花画像の名前を当てる~検出編~[Keras][CNN][深層学習][Python]

vastee.hatenablog.com

の続き

今回は学習したモデルを使って花の名前を当てるプログラムを作成する.

import numpy as np

from keras.preprocessing.image import load_img, img_to_array
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.models import Sequential, Model
from keras.layers import Input, Dense, Dropout, Activation, Flatten
from keras.optimizers import SGD

classes = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
nb_classes = len(classes)
img_rows, img_cols = 224, 224

def build_model() :
    input_tensor = Input(shape=(img_rows, img_cols, 3))
    vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)

    _model = Sequential()

    _model.add(Flatten(input_shape=vgg16.output_shape[1:]))
    _model.add(Dense(256, activation='relu'))
    _model.add(Dropout(0.5))
    _model.add(Dense(nb_classes, activation='softmax'))

    model = Model(inputs=vgg16.input, outputs=_model(vgg16.output))
    
    for layer in model.layers[:15]:
        layer.trainable = False

    model.compile(loss='categorical_crossentropy',
                  optimizer=SGD(lr=1e-4, momentum=0.9), metrics=['accuracy'])
    return model

if __name__ == "__main__":
    model = build_model()
    # 学習時に使ったモデルを使用
    model.load_weights("./flower-model.hdf5")

    filename = "./daisy.jpg"

    # kerasに画像を読み込む関数がある
    img = load_img(filename, target_size=(img_rows, img_cols))
    x = img_to_array(img)
    x = np.expand_dims(x, axis=0)
    
    predict = model.predict(preprocess_input(x))

    for pre in predict:
        y = pre.argmax()
        print("Name: ", classes[y])

f:id:Vastee:20180924224034j:plain

daisy.jpg

>> Name: daisy

ちゃんと合ってた.