It seems multiple people are facing this problem, and the debug logs are very uninformative.
It seems to be sweeps only because vanilla WandB works fine.
Still, I’ve attached them in case anyone wants to check them out.
Code
I’m doing something like:
def init_wandb_sweep(self, args: dict, wandb: Callable = None) -> int:
'''
Setup Wandb Seep configs. Only run after wandb_logger() has been called
'''
sweep_configuration = {
"method": "random",
"name": "sweep",
"metric": {"goal": "maximize", "name": "Train/acc"},
"parameters": {
"lr": {"max": 1e-2, "min": 1e-5},
"drop_rate": {"max": 0.2, "min": 0.0},
"weight_decay": {"max": 1e-3, "min": 1e-5},
"grad_clip": {"max": 1.0, "min": 0.1},
},
}
assert wandb is not None, "Wandb logger not initialized"
sweep_id = wandb.sweep(sweep=sweep_configuration, project="<...>",
entity='<...>')
args.lr = wandb.config.lr
args.drop_rate = wandb.config.drop_rate
args.weight_decay = wandb.config.weight_decay
args.grad_clip = wandb.config.grad_clip
return sweep_id, args
which is called by:
sweep_id, new_args = logger.init_wandb_sweep(args, wandb_logger)
if args.tune_hyperparams:
args = new_args
trainer = Trainer(args, logger=(my_logger, wandb_logger),
loaders=(trainloader, valloader),
decode_fn=train_dataset.tok.decode,
shard=shard,
key=key)
wandb_logger.agent(sweep_id, function=trainer.train, count=1)