pytorch_lightning_spells.metrics module

Classes:

AUC([compute_on_step, dist_sync_on_step, ...])

AUC: Area Under the ROC Curve

FBeta([step, beta, compute_on_step, ...])

The F-beta score is the weighted harmonic mean of precision and recall

GlobalMetric([compute_on_step, ...])

SpearmanCorrelation([sigmoid, ...])

class pytorch_lightning_spells.metrics.AUC(compute_on_step=False, dist_sync_on_step=False, process_group=None)[source]

Bases: GlobalMetric

AUC: Area Under the ROC Curve

Binary mode

>>> auc = AUC()
>>> _ = auc(torch.tensor([0.3, 0.8, 0.2]), torch.tensor([0, 1, 0]))
>>> _ = auc(torch.tensor([0.3, 0.3, 0.9]), torch.tensor([1, 1, 0]))
>>> round(auc.compute().item(), 2)
0.56

Multi-class mode

This will use the first column as the negative case, and the rest collectively as the positive case.

>>> auc = AUC()
>>> _ = auc(torch.tensor([[0.3, 0.8, 0.2], [0.2, 0.1, 0.1], [0.5, 0.1, 0.7]]).t(), torch.tensor([0, 1, 0]))
>>> _ = auc(torch.tensor([[0.3, 0.3, 0.8], [0.2, 0.6, 0.1], [0.5, 0.1, 0.1]]).t(), torch.tensor([1, 1, 0]))
>>> round(auc.compute().item(), 2)
0.39
Parameters:
  • compute_on_step (bool) –

  • dist_sync_on_step (bool) –

  • process_group (Any | None) –

compute()[source]
class pytorch_lightning_spells.metrics.FBeta(step=0.02, beta=2, compute_on_step=False, dist_sync_on_step=False, process_group=None)[source]

Bases: GlobalMetric

The F-beta score is the weighted harmonic mean of precision and recall

Binary mode

>>> fbeta = FBeta()
>>> _ = fbeta(torch.tensor([0.3, 0.8, 0.2]), torch.tensor([0, 1, 0]))
>>> _ = fbeta(torch.tensor([0.3, 0.3, 0.9]), torch.tensor([1, 1, 0]))
>>> round(fbeta.compute().item(), 2)
0.88

Multi-class mode

This will use the first column as the negative case, and the rest collectively as the positive case.

>>> fbeta = FBeta()
>>> _ = fbeta(torch.tensor([[0.8, 0.3, 0.7], [0.1, 0.1, 0.1], [0.1, 0.6, 0.2]]).t(), torch.tensor([0, 1, 0]))
>>> _ = fbeta(torch.tensor([[0.3, 0.7, 0.8], [0.2, 0.2, 0.1], [0.5, 0.1, 0.1]]).t(), torch.tensor([1, 1, 0]))
>>> round(fbeta.compute().item(), 4)
0.9375
Parameters:
  • step (float) –

  • beta (int) –

  • compute_on_step (bool) –

  • dist_sync_on_step (bool) –

  • process_group (Any | None) –

compute()[source]
find_best_fbeta_threshold(truth, probs)[source]
class pytorch_lightning_spells.metrics.GlobalMetric(compute_on_step=False, dist_sync_on_step=False, process_group=None)[source]

Bases: Metric

Parameters:
  • compute_on_step (bool) –

  • dist_sync_on_step (bool) –

  • process_group (Any | None) –

update(preds, target)[source]

Update state with predictions and targets.

Parameters:
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

class pytorch_lightning_spells.metrics.SpearmanCorrelation(sigmoid=False, compute_on_step=False, dist_sync_on_step=False, process_group=None)[source]

Bases: GlobalMetric

Parameters:
  • sigmoid (bool) –

  • compute_on_step (bool) –

  • dist_sync_on_step (bool) –

  • process_group (Any | None) –

compute()[source]