I’m seeing the weirdest global_step restarts during my training runs. And it does affect the optimizer as it restarts the warmup every time global_step goes to zero.
This is in Jax using a joint scheduler from Optax to create the learning rate function. I log separately the step number when doing evals and it the counter seems fine, but for some reason the logged values in wandb are pretty strange.
The issue you’re experiencing might be due to multiple calls to wandb.log for the same training step. The wandb SDK has its own internal step counter that is incremented every time a wandb.log call is made. This means that there is a possibility that the wandb log counter is not aligned with the training step in your training loop.
To avoid this, you can specifically define your x-axis step using wandb.define_metric. You only need to do this once, after wandb.init is called. Here is an example:
The glob pattern, “*”, means that every metric will use “global_step” as the x-axis in your charts. If you only want certain metrics to be logged against “global_step”, you can specify them instead: