pytorch_lightning_spells.metrics module
Classes:
|
AUC: Area Under the ROC Curve |
|
The F-beta score is the weighted harmonic mean of precision and recall |
|
|
|
- class pytorch_lightning_spells.metrics.AUC(*args, **kwargs)[source]
Bases:
GlobalMetricAUC: Area Under the ROC Curve
Binary mode
>>> auc = AUC() >>> _ = auc(torch.tensor([0.3, 0.8, 0.2]), torch.tensor([0, 1, 0])) >>> _ = auc(torch.tensor([0.3, 0.3, 0.9]), torch.tensor([1, 1, 0])) >>> round(auc.compute().item(), 2) 0.56
Multi-class mode
This will use the first column as the negative case, and the rest collectively as the positive case.
>>> auc = AUC() >>> _ = auc(torch.tensor([[0.3, 0.8, 0.2], [0.2, 0.1, 0.1], [0.5, 0.1, 0.7]]).t(), torch.tensor([0, 1, 0])) >>> _ = auc(torch.tensor([[0.3, 0.3, 0.8], [0.2, 0.6, 0.1], [0.5, 0.1, 0.1]]).t(), torch.tensor([1, 1, 0])) >>> round(auc.compute().item(), 2) 0.39
- class pytorch_lightning_spells.metrics.FBeta(step=0.02, beta=2, compute_on_step=False, dist_sync_on_step=False, process_group=None)[source]
Bases:
GlobalMetricThe F-beta score is the weighted harmonic mean of precision and recall
This is a specialized implementation.
Binary mode
>>> fbeta = FBeta() >>> _ = fbeta(torch.tensor([0.3, 0.8, 0.2]), torch.tensor([0, 1, 0])) >>> _ = fbeta(torch.tensor([0.3, 0.3, 0.9]), torch.tensor([1, 1, 0])) >>> round(fbeta.compute().item(), 2) 0.88
Multi-class mode
This will use the first column as the negative case, and the rest collectively as the positive case.
Use MulticlassFBetaScore from torchmetrics for the true multiclass F-score metric.
>>> fbeta = FBeta() >>> _ = fbeta(torch.tensor([[0.8, 0.3, 0.7], [0.1, 0.1, 0.1], [0.1, 0.6, 0.2]]).t(), torch.tensor([0, 1, 0])) >>> _ = fbeta(torch.tensor([[0.3, 0.7, 0.8], [0.2, 0.2, 0.1], [0.5, 0.1, 0.1]]).t(), torch.tensor([1, 1, 0])) >>> round(fbeta.compute().item(), 4) 0.9375
- Parameters:
step (float)
beta (int)
compute_on_step (bool)
dist_sync_on_step (bool)
process_group (Any)
- class pytorch_lightning_spells.metrics.GlobalMetric(dist_sync_on_step=False, process_group=None)[source]
Bases:
Metric- Parameters:
dist_sync_on_step (bool)
process_group (Any)
- class pytorch_lightning_spells.metrics.SpearmanCorrelation(sigmoid=False, compute_on_step=False, dist_sync_on_step=False, process_group=None)[source]
Bases:
GlobalMetric- Parameters:
sigmoid (bool)
compute_on_step (bool)
dist_sync_on_step (bool)
process_group (Any)