Getting KeyError: tensor([0]) while plotting wandb's confusion matrix

Hi,
I am trying to plot confusion matrix using wandb’s API.
But I am getting

  File "/home/ubuntu/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/wandb/plot/confusion_matrix.py", line 72, in confusion_matrix
    counts[class_mapping[y_true[i]], class_mapping[preds[i]]] += 1
KeyError: tensor([0])

My validation loops like below -

for batch_idx, (data, target) in enumerate(loader['valid']):
     output = model(data)
     preds = torch.max(output, dim=1, keepdim=True)[1]
     wandb.log({"conf_mat": wandb.plot.confusion_matrix(y_true=target, preds=preds, # noqa
                       class_names=class_names)})

Hi @rishav , this happens when your prediction array values don’t index the correct class name. Example:

The below would produce a KeyError: 5 as there isn’t a 5th index in class_names

preds   = [1,2,3,4,5]
class_names = ["one","two","three","four","five"]

The above should instead be

preds   = [0,1,2,3,4]
class_names = ["one","two","three","four","five"]

Hi @rishav, since we have not heard back from you we are going to close this request. If you would like to re-open the conversation, please let us know!