pytorch_lightning_spells.metrics module
Classes:
|
AUC: Area Under the ROC Curve |
|
The F-beta score is the weighted harmonic mean of precision and recall |
|
|
|
- class pytorch_lightning_spells.metrics.AUC(compute_on_step=False, dist_sync_on_step=False, process_group=None)[source]
Bases:
GlobalMetric
AUC: Area Under the ROC Curve
Binary mode
>>> auc = AUC() >>> _ = auc(torch.tensor([0.3, 0.8, 0.2]), torch.tensor([0, 1, 0])) >>> _ = auc(torch.tensor([0.3, 0.3, 0.9]), torch.tensor([1, 1, 0])) >>> round(auc.compute().item(), 2) 0.56
Multi-class mode
This will use the first column as the negative case, and the rest collectively as the positive case.
>>> auc = AUC() >>> _ = auc(torch.tensor([[0.3, 0.8, 0.2], [0.2, 0.1, 0.1], [0.5, 0.1, 0.7]]).t(), torch.tensor([0, 1, 0])) >>> _ = auc(torch.tensor([[0.3, 0.3, 0.8], [0.2, 0.6, 0.1], [0.5, 0.1, 0.1]]).t(), torch.tensor([1, 1, 0])) >>> round(auc.compute().item(), 2) 0.39
- Parameters:
compute_on_step (bool) –
dist_sync_on_step (bool) –
process_group (Any | None) –
- class pytorch_lightning_spells.metrics.FBeta(step=0.02, beta=2, compute_on_step=False, dist_sync_on_step=False, process_group=None)[source]
Bases:
GlobalMetric
The F-beta score is the weighted harmonic mean of precision and recall
Binary mode
>>> fbeta = FBeta() >>> _ = fbeta(torch.tensor([0.3, 0.8, 0.2]), torch.tensor([0, 1, 0])) >>> _ = fbeta(torch.tensor([0.3, 0.3, 0.9]), torch.tensor([1, 1, 0])) >>> round(fbeta.compute().item(), 2) 0.88
Multi-class mode
This will use the first column as the negative case, and the rest collectively as the positive case.
>>> fbeta = FBeta() >>> _ = fbeta(torch.tensor([[0.8, 0.3, 0.7], [0.1, 0.1, 0.1], [0.1, 0.6, 0.2]]).t(), torch.tensor([0, 1, 0])) >>> _ = fbeta(torch.tensor([[0.3, 0.7, 0.8], [0.2, 0.2, 0.1], [0.5, 0.1, 0.1]]).t(), torch.tensor([1, 1, 0])) >>> round(fbeta.compute().item(), 4) 0.9375
- Parameters:
step (float) –
beta (int) –
compute_on_step (bool) –
dist_sync_on_step (bool) –
process_group (Any | None) –
- class pytorch_lightning_spells.metrics.GlobalMetric(compute_on_step=False, dist_sync_on_step=False, process_group=None)[source]
Bases:
Metric
- Parameters:
compute_on_step (bool) –
dist_sync_on_step (bool) –
process_group (Any | None) –
- class pytorch_lightning_spells.metrics.SpearmanCorrelation(sigmoid=False, compute_on_step=False, dist_sync_on_step=False, process_group=None)[source]
Bases:
GlobalMetric
- Parameters:
sigmoid (bool) –
compute_on_step (bool) –
dist_sync_on_step (bool) –
process_group (Any | None) –