Runs are overwritten when launched with wandb sweep

I am attempting to perform cross-validation over N folds for given a model configuration. The script runs each fold sequentially. Each fold is associated with an independent WandB Run. When the script is run manually for a single model configuration, each fold is uploaded as an independent Run, as expected. However, when the script is launched automatically via wandb sweep, each fold of each model configuration overwrites the previous one and so only a single Run is uploaded.

The script includes while loop that waits until the Run instance is finished. However, the output files indicate that this loop is never entered (i.e. the Run instance is closed after calling wnb.finish()) because no print statement is ever triggered.

(This following MWE can also be produced through this Github Repo).

Here’s the script:

# run.py
import argparse
import numpy as np
import time
import wandb

def main(d1: int, d2: int, n_folds: int, seed: int=0, project: str=None):

    rng = np.random.default_rng(seed=seed)

    for i in range(n_folds):
        print()

        # Setup WandB
        wnb = wandb.init(project=project, config=dict(d1=d1, d2=d2))

        print(f"[Fold {i}/{n_folds}]: Starting Run (id={wnb.id})...")

        # Do something
        arr = rng.normal(size=(d1,d2))

        # Log some value
        wnb.summary['dummy'] = arr.sum()

        # Close wandb
        wnb.finish()
        while (wandb.run is not None):
            print(f"{time.asctime()} Waiting for wandb to finish...")
            time.sleep(5)  # units: seconds

    return

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--d1', type=int)
    parser.add_argument('--d2', type=int)
    parser.add_argument('--n_folds', type=int)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--project', type=str, default='debug_wandb_sweep_multirun')

    args: dict = vars(parser.parse_args())

    main(**args)

When I run the above script from the command line,

python run.py --d1=3 --d2=2 --n_folds=3

I get 3 different Runs, as expected.

However, when I launch the same script with the same arguments via wandb sweep, I only get a single Run. Upon inspecting the output messages, each Run is overwriting the previous run.

# config.yaml
program: run.py
project: debug_wandb_sweep_multirun
method: grid
parameters:
  d1:
    values: [3]
  d2:
    values: [2]
command:
  - ${env}
  - ${interpreter}
  - ${program}
  - ${args}
  - "--n_folds=3"

Launching this script via sweeps would really assist in cross-validating over all possible model hyperparameters. Any help or suggestions would be appreciated!


Wandb Project


Versions

  • python==3.12.8
  • wandb==0.19.4
  • numpy==2.2.1