Wandb sweep using slurm and multi gpu setting

Hi, I am using slurm to submit a sweep using the following file

#SBATCH --job-name=distributed
#SBATCH --account=[ACCOUNT]
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:2
#SBATCH --cpus-per-task=8
#SBATCH --mem=32G

#SBATCH --time=00:05:00
#SBATCH --output=logs/gpu_multi_mpi%j.out 
#SBATCH --error=logs/gpu_multi_mpi%j.out

module purge
module load cuda
module load python/3.9

source ~/venv/bin/activate

wandb agent --count 10 [AGENT_ID]

Here is a copy of the main.py

if __name__ == "__main__":
    # parse the arguments
    parser = argparse.ArgumentParser(
        "Distributed Optimization Script", parents=[get_args_parser()]
    args = parser.parse_args()
    mp.set_start_method("spawn", force=True)

    NODE_ID = os.environ["SLURM_NODEID"]
    rank = int(os.environ["SLURM_PROCID"])
    local_rank = int(os.environ["SLURM_LOCALID"])
    world_size = int(os.environ["SLURM_NTASKS"])
    hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])
    n_nodes = len(hostnames)
    print(NODE_ID, rank, local_rank, world_size, hostnames)
    # get IDs of reserved GPU
    gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",")
    print(f"GPU IDS: {gpu_ids}")

When running this file regularly using slurm (not sweeping) gpu_ids is populated and works fine. However when I attempt to sweep using the code snippet shown earlier - I get a key error.


How can I access the gpu_ids of the requested gpus when doing a sweep?