Strange global_step restarts affecting learning rates and performance?

I’m seeing the weirdest global_step restarts during my training runs. And it does affect the optimizer as it restarts the warmup every time global_step goes to zero.

This is in Jax using a joint scheduler from Optax to create the learning rate function. I log separately the step number when doing evals and it the counter seems fine, but for some reason the logged values in wandb are pretty strange.

hi @versae

Thank you for reaching out for support. I’ll check this on our end and we’ll get back to you for updates.

Hi @versae,

The issue you’re experiencing might be due to multiple calls to wandb.log for the same training step. The wandb SDK has its own internal step counter that is incremented every time a wandb.log call is made. This means that there is a possibility that the wandb log counter is not aligned with the training step in your training loop.

To avoid this, you can specifically define your x-axis step using wandb.define_metric. You only need to do this once, after wandb.init is called. Here is an example:

wandb.init(...)
wandb.define_metric("*", step_metric="global_step")

The glob pattern, “*”, means that every metric will use “global_step” as the x-axis in your charts. If you only want certain metrics to be logged against “global_step”, you can specify them instead:

wandb.define_metric("train/loss", step_metric="global_step")

This should help align your global_step with the internal wandb step counter and prevent it from restarting to zero.

Hi @versae ,

I just want to follow up if this helps and you still need assistance.

Regards,
Carlo Argel

Hi!

We’re still investigating. Thanks for your help!

Cheers.

Hi @versae ,

Thank you for informing us. We are closing this out due to internal tracking purposes. You can write back anytime you are ready to proceed with this.

Regards,
Carlo Argel

This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.