I am using wandb to upload model checkpoints for my experiments.
When I check the logs, it looks like the files have been uploaded:
2024-02-11 00:46:58,557 INFO SenderThread:3056091 [dir_watcher.py:finish():402] scan save: /scratch/ktes/code/jaxmarl-transfer/wandb/run-20240210_154852-4op00c4x/files/models/multiwalker_v9__ippo_three_walkers_baselines_shared_no_agent_id__30__1707580131/ippo_three_walkers_baselines_shared_no_agent_id_15000576_steps_14648_updates.agent_30_seed/_METADATA models/multiwalker_v9__ippo_three_walkers_baselines_shared_no_agent_id__30__1707580131/ippo_three_walkers_baselines_shared_no_agent_id_15000576_steps_14648_updates.agent_30_seed/_METADATA
...
2024-02-11 00:46:59,818 INFO wandb-upload_2:3056091 [upload_job.py:push():131] Uploaded file /scratch/ktes/code/jaxmarl-transfer/wandb/run-20240210_154852-4op00c4x/files/models/multiwalker_v9__ippo_three_walkers_baselines_shared_no_agent_id__30__1707580131/ippo_three_walkers_baselines_shared_no_agent_id_19999744_steps_19530_updates.agent_30_seed/checkpoints
But when I check the UI, these files are present anywhere - they should be here I believe - Weights & Biases. Where can I find these files?
Hi @kaleabtessera , the log reads that wandb is successfully creating a tracked run job under and artifact and not a checkpoint. How are you saving your checkpoints within your code?
Hi @mohammadbakir ,
I save my checkpoints like this:
from orbax.checkpoint import checkpointer
from orbax.checkpoint.pytree_checkpoint_handler import PyTreeCheckpointHandler
from flax.training import orbax_utils
checkpointers = {}
for agent in range(num_agents):
agent_key = f"agent_{agent}"
checkpointers[agent_key] = checkpointer.Checkpointer(
PyTreeCheckpointHandler(aggregate_filename=f"checkpoints")
)
wb_run = wandb.init(
entity=config["ENTITY"],
project=config["PROJECT"],
tags=["IPPO", "FF","NO-PS"],
config=config,
mode=config["WANDB_MODE"],
name=run_name,
save_code=True,
reinit=True,
group=group_name,
)
...
chp_dir = f"{wandb.run.dir}/models/{config['RUN_NAME']}"
# Should we save the model
if (final_update) or ( next_checkpoint_step and global_step >= next_checkpoint_step):
for agent in range(num_agents):
agent_key = f"agent_{agent}"
agent_identity = f"{agent_key}_{config['SEED']}_seed"
# Save at intervals
model_path = f"{chp_dir}/{config['EXP_NAME']}_{global_step}_steps_{update}_updates.{agent_identity}"
save_args = orbax_utils.save_args_from_target(runner_state[0][agent].params)
checkpointers[agent_key].save(model_path, runner_state[0][agent].params, save_args=save_args)
print(f"model saved to {model_path} at step {global_step}")
if config["WANDB_MODE"] == "online":
# os.path.join since model_path is a folder
wb_run.save( # type: ignore
f"{model_path}/agent_checkpoints",
base_path=model_path,
policy="now",
)
Is there a better way to do this in jax using orbax & wandb? The weird thing is that sometimes I can see it in Files/models, but sometimes not.
Hi @kaleabtessera , I verified that the save functionality works as intended. The only time a file would not be saved is if {model_path}/agent_checkpoints doesn’t exist or incorrect paths are being set in the save function. Could you try to run a toy example on your end. Create some files and execute save functionality and see if you can get this to work. I’ve looked internally and nothing has been flagged in regards to failure to save files within the specified environments.
Hi @kaleabtessera , wanted to check in to see if running a toy example worked and/or if you are still running into issues.
Hi @mohammadbakir , I am think the path is being set correctly since this is an intermittent issue, sometimes it logs, and other times it doesn’t. I think it could be related to multi-processing, possibly also related to - OSError Too many open files: /tmp/tmphv67gzd0wandb-media · Issue #2825 · wandb/wandb · GitHub - since I have also intermittently seen the OSError: [Errno 24] Too many open files
error.
I have switched to logging checkpoints only once at the end, and it seems to work consistently:
artifact = wandb.Artifact(name=f'checkpoint_{config["RUN_NAME"]}', type='checkpoint')
artifact.add_dir(local_path=config["CHP_DIR"])
run.log_artifact(artifact)
Is this the recommended way to save Jax model checkpoints?
Hi @kaleabtessera , there might be a multi-processing component involved especially if different processes are attempting to access/upload same file. Also yes, the recommended method is to use artifacts to log your checkpoints. You can log all checkpoints under a single artifact or different artifacts, it’s users choice.
If you would like us to investigate the multiprocessing component and possible impact, could you provide the debug.log
and debug-internal.log
files located in wandb working directory wandb/<run-path>/logs
for the affected runs. This will help us identify any errors. If yes, feel free to send them to support@wandb.com and reference this community post.
@kaleabtessera, since we have not heard back from you we are going to close this request. If you would like to re-open the conversation, please let us know!