Best practice to efficiently log GPU PyTorch tensors to wandb?

Hi there,

I am using reinforcement learning and have quite a complicated training procedure. To make sure everything is working properly it is important to me to log as much as I can to my wandb dashboard. However most of these quantities are PyTorch tensors on GPU and the way I am logging seems quite inefficient.

My current logging setup looks something like this

batch_d = dict()
batch_d['logp']        =           logp.detach().cpu().tolist()
batch_d['loss']        =           loss.detach().cpu().tolist()
batch_d['loss_sum']    =       loss_sum.detach().cpu().tolist()
batch_d['loss_batch']   =    loss_batch.detach().cpu().tolist()
# ... ~10 other similar things tracked here too
# ...convert quantities to wandb Histograms and similar
wandb.log(batch_d)

However all this detaching and moving to cpu slows down performance (e.g. as mentioned by this article on PyTorch efficiency). Hence I was wondering if there was a better way I can log all these quantities.

Thanks,

Tom

Any ideas are very appreciated.

To be honest, I’ve found that real-time batch training metrics are a bit overrated and sometimes it’s better to just train faster, what I personally do is create an entirely separate “training_eval” workflow (identical to running validation with no gradients, but executed on the training dataset). This might get you the speedups you need depending on what your bottleneck is, and you can also always set this workflow to run e.g. every 5 epochs for a guaranteed speedup.

Of course you also lose a bit of logging, but from my view training metrics are only a supplemental thing to compare to validation metrics, sometimes it’s better just to train something faster rather than to log everything.

Even something as simple as print(some_tensor_on_gpu) needs to move the tensor from GPU to CPU, so that overhead is an unavoidable action when logging.

Off the top of my head, the only thing I can think of at the moment is perhaps if you could dump all the detach() cpu() numpy() tolist() etc. operations followed by wandb.log to a thread that asynchronously runs in the background and thus does not block your main training code.

Thanks guys for your replies.

@dealer56 I like the idea with the async thread, but I’ve never done anything like this. Any ideas for a good starting place? Wondering if it’s something that Pytorch supports easily.