Error on torch.load for a run's saved file

Hi,
I want to load a pt file of a run which is downloaded using WandB api. but this error is raised:
'utf-8' codec can't decode byte 0xaa in position 4: invalid start byte

My code is:

api = wandb.Api()
runs = api.runs('USERNAME/PROJ')
model_path = list(list(runs)[0].files())[1].download()
model = torch.load(model_path)

HI @sadra-barikbin , it sounds like an encoding/decoding issue to me. Can you try decoding in the following way

import io
with open(model_path, 'rb') as f:
    buffer = io.BytesIO(f.read())
model = torch.load(model_path)

If that doesn’t work, check out the PyTorch docs for torch.load to see some other possible fixes.

This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.