Clarification on Early Termination (Hyperband)

Looking to run a large HP Sweep. I’m using Pytorch Lightning (PL) for the model training and W&B for experiment management. I have already have it up and running (using the WandbLogger callback for PL) and I’m hoping to utilize the Early Termination feature that W&B provides, but I find the documentation a little confusing. I’d like to do something as simple as: Check some metric at ~[100, 500, 2500, …] epochs and terminate those who aren’t performing well.

The documentation for HyperBand says:

Brackets are based on the number of logged iterations, i.e. elements in the run’s history. Depending on where you are calling wandb.log , these iterations may correspond to steps, epochs, or something in between. The numerical value of the step counter is not used in bracket calculations.

I’m having trouble deciphering this.

For example, in my case I only check (and log) validation stuff every 10 epochs (check_val_every_n_epoch=10 is fed to the PL trainer) to save compute. I also log things at the end of a train batch, train epoch, and validation epoch. I log both dicts (in PL: self.log_dict() ) and values (self.log()). In some cases the logger flag is False in the call: self.log( , logger=False) or self.log_dict( , logger=False).

So how is the number of logged iterations computed? How do I go about using this to achieve my original goal: check some metric every [100, 500, 2500, …] epochs and terminate the ‘bad’ ones (as per the HyperBand alg)?

Thanks in advance,
Max

1 Like

Thanks for the Q!
Allow me to look into this and get back with an answer :slight_smile:

Any update on this? Would save me lots of GPU hours :upside_down_face: