rishav
August 22, 2022, 4:46pm
1
Hi,
I am trying to plot confusion matrix using wandb’s API.
But I am getting
File "/home/ubuntu/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/wandb/plot/confusion_matrix.py", line 72, in confusion_matrix
counts[class_mapping[y_true[i]], class_mapping[preds[i]]] += 1
KeyError: tensor([0])
My validation loops like below -
for batch_idx, (data, target) in enumerate(loader['valid']):
output = model(data)
preds = torch.max(output, dim=1, keepdim=True)[1]
wandb.log({"conf_mat": wandb.plot.confusion_matrix(y_true=target, preds=preds, # noqa
class_names=class_names)})
Hi @rishav , this happens when your prediction array values don’t index the correct class name. Example:
The below would produce a KeyError: 5
as there isn’t a 5th index in class_names
preds = [1,2,3,4,5]
class_names = ["one","two","three","four","five"]
The above should instead be
preds = [0,1,2,3,4]
class_names = ["one","two","three","four","five"]
Hi @rishav , since we have not heard back from you we are going to close this request. If you would like to re-open the conversation, please let us know!
system
Closed
October 29, 2022, 6:08pm
4
This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.