pytorch_lightning_spells.metrics module

Classes:

AUC(*args, **kwargs)

AUC: Area Under the ROC Curve

FBeta([step, beta, compute_on_step, ...])

The F-beta score is the weighted harmonic mean of precision and recall

GlobalMetric([dist_sync_on_step, process_group])

SpearmanCorrelation([sigmoid, ...])

class pytorch_lightning_spells.metrics.AUC(*args, **kwargs)[source]

Bases: GlobalMetric

AUC: Area Under the ROC Curve

Binary mode

>>> auc = AUC()
>>> _ = auc(torch.tensor([0.3, 0.8, 0.2]), torch.tensor([0, 1, 0]))
>>> _ = auc(torch.tensor([0.3, 0.3, 0.9]), torch.tensor([1, 1, 0]))
>>> round(auc.compute().item(), 2)
0.56

Multi-class mode

This will use the first column as the negative case, and the rest collectively as the positive case.

>>> auc = AUC()
>>> _ = auc(torch.tensor([[0.3, 0.8, 0.2], [0.2, 0.1, 0.1], [0.5, 0.1, 0.7]]).t(), torch.tensor([0, 1, 0]))
>>> _ = auc(torch.tensor([[0.3, 0.3, 0.8], [0.2, 0.6, 0.1], [0.5, 0.1, 0.1]]).t(), torch.tensor([1, 1, 0]))
>>> round(auc.compute().item(), 2)
0.39
compute()[source]
class pytorch_lightning_spells.metrics.FBeta(step=0.02, beta=2, compute_on_step=False, dist_sync_on_step=False, process_group=None)[source]

Bases: GlobalMetric

The F-beta score is the weighted harmonic mean of precision and recall

This is a specialized implementation.

Binary mode

>>> fbeta = FBeta()
>>> _ = fbeta(torch.tensor([0.3, 0.8, 0.2]), torch.tensor([0, 1, 0]))
>>> _ = fbeta(torch.tensor([0.3, 0.3, 0.9]), torch.tensor([1, 1, 0]))
>>> round(fbeta.compute().item(), 2)
0.88

Multi-class mode

This will use the first column as the negative case, and the rest collectively as the positive case.

Use MulticlassFBetaScore from torchmetrics for the true multiclass F-score metric.

>>> fbeta = FBeta()
>>> _ = fbeta(torch.tensor([[0.8, 0.3, 0.7], [0.1, 0.1, 0.1], [0.1, 0.6, 0.2]]).t(), torch.tensor([0, 1, 0]))
>>> _ = fbeta(torch.tensor([[0.3, 0.7, 0.8], [0.2, 0.2, 0.1], [0.5, 0.1, 0.1]]).t(), torch.tensor([1, 1, 0]))
>>> round(fbeta.compute().item(), 4)
0.9375
Parameters:
  • step (float)

  • beta (int)

  • compute_on_step (bool)

  • dist_sync_on_step (bool)

  • process_group (Any)

compute()[source]
find_best_fbeta_threshold(truth, probs)[source]
class pytorch_lightning_spells.metrics.GlobalMetric(dist_sync_on_step=False, process_group=None)[source]

Bases: Metric

Parameters:
  • dist_sync_on_step (bool)

  • process_group (Any)

update(preds, target)[source]

Update state with predictions and targets.

Parameters:
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

class pytorch_lightning_spells.metrics.SpearmanCorrelation(sigmoid=False, compute_on_step=False, dist_sync_on_step=False, process_group=None)[source]

Bases: GlobalMetric

Parameters:
  • sigmoid (bool)

  • compute_on_step (bool)

  • dist_sync_on_step (bool)

  • process_group (Any)

compute()[source]