Hi wandb community. I wrote some code trying to parallelize my wandb sweeps since the model I am working with takes a long time to converge and I have a lot of subprocesses to sweep through. Basically I don’t have the luxury of time right now. Here’s a generalized snippet of my code:
def run_pipeline(args):
# Stuff happens here
# Wandb init
group = "within_session" if session_config["within_session"] else "across_session"
run = wandb.init(name=f"{sessions[i]}_{group}_decoder_run", group=group, config=sweep_config, reinit=True)
# Model training
return results
def run_pipeline_wrapper(args):
# Stuff happens here
run_pipeline(args)
return None
if __name__ == "__main__":
total_runs = 30
agents = 5
runs_per_agent = total_runs // agents
sweep_config = {'method': 'random'}
parameters_dict = {
# Lota of parameters to sweep
}
sweep_config['parameters'] = parameters_dict
# Create a sweep id that stores sweep ids
sweep_id_json_path = 'sweep_id.json'
if not os.path.exists(sweep_id_json_path):
with open(sweep_id_json_path, 'w') as f:
json.dump({}, f)
sweep_id_json = json.load(open(sweep_id_json_path, 'r'))
# Sessions_list = number of unique data that I need to run my sweeps
for i in range(len(sessions_list)):
# Preparing a partial method to pass
run_pipeline_with_args = partial(run_pipeline_wrapper, args)
# I cache the existing sweep_ids in a json file to help in attaching sweep ids if I rerun the code again
if f"{sessions_list[i]}_{is_within}" not in sweep_id_json:
sweep_id = wandb.sweep(sweep_config, project=f"HPC_model_{sess}_session_{data}_{data_type}")
else:
sweep_id = wandb.sweep(sweep_config, project=f"HPC_model_{sess}_session_{data}_{data_type}"
, prior_runs=sweep_id_json[f"{sessions_list[i]}_{is_within}"])
# This is the parallelization logic, where I parallelize the sweeps
with concurrent.futures.ThreadPoolExecutor(max_workers=agents) as executor:
futures = [
executor.submit(wandb.agent, sweep_id, run_pipeline_with_args, count=runs_per_agent)
for _ in range(agents)
]
concurrent.futures.wait(futures)
When I run this code, I am basically stuck on wandb.init(), with that process eventually being terminated due to a timeout. I don’t think this is a problem of increasing wandb’s timeout. How do I fix this? Do you think this might be a problem because of my parallelization logic? If so, how do you devs parallelize your wandb sweeps in-code?
Attached logs: