Class labels not displaying for PR Curve or ROC Curve

Hi W&B Team

I’ve managed to implement the PR Curve and ROC curve for a multiclass classifier and have them displaying in nicely in my dashboard. The only slight problem is that I can’t get the class names to display in the plots.

I’m retrieving class names from the Pytorch dataset which returns as a List of strings:

class_names = dataset.classes

[‘Akashiwo’, ‘Alexandrium’, ‘Amoeba’, ‘Amphidinium’, ‘Apedinella’, … ]

I then pass this to the PR curve function:

wandb.log({"pr": wandb.plot.pr_curve(all_test_labels.cpu(), all_test_output.cpu(), labels=class_names)})

However when the PR Curve is rendered on the dashboard it still displays a list of indices instead of the corresponding class names.

Still learning as I go so hoping there is a very simple fix for this.

Many thanks


hey @gg_sams - what wandb SDK version are you running? I recommend taking a look at this Colab and ensuring your inputs match the types of those specified in this line (since it seems to be the same classes you’re using as well):

    wandb.log({"pr_curve" : wandb.plot.pr_curve(ground_truth_class_ids, 

please let me know how this goes!

Hi Uma

Thanks for the link to the Colab. I compared the class names list from the Colab book and my own code and they were exactly the same datatype (Python string list). However I noticed that the tensor containing my labels included a decimal place after the index indicating that it was a tensor containing floats rather than ints. When creating the tensor I hadn’t specified a data type explicitly (so I guess it defaulted to floats).

I explicitly created the tensor using the int dtype and everything now works as expected with the correct labels now appearing in the graphs.

all_test_labels = torch.tensor([], dtype=torch.int32)

Many thanks