How to wandb.restore a keras model saved using WandbModelCheckpoint

A similar documentation exists for using the legacy WandbCallback.

But the documentation for WandbModelCheckpoint does not indicate how to do the restore.

Any suggestion?

Hi @alanlivio! Thank you for reaching out to W&B support.

As WandbModelCheckpoint save the checkpoints as an artifact in W&B, you should be able to use the Artifact API to download the specific version of the model logged (see here):

api = wandb.Api()
artifact = api.artifact(entity/project_name/artifact_name:artifact_alias)
artifact_dir = artifact.download()

and use the keras.models.load_model() specifying the path in the artifact_dir to load the downloaded checkpoint file.

Let me know if yo have any further questions.

Kind regards,
Francesco

Hi @alanlivio,

We wanted to follow up with you regarding your support request as we have not heard back from you. Please let us know if we can be of further assistance or if your issue has been resolved.

Best,
Francesco

Hi @fmamberti-wandb.
Would it be something like below (similiar to the WandbCallback example but using the WandbModelCheckpoint )?

wandb.init(project="preemptible", resume=True)

if wandb.run.resumed:
    # restore the best model
    artifact = api.artifact("alanlivio/preemptible/artifact_name:models")
    artifact_dir = artifact.download()
    keras.models.load_model(artifact_dir())
else:
    a = keras.layers.Input(shape=(32,))
    b = keras.layers.Dense(10)(a)
    model = keras.models.Model(input=a, output=b)

model.compile("adam", loss="mse")
model.fit(
    np.random.rand(100, 32),
    np.random.rand(100, 10),
    # set the resumed epoch
    initial_epoch=wandb.run.step,
    epochs=300,
    # save the best model if it improved each epoch
      WandbModelCheckpoint("models"),
)

Hi @alanlivio - that should work as long as alanlivio/preemptible/artifact_name:models is the full name of an existing Model artifact that only contains the .pb files for the Keras model.

You may need to change models with v# where # is the number version of the artifact you want to download or with an alias that you have assigned to the artifact.

Each artifact version would also have a Usage tag that you can use to get the correct path if needed.

Kind regards,
Francesco

Thank you again @fmamberti-wandb.

Hi @alanlivio - that should work as long as alanlivio/preemptible/artifact_name:models is the full name of an existing Model artifact that only contains the .pb files for the Keras model.

That is my main question, how can I set the same name used in WandbModelCheckpoint("models") to be the one alanlivio/preemptible/artifact_name:models.

I would be super good if you can update the WandbCallback documentation clarifying that. For the legacy WandbCallback, the documentation is super clear and worked for me.

Thanks and best regards,
Alan

Thank you for your patience @alanlivio and the feedback regarding the documentation, I will raise this with the team to get the docs updated.

When using the WandbModelCheckpoint callback, the model will be uploaded as an artifact with the name run_runid_model as in the screenshot below:

You could set save_best_only=True to save the best model only and in that case, you will be able to use

artifact = api.artifact("alanlivio/preemptible/<run_RUNID_model>:latest")

to restore the best model.

You can also rename the models in the UI by clicking on the meatball menu (three dots) next to the model name if you prefer having a more readable name.

I’ve also noticed that you may need to update the arguments for your model.fit() to have WandbModelCheckpoint passed as a callback, i.e.:

model.fit(...,
callbacks = [ WandbMetricsLogger(log_freq=10),WandbModelCheckpoint(filepath="models/")]
)

You can also find a running example of how to upload models as artifacts with WandbModelCheckpoint on this Notebook, taken from this section.

Let me know if you have any further questions.

Thank you again @fmamberti-wandb. See some comments bellow.

When using the WandbModelCheckpoint callback, the model will be uploaded as an artifact with the name run_runid_model as in the screenshot below:

You could set save_best_only=True to save the best model only and in that case, you will be able to use

artifact = api.artifact("alanlivio/preemptible/<run_RUNID_model>:latest")

For me, this is the main difference to the legacy WandbCallback, which only needs a simple identification, namely, wandb.restore("model-best.h5").name. Moreover, the legacy WandbCallback better supports recovery when resuming a previous failed run. This cannot be done with WandbModelCheckpoint because you need to know the RUNID from the previous failed run.

I’ve also noticed that you may need to update the arguments for your model.fit() to have WandbModelCheckpoint passed as a callback, i.e.:

model.fit(...,
callbacks = [ WandbMetricsLogger(log_freq=10),WandbModelCheckpoint(filepath="models/")]
)

This “models” param also lacks clear documentation. Moreover, it differs from the section name of the screenshot you share, which says “model” instead of “models.”

You can also find a running example of how to upload models as artifacts with WandbModelCheckpoint on this Notebook, taken from this section.

I appreciate your sharing. But the notebook is not exactly my case because does not have a restore action.

Thanks and best regards,
Alan

Hi @alanlivio,

The WandbModelCheckpoint parameter filepath defines the local folder where the trained model checkpoint is saved, rather than the name it’s given to the Model as an artifact on the W&B platform.

The following snippet of code should allow you to resume a Run from the same device where it has failed with the latest integration, in a similar way as restore allowed you to do it previously:

import keras
import numpy as np
import wandb
import os
from wandb.keras import WandbModelCheckpoint, WandbMetricsLogger

model_path = "model_checkpoint" 

wandb.init(project="preemptible", resume=True)

if wandb.run.resumed:
    # restore the model from the local folder if it exists
    if os.path.exists(model_path):
        model = keras.models.load_model(model_path)
    # If local folder does not exist, download the latest model from W&B
    else:
        run_id = wandb.run.id
        api = wandb.Api()
        artifact = api.artifact(f"alanlivio/preemptible/run_{run_id}_model:latest")
        artifact_dir = artifact.download()
        model = keras.models.load_model(artifact_dir)
else:
    # initialize new model
    a = keras.layers.Input(shape=(32,))
    b = keras.layers.Dense(10)(a)
    model = keras.models.Model(inputs=a, outputs=b)
model.compile("adam", loss="mse")
model.fit(
    np.random.rand(100, 32),
    np.random.rand(100, 10),
    # set the resumed epoch
    initial_epoch=wandb.run.step,
    epochs=300,
    # save the best model if it improved each epoch
    callbacks = [
        WandbMetricsLogger(log_freq=10),
        WandbModelCheckpoint(filepath=model_path)
    ]
)

In your working directory, you should see a model_checkpoint folder being created.

This will be used to save the model during the training and to resume the training.

If this folder doesn’t exist when resuming, the model will be downloaded from W&B, without having to specify the run id as this will be automatically filled using run_id = wandb.run.id.

Please let me know if you have any further questions on this.

Kind regards,
Francesco

1 Like

Hi Francesco. Many thanks for your code. But it has two problems.

First, it raises an error in WandbModelCheckpoint (tensorflow ==2.16.1,keras==3.0.5, wandb==0.16.4)

Traceback (most recent call last):
  File "/home/alan/src/tmp/test-wandb-model-checkpoint.py", line 39, in <module>
    WandbModelCheckpoint(filepath=model_path),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alan/bin/miniconda3/lib/python3.11/site-packages/wandb/integration/keras/callbacks/model_checkpoint.py", line 94, in __init__
    super().__init__(
TypeError: ModelCheckpoint.__init__() got an unexpected keyword argument 'options'

Second, it gets locked in an error (see below) if it is resuming from a previous fail (in my case, the error above), and there is no artifact to be downloaded. Maybe could we check if it exists with use_artifact or even use_model/

Traceback (most recent call last):
  File "/home/alan/bin/miniconda3/lib/python3.11/site-packages/wandb/apis/normalize.py", line 41, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/alan/bin/miniconda3/lib/python3.11/site-packages/wandb/apis/public/api.py", line 958, in artifact
    artifact = wandb.Artifact._from_name(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alan/bin/miniconda3/lib/python3.11/site-packages/wandb/sdk/artifacts/artifact.py", line 254, in _from_name
    raise ValueError(
ValueError: Unable to fetch artifact with name alanlivio/preemptible/run_4f5z6x9w_model:latest

Hi @alanlivio, I have checked with the team working on the integrations and WandbModelCheckpoint currently doesn’t support Keras 3, which is why you see the error.

We are currently updating the integration to support Keras 3, however, you would have to configure an environment with Keras 2 to use the integration for the time being. Apologies for any confusion on this.

Regarding your second point, you could wrap the API call to get the artifact:
artifact = api.artifact(f"alanlivio/preemptible/run_{run_id}_model:latest")
in a Try/Except statement to cover cases where no model checkpoint has been uploaded and initialise a new model in that case.

Hi @alanlivio, I wanted to follow up on this to check if you have any follow-up questions or if you are happy with us closing this ticket.

1 Like

Thanks again @fmamberti-wandb.

I did the Keras fix (pip install python=3.9 tensorflow keras=2.9), but it still does not work for me.

Precisely, I run once, kill the run with ctrl+c after some epochs (to simulate a fail), and run again, but it stops with the error below.

wandb: WARNING Ensure read and write access to run files dir: /home/alan/src/tmp/wandb/run-20240316_133153-4igkakhj/files, control this via the WANDB_DIR env var. See https://docs.wandb.ai/guides/track/environment-variables
wandb: - 0.002 MB of 0.002 MB uploaded

See my code below.

import os

import keras
import numpy as np
import tensorflow
import wandb
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint

model_path = "model_checkpoint"

wandb.init(project="preemptible", resume=True)

ent_id, proj_id, run_id = wandb.run.entity, wandb.run.project, wandb.run.id

if wandb.run.resumed:
    # restore the model from the local folder if it exists
    if os.path.exists(model_path):
        model = keras.models.load_model(model_path)
    # If local folder does not exist, download the latest model from W&B
    else:
        api = wandb.Api()
        artifact = api.artifact(f"{ent_id}/{proj_id}/run_{run_id}_model:latest")
        artifact_dir = artifact.download()
        model = keras.models.load_model(artifact_dir)
else:
    # initialize new model
    a = keras.layers.Input(shape=(32,))
    b = keras.layers.Dense(10)(a)
    model = keras.models.Model(inputs=a, outputs=b)
    model.compile("adam", loss="mse")
    model.fit(
        np.random.rand(100, 32),
        np.random.rand(100, 10),
        # set the resumed epoch
        initial_epoch=wandb.run.step,
        epochs=300,
        # save the best model if it improved each epoch
        callbacks=[
            WandbMetricsLogger(log_freq=10),
            WandbModelCheckpoint(filepath=model_path),
        ],
    )

Hi @alanlivio, that warning shouldn’t prevent the training from resuming.

Thanks for sending the code - it seems that you are running .compile() and .fit() only when initiating a new model, as these are indented in the else statement.

If you change the final section of the code as shown below, training should resume successfully:

...
else:
    # initialize new model
    a = keras.layers.Input(shape=(32,))
    b = keras.layers.Dense(10)(a)
    model = keras.models.Model(inputs=a, outputs=b)

# removing indentation to ensure the model is compiled and trained when resuming
model.compile("adam", loss="mse")
model.fit(
    np.random.rand(100, 32),
    np.random.rand(100, 10),
    # set the resumed epoch
    initial_epoch=wandb.run.step,
    epochs=300,
    # save the best model if it improved each epoch
    callbacks=[
        WandbMetricsLogger(log_freq=10),
        WandbModelCheckpoint(filepath=model_path),
    ],
)
1 Like

Thanks @fmamberti-wandb. Your change indeed worked. I simulated with a ctr+c kill for a first execution, which can be reconverted from the last epoch in a second execution.

I suggest removing the part of restoring the model from the local folder if it exists and only restoring from a saved model in the wandb. This will be closer to the example using WandbCallback, and it allows multiple experiments to be run in the same folder without getting confused by using the same “model_checkpoint” folder.

So my final code is bellow:

import os

import keras
import numpy as np
import tensorflow
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint

import wandb

wandb.init(project="preemptible", resume=True)

ent_id, proj_id, run_id = wandb.run.entity, wandb.run.project, wandb.run.id
model_path = f"{wandb.run.dir}/model_checkpoint"

if wandb.run.resumed:
    # If local folder does not exist, download the latest model from W&B
    api = wandb.Api()
    artifact = api.artifact(f"{ent_id}/{proj_id}/run_{run_id}_model:latest")
    artifact_dir = artifact.download()
    model = keras.models.load_model(artifact_dir)
else:
    # initialize new model
    a = keras.layers.Input(shape=(32,))
    b = keras.layers.Dense(10)(a)
    model = keras.models.Model(inputs=a, outputs=b)

# removing indentation to ensure the model is compiled and trained when resuming
model.compile("adam", loss="mse")
model.fit(
    np.random.rand(100, 32),
    np.random.rand(100, 10),
    # set the resumed epoch
    initial_epoch=wandb.run.step,
    epochs=300,
    # save the best model if it improved each epoch
    callbacks=[
        WandbMetricsLogger(log_freq=10),
        WandbModelCheckpoint(filepath=model_path),
    ],
)

Hi @alanlivio, it’s great to know this is now working for you!

Thank you again for your suggestion and feedback, and for surfacing the issue with the documentation. We really appreciate that.

I will now mark this request as solved, but feel free to get back in touch if you have any further questions.

1 Like