In addition to sequential models
and models created with the functional
API, you may also define models by defining a custom
call()
(forward pass) operation.
To create a custom Keras model, you call the
keras_model_custom()
function, passing it an R function
which in turn returns another R function that implements the custom
call()
(forward pass) operation. The R function you pass
takes a model
argument, which provides access to the
underlying Keras model object should you need it.
Typically, you’ll wrap your call to keras_model_custom()
in yet another function that enables callers to easily instantiate your
custom model.
This example demonstrates the implementation of a simple custom model that implements a multi-layer-perceptron with optional dropout and batch normalization:
library(keras)
keras_model_simple_mlp <- function(num_classes,
use_bn = FALSE, use_dp = FALSE,
name = NULL) {
# define and return a custom model
keras_model_custom(name = name, function(self) {
# create layers we'll need for the call (this code executes once)
self$dense1 <- layer_dense(units = 32, activation = "relu")
self$dense2 <- layer_dense(units = num_classes, activation = "softmax")
if (use_dp)
self$dp <- layer_dropout(rate = 0.5)
if (use_bn)
self$bn <- layer_batch_normalization(axis = -1)
# implement call (this code executes during training & inference)
function(inputs, mask = NULL, training = FALSE) {
x <- self$dense1(inputs)
if (use_dp)
x <- self$dp(x)
if (use_bn)
x <- self$bn(x)
self$dense2(x)
}
})
}
Note that we include a name
parameter so that users can
optionally provide a human readable name for the model.
Note also that when we create layers to be used in our forward pass
we set them onto the self
object so they are tracked
appropriately by Keras.
In call()
, you may specify custom losses by calling
self$add_loss()
. You can also access any other members of
the Keras model you need (or even add fields to the model) by using
self$
.
To use a custom model, just call your model’s high-level wrapper function. For example:
library(keras)
# create the model
model <- keras_model_simple_mlp(num_classes = 10, use_dp = TRUE)
# compile graph
model %>% compile(
loss = 'categorical_crossentropy',
optimizer = optimizer_rmsprop(),
metrics = c('accuracy')
)
# Generate dummy data
data <- matrix(runif(1000*100), nrow = 1000, ncol = 100)
labels <- matrix(round(runif(1000, min = 0, max = 9)), nrow = 1000, ncol = 1)
# Convert labels to categorical one-hot encoding
one_hot_labels <- to_categorical(labels, num_classes = 10)
# Train the model
model %>% fit(data, one_hot_labels, epochs=10, batch_size=32)