Hi W&B Team
I’ve managed to implement the PR Curve and ROC curve for a multiclass classifier and have them displaying in nicely in my dashboard. The only slight problem is that I can’t get the class names to display in the plots.
I’m retrieving class names from the Pytorch dataset which returns as a List of strings:
class_names = dataset.classes
[‘Akashiwo’, ‘Alexandrium’, ‘Amoeba’, ‘Amphidinium’, ‘Apedinella’, … ]
I then pass this to the PR curve function:
wandb.log({"pr": wandb.plot.pr_curve(all_test_labels.cpu(), all_test_output.cpu(), labels=class_names)})
However when the PR Curve is rendered on the dashboard it still displays a list of indices instead of the corresponding class names.
Still learning as I go so hoping there is a very simple fix for this.
Many thanks
Gary
hey @gg_sams - what wandb SDK version are you running? I recommend taking a look at this Colab and ensuring your inputs match the types of those specified in this line (since it seems to be the same classes you’re using as well):
wandb.log({"pr_curve" : wandb.plot.pr_curve(ground_truth_class_ids,
val_predictions,
labels=self.flat_class_names)})
please let me know how this goes!
Hi Uma
Thanks for the link to the Colab. I compared the class names list from the Colab book and my own code and they were exactly the same datatype (Python string list). However I noticed that the tensor containing my labels included a decimal place after the index indicating that it was a tensor containing floats rather than ints. When creating the tensor I hadn’t specified a data type explicitly (so I guess it defaulted to floats).
I explicitly created the tensor using the int dtype and everything now works as expected with the correct labels now appearing in the graphs.
all_test_labels = torch.tensor([], dtype=torch.int32)
Many thanks
Gary