Thanks @fmamberti-wandb. Your change indeed worked. I simulated with a ctr+c kill for a first execution, which can be reconverted from the last epoch in a second execution.
I suggest removing the part of restoring the model from the local folder if it exists and only restoring from a saved model in the wandb. This will be closer to the example using WandbCallback, and it allows multiple experiments to be run in the same folder without getting confused by using the same “model_checkpoint” folder.
So my final code is bellow:
import os
import keras
import numpy as np
import tensorflow
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint
import wandb
wandb.init(project="preemptible", resume=True)
ent_id, proj_id, run_id = wandb.run.entity, wandb.run.project, wandb.run.id
model_path = f"{wandb.run.dir}/model_checkpoint"
if wandb.run.resumed:
# If local folder does not exist, download the latest model from W&B
api = wandb.Api()
artifact = api.artifact(f"{ent_id}/{proj_id}/run_{run_id}_model:latest")
artifact_dir = artifact.download()
model = keras.models.load_model(artifact_dir)
else:
# initialize new model
a = keras.layers.Input(shape=(32,))
b = keras.layers.Dense(10)(a)
model = keras.models.Model(inputs=a, outputs=b)
# removing indentation to ensure the model is compiled and trained when resuming
model.compile("adam", loss="mse")
model.fit(
np.random.rand(100, 32),
np.random.rand(100, 10),
# set the resumed epoch
initial_epoch=wandb.run.step,
epochs=300,
# save the best model if it improved each epoch
callbacks=[
WandbMetricsLogger(log_freq=10),
WandbModelCheckpoint(filepath=model_path),
],
)