Kerasの処理テンプレート

メモ

  • モデルのcompile時に与えるmetricsは、学習の各エポック毎に計算する学習の指標を表すもの。損失関数は何もやらなくても計算しているので、損失関数以外を指定する。自分で関数を作っても良いが、大概は用意されている。良くあるサンプルではaccuracyが指定されているが、これは分類問題では損失関数がクロスエントロピーなのに対して、実際の正解率を計算してくるもの。
  • kerasの終了時に、「tensorflowがNoneはdelは無い」みたいなエラーを出すときは、バックエンドのセッションのクリアをちゃんと呼ぶようにする。tfが別スレッドで動いていて、tfが終了する前に親プロセスのkerasが終了してしまう、みたいな状況なのかな?
  • scikit-learnとの融合で複雑な交差検証は出来るけれど、シンプルにやるだけならfitにvalidation_splitを指定する。
  • fitの返り値はエポック毎のlossとmetricsの値を保存している。返り値をretとすると、ret.history['loss']などでアクセスできる。validation_splitを指定していればret.history['val_loss']も保存されている。
  • 結果を保存するとき、モデルはmodel.saveで良い。モデルをjsonで、重みは別途model.save_weightsで保存する例がドキュメントに書いてあるけれど、何でだろう?モデルの構成だけ保存したい(重みは大きいから不要)とかいう状況あるのかな?。また、fitの返り値はpythonオブジェクトなのでpickleで保存する。(その際、返り値自体を保存しようとするとなぜかエラーが起きる。なんでだろう?。とりあえず、historyだけなら保存出来た。)
import numpy
import matplotlib.pyplot as plt
import pickle
import keras.backend as K
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping

def gen_model():
    model = Sequential()
    model.add(Dense(12, input_dim=8, activation='relu'))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy', optimizer='adam',
                  metrics=['accuracy'])
    return model

def plot_results(history):
    plt.title('learning history')
    plt.xlabel('epochs')
    plt.ylabel('loss or accuracy')
    plt.plot(history.history['loss'], label='loss')
    plt.plot(history.history['val_loss'], label='val_loss')
    plt.grid()
    plt.legend()

# Data preparetion
dataset = numpy.loadtxt("pima_indians_diabetes.csv", delimiter=",")
X = dataset[:, 0:8]
Y = dataset[:, 8]

# Model generation
model = gen_model()
model.summary()

# Model fitting with validation
early_stopping = EarlyStopping(patience=3)
history = model.fit(X, Y, batch_size=16, epochs=100, callbacks=[early_stopping],validation_split=0.1)

# Save results
model.save("model.hd5")
with open('learning_history.pkl', 'wb') as f:
    pickle.dump(history.history, f)

# plot loss and validation loss
plot_results(history)

K.clear_session()