Thanks Ramit!
Here’s how everything currently works. In Sagemaker, I define my Tensorflow estimator as follows:
tf_estimator = TensorFlow(
...
checkpoint_s3_uri = f's3://ds-models-1/{model_name}/{run_name}',
checkpoint_local_path= '/opt/ml/checkpoints/')
Then in my training script, I make the Checkpoint callback:
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath='/opt/ml/checkpoints/',
save_weights_only=False,
monitor='val_accuracy',
mode='max',
save_best_only=True)
So the model checkpoints are saved to ‘/opt/ml/checkpoints/’ and Sagemaker monitors that folder and pushes to my S3 bucket.
After training completes in the script, I create my Wandb artifact:
art = wandb.Artifact(job_name, type="model")
art.add_reference(f's3://ds-models-1/{model_name}/{run_name}')
wandb.log_artifact(art)
All of that works well and I can see my model uploaded to S3 as well as listed as an artifact in Wandb.
When I then go into my notebook to try to download the model, here’s what I get:
run = wandb.init(name="download_model", project=project, job_type="testing")
model_path = run.use_artifact(wandb_model_name, type='model')
model_path.download()
ValueError Traceback (most recent call last)
/tmp/ipykernel_11366/3182984502.py in <cell line: 4>()
2
3 model_path = run.use_artifact(‘distributedspectrum/RadioML-Experimentation/RadioML-1d-conv-training-run-1:v0’, type=‘model’)
----> 4 model_path.download()
~/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/wandb/apis/public.py in download(self, root, recursive)
4521
4522 pool = multiprocessing.dummy.Pool(32)
→ 4523 pool.map(
4524 partial(self._download_file, root=dirpath, download_logger=download_logger),
4525 manifest.entries,
~/anaconda3/envs/tensorflow2_p38/lib/python3.8/multiprocessing/pool.py in map(self, func, iterable, chunksize)
362 in a list that is returned.
363 ‘’’
→ 364 return self._map_async(func, iterable, mapstar, chunksize).get()
365
366 def starmap(self, func, iterable, chunksize=None):
~/anaconda3/envs/tensorflow2_p38/lib/python3.8/multiprocessing/pool.py in get(self, timeout)
769 return self._value
770 else:
→ 771 raise self._value
772
773 def _set(self, i, obj):
~/anaconda3/envs/tensorflow2_p38/lib/python3.8/multiprocessing/pool.py in worker(inqueue, outqueue, initializer, initargs, maxtasks, wrap_exception)
123 job, i, func, args, kwds = task
124 try:
→ 125 result = (True, func(*args, **kwds))
126 except Exception as e:
127 if wrap_exception and func is not _helper_reraises_exception:
~/anaconda3/envs/tensorflow2_p38/lib/python3.8/multiprocessing/pool.py in mapstar(args)
46
47 def mapstar(args):
—> 48 return list(map(*args))
49
50 def starmapstar(args):
~/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/wandb/apis/public.py in _download_file(self, name, root, download_logger)
4620 ):
4621 # download file into cache and copy to target dir
→ 4622 downloaded_path = self.get_path(name).download(root)
4623 if download_logger is not None:
4624 download_logger.notify_downloaded()
~/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/wandb/apis/public.py in download(self, root)
3943 manifest = self._parent_artifact._load_manifest()
3944 if self.entry.ref is not None:
→ 3945 cache_path = manifest.storage_policy.load_reference(
3946 self._parent_artifact,
3947 self.name,
~/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/wandb/sdk/wandb_artifacts.py in load_reference(self, artifact, name, manifest_entry, local)
962 local: bool = False,
963 ) → str:
→ 964 return self._handler.load_path(artifact, manifest_entry, local)
965
966 def _file_url(
~/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/wandb/sdk/wandb_artifacts.py in load_path(self, artifact, manifest_entry, local)
1120 ‘No storage handler registered for scheme “%s”’ % str(url.scheme)
1121 )
→ 1122 return self._handlers[str(url.scheme)].load_path(
1123 artifact, manifest_entry, local=local
1124 )
~/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/wandb/sdk/wandb_artifacts.py in load_path(self, artifact, manifest_entry, local)
1430 )
1431 else:
→ 1432 raise ValueError(
1433 “Digest mismatch for object %s: expected %s but found %s”
1434 % (manifest_entry.ref, manifest_entry.digest, etag)
ValueError: Digest mismatch for object s3://ds-models-1/RadioML-1d-conv/training-run-1/saved_model.pb: expected bd308623d5d8db45d883aa98e570147d but found 584fe63cdb5cff0663c058e57527dded
Sometimes, it’ll complain about the “saved_model.pb” file and sometimes it’ll be one of the Keras variable files.