Resuming training

Hey everyone, Im new to WandB and would love some advice.

This is my current setup:

  1. Run the model first time and save the model every epoch (based on a variable) using the following:

log wandb artifact

            model_artifact = wandb.Artifact(
                aliases=[f'step_{global_step}', f'epoch_{epoch}']
  1. i have resume as ‘True’ in the configs
  2. I then load the last saved model (i am using diffusion from hugging face):

print(“Resuming run…”)
artifact_name = args.model_resume_name
artifact = wandb.use_artifact(artifact_name)

    # Download the model file(s) and return the path to the downloaded artifact
    artifact_dir =

    pipeline = AudioDiffusionPipeline.from_pretrained(artifact_dir)

    mel = pipeline.mel
    model = pipeline.unet

How do i continue training from the last epoch i left off from? Is 3) above even necessary? does the resume load the optimizer settings, learning rate at specific epoch?

The docs are not very clear.

I hope i am articulating myself properly.


Hi Mark,

I responded to your issue via email shortly ago, but will respond here as well for visibility.

It looks like you’re storing your epochs as aliases, and in order to properly resume training from a given epoch, you need to access that explicitly under your if line. One way you could go about doing that is by including the following line underneath to properly access the correct epoch:

start_epoch = int(filter(lambda alias: alias.startswith(‘epoch’), artifact.aliases)[0].split(‘_’)[1])

which would give you the correct epoch to start at going forward.

Let me know if you need anything else!


Where would i use the start_epoch once i have it in the training loop…where would i apply this to make it work?

I assume it would be in the training loop and changing it to:

for epoch in range(start_epoch, args.num_epochs):

Hi Mark,

You are definitely correct to assume it would be in the training loop and that you need to change that particular line of code. Since you save the most recent epoch # when you save the artifact, you should be referencing start_epoch+1 to get the following epoch.

Also, since you are calling scheduler.step() and optimizer.step(), be sure to save those (either as an artifact or anything else of your choosing) to ensure you’re using the correct values when resuming from a specific epoch.



This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.