Saving and Loading Models#

There are three ways of saving and reloading models:

  1. saving the model weights manually

  2. saving the model weights automatically during training using a callback function

  3. saving the full model

At any point you can save the model weights manually by doing

model.save_weights('filename.h5')

You can then later restore the weights by loading them from the file with

model.load_weights('filename.h5')

Note that for reloading the weights an instance of the model is needed.

This can also be done automatically during the training loop using a callback function:

checkpoint_path = 'training/model-e{epoch:04d}.ckpt'
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path)
model.fit(..., callbacks=[checkpoint_callback])

By default, this will save the model weights after every single epoch. The path of the latest checkpoint can then be retrieved with

tf.train.latest_checkpoint('training')

This mechanism is handy for keeping persistent copies of the model so that an interrupted training can later be resumed.

Finally, it is possible to save the full model, including its definition and architecture, with

model.save('model.h5')

Then, the model can be reused even without knowing the exact architecture using

model = tf.keras.models.load_model('model.h5')