ユニファ開発者ブログ

ユニファ株式会社システム開発部メンバーによるブログです。

Keras Functional API

By Matthew Millar R&D Scientist at ユニファ

What is Keras functional API?

Most people are used to the Sequential model from Keras as it is a straightforward method for creating simple models. The functional API is Keras way of creating far more complex models. This can allow for the creation of models with multiple inputs and outputs, different types of inputs, merging inputs, having two loss functions, and more.

Code Comparison:

So, let’s look at the most basic model possible. Using the MNIST dataset that is already included in Keras is an easy model and dataset that is available for everyone and should need no introduction. So I will skip the setup, loading, and training-test splits of the data and go into the model. The below code is a basic setup for a Sequential model to learn how to recognize handwritten numbers. This code sample comes from the Keras team GitHub [1].

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))

Easy right? Now we can build a similar model using the Functional API from Keras. Looking at them compared side by side, they are very similar. But now you don’t need Sequential to be defined.
First, we will need to import a few more modules:

from keras.layers import Input, Dense
from keras.models import Model

These modules are needed for the Functional API.
Then we need the first part defines the input shape much like this from the original Sequential model.

# Sequntial way
model.add(Conv2D(32, kernel_size=(3, 3),activation='relu',input_shape=input_shape))
# Which is the same as
# Functional API
inputs = Input(shape=(input_shape))
# Define the Conv2d Layer
x = Conv2D(32, kernel_size=(3, 3),activation='relu')(inputs)

The next lines are the same as they start building out the architecture. So, they have the same setup. The next difference is the output this is where you define the output and the model.

predictions = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=inputs, outputs=predictions)

The last layer (prediction) is pretty much the same as the last Fully connected layer in the basic model.
So you should end up with something that looks like this:

# Define input shape as the input Reuse the original inputshape
inputs = Input(shape=(input_shape))
# Define the Conv2d Layer
x = Conv2D(32, kernel_size=(3, 3),activation='relu')(inputs)
x = Conv2D(64, kernel_size=(3, 3),activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.25)(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.5)(x)
predictions = Dense(num_classes, activation='softmax')(x)

# This creates a model that includes
# the Input layer and three Dense layers
functional_model = Model(inputs=inputs, outputs=predictions)
functional_model.compile(loss=keras.losses.categorical_crossentropy,
                         optimizer=keras.optimizers.Adadelta(),
                         metrics=['accuracy'])
functional_model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test)) 

Results:

As you can see from the scoring the two methods produced pretty much the same results. The added advantage though with the Functional API model is that it is more extendable and far more customizable. When performing a more complex task, the use of the Functional API may be mandatory as a single Sequential model cannot handle the complexity of it.
Now, what is the point you may say? The biggest benefit is not the model defined above can then be used as another layer in another model like so:

x = Input(shape=(input_shape))
pred = functional_model(x)

That will produce the classification results of any input that is sent in. This can be used to aid a classification into a video feed, or a more complex model needed multiple types of inputs.
Trining of the models behave the same as well and yield similar results too.

Sequential Model Training.
Train on 60000 samples, validate on 10000 samples
Epoch 1/12
60000/60000 [==============================] - 211s 4ms/step - loss: 0.2604 - acc: 0.9208 - val_loss: 0.0589 - val_acc: 0.9797
Epoch 2/12
60000/60000 [==============================] - 203s 3ms/step - loss: 0.0870 - acc: 0.9746 - val_loss: 0.0395 - val_acc: 0.9868
Epoch 3/12
60000/60000 [==============================] - 202s 3ms/step - loss: 0.0648 - acc: 0.9800 - val_loss: 0.0374 - val_acc: 0.9879
Epoch 4/12
60000/60000 [==============================] - 201s 3ms/step - loss: 0.0541 - acc: 0.9837 - val_loss: 0.0395 - val_acc: 0.9868
Epoch 5/12
60000/60000 [==============================] - 203s 3ms/step - loss: 0.0465 - acc: 0.9857 - val_loss: 0.0275 - val_acc: 0.9907
Epoch 6/12
60000/60000 [==============================] - 206s 3ms/step - loss: 0.0407 - acc: 0.9879 - val_loss: 0.0288 - val_acc: 0.9900
Epoch 7/12
60000/60000 [==============================] - 203s 3ms/step - loss: 0.0381 - acc: 0.9887 - val_loss: 0.0258 - val_acc: 0.9925
Epoch 8/12
60000/60000 [==============================] - 212s 4ms/step - loss: 0.0337 - acc: 0.9897 - val_loss: 0.0298 - val_acc: 0.9900
Epoch 9/12
60000/60000 [==============================] - 211s 4ms/step - loss: 0.0311 - acc: 0.9901 - val_loss: 0.0257 - val_acc: 0.9927
Epoch 10/12
60000/60000 [==============================] - 211s 4ms/step - loss: 0.0290 - acc: 0.9909 - val_loss: 0.0264 - val_acc: 0.9918
Epoch 11/12
60000/60000 [==============================] - 206s 3ms/step - loss: 0.0271 - acc: 0.9916 - val_loss: 0.0254 - val_acc: 0.9922
Epoch 12/12
60000/60000 [==============================] - 201s 3ms/step - loss: 0.0265 - acc: 0.9918 - val_loss: 0.0278 - val_acc: 0.9920
Functional API Trainig
Train on 60000 samples, validate on 10000 samples
Epoch 1/12
60000/60000 [==============================] - 213s 4ms/step - loss: 0.2768 - acc: 0.9142 - val_loss: 0.0583 - val_acc: 0.9812
Epoch 2/12
60000/60000 [==============================] - 205s 3ms/step - loss: 0.0947 - acc: 0.9721 - val_loss: 0.0477 - val_acc: 0.9842
Epoch 3/12
60000/60000 [==============================] - 202s 3ms/step - loss: 0.0696 - acc: 0.9802 - val_loss: 0.0363 - val_acc: 0.9883
Epoch 4/12
60000/60000 [==============================] - 203s 3ms/step - loss: 0.0566 - acc: 0.9831 - val_loss: 0.0319 - val_acc: 0.9893
Epoch 5/12
60000/60000 [==============================] - 201s 3ms/step - loss: 0.0495 - acc: 0.9854 - val_loss: 0.0331 - val_acc: 0.9892
Epoch 6/12
60000/60000 [==============================] - 202s 3ms/step - loss: 0.0432 - acc: 0.9864 - val_loss: 0.0293 - val_acc: 0.9904
Epoch 7/12
60000/60000 [==============================] - 205s 3ms/step - loss: 0.0393 - acc: 0.9879 - val_loss: 0.0284 - val_acc: 0.9903
Epoch 8/12
60000/60000 [==============================] - 196s 3ms/step - loss: 0.0341 - acc: 0.9893 - val_loss: 0.0273 - val_acc: 0.9916
Epoch 9/12
60000/60000 [==============================] - 202s 3ms/step - loss: 0.0319 - acc: 0.9900 - val_loss: 0.0249 - val_acc: 0.9919
Epoch 10/12
60000/60000 [==============================] - 210s 3ms/step - loss: 0.0297 - acc: 0.9904 - val_loss: 0.0324 - val_acc: 0.9898
Epoch 11/12
60000/60000 [==============================] - 212s 4ms/step - loss: 0.0285 - acc: 0.9911 - val_loss: 0.0248 - val_acc: 0.9922
Epoch 12/12
60000/60000 [==============================] - 209s 3ms/step - loss: 0.0272 - acc: 0.9915 - val_loss: 0.0283 - val_acc: 0.9921

And the final results are the same as well.

Sequential
Test loss: 0.027761173594164575
Test accuracy: 0.992
Functional
Test loss: 0.028270527327229955
Test accuracy: 0.9921

A Better Example! Image Similarity:

This model will use a ResNet50 pre-trained model to create the vectors used for image comparison. For each image, the features will be calculated and then merged into on input for the Fully Connected layers. But, honestly, any CNN will work you can even define your own CNN and use it to extract features. The final layer will produce a probability that the two images are similar or not based on a threshold. This model will not do very complex comparisons as it is too simple. But for images of scenery, it should get satisfactory results.
The basic model for image similarity can be done like this:

input_shape = (224, 224, 3)
base_network = resnet50.ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)

input_1 = Input(shape=(input_shape))
input_2 = Input(shape=(input_shape))

vector_1 = base_network(input_1)
vector_2 = base_network(input_2)

# Get the distance between images
merged = Lambda(absdiff, output_shape=absdiff_output_shape)([vector_1, vector_2])

fc1 = Dense(1024)(merged)
fc1 = BatchNormalization()(fc1)
fc1 = Dropout(0.4)(fc1)
fc1 = Activation("relu")(fc1)

fc2 = Dense(2048)(fc1)
fc2 = BatchNormalization()(fc2)
fc2 = Dropout(0.4)(fc2)
fc2 = Activation("relu")(fc2)

fc3 = Dense(4096)(fc2)
fc3 = BatchNormalization()(fc3)
fc3 = Dropout(0.3)(fc3)
fc3 = Activation("relu")(fc3)

fc4 = Dense(4096)(fc3)
fc4 = Activation("relu")(fc4)

fc5 = Flatten()(fc4)
pred = Dense(2, kernel_initializer="glorot_uniform")(fc5)
pred = Activation("sigmoid", name="A_2")(pred)

model = Model(inputs=[input_1, input_2], outputs=pred)

model.compile(optimizer='adam', loss="binary_crossentropy", metrics=["accuracy"])
NUM_EPOCHS = 10
history = model.fit_generator(train_gen,
                              steps_per_epoch=num_train_steps,
                              epochs=NUM_EPOCHS,
                              validation_data=val_gen,
                              validation_steps=num_val_steps,
                              verbose = 1)

Conclusion:

Now can you see the usefulness of Functional API in Keras? This is just the tip of the iceberg on what can be accomplished with this API. There are many more possibilities to be had.
This API is not limited to images but can be used to define any complex model with multiple inputs and outputs. Using for natural language processing or even complex analysis of the stock market where there are numerical and nonnumerical data used in the same model.