Source code for pytorch_lightning_spells.metrics

import warnings
from typing import Optional, Any

import torch
import numpy as np
from torchmetrics import Metric
from scipy.stats import spearmanr
from pytorch_lightning.utilities import rank_zero_warn
from sklearn.metrics import fbeta_score, roc_auc_score
from sklearn.exceptions import UndefinedMetricWarning


[docs] class GlobalMetric(Metric): def __init__( self, compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, ) self.add_state("preds", default=[], dist_reduce_fx=None) self.add_state("target", default=[], dist_reduce_fx=None) rank_zero_warn( "This Metric will save all targets and predictions in buffer." " For large datasets this may lead to large memory footprint." )
[docs] def update(self, preds: torch.Tensor, target: torch.Tensor): """ Update state with predictions and targets. Args: preds: Predictions from model target: Ground truth values """ self.preds.append(preds) self.target.append(target)
[docs] class AUC(GlobalMetric): """AUC: Area Under the ROC Curve **Binary mode** >>> auc = AUC() >>> _ = auc(torch.tensor([0.3, 0.8, 0.2]), torch.tensor([0, 1, 0])) >>> _ = auc(torch.tensor([0.3, 0.3, 0.9]), torch.tensor([1, 1, 0])) >>> round(auc.compute().item(), 2) 0.56 **Multi-class mode** This will use the first column as the negative case, and the rest collectively as the positive case. >>> auc = AUC() >>> _ = auc(torch.tensor([[0.3, 0.8, 0.2], [0.2, 0.1, 0.1], [0.5, 0.1, 0.7]]).t(), torch.tensor([0, 1, 0])) >>> _ = auc(torch.tensor([[0.3, 0.3, 0.8], [0.2, 0.6, 0.1], [0.5, 0.1, 0.1]]).t(), torch.tensor([1, 1, 0])) >>> round(auc.compute().item(), 2) 0.39 """
[docs] def compute(self): target = torch.cat(self.target, dim=0).cpu().long().numpy() preds = torch.nan_to_num(torch.cat(self.preds, dim=0).float().cpu()).numpy() if len(preds.shape) > 1: preds = 1 - preds[:, 0] target = (target != 0).astype(int) if len(np.unique(target)) == 1: return torch.tensor(0, device=self.preds[0].device) return torch.tensor(roc_auc_score(target, preds), device=self.preds[0].device)
[docs] class SpearmanCorrelation(GlobalMetric): def __init__( self, sigmoid: bool = False, compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, ) self.sigmoid = sigmoid
[docs] def compute(self): preds = torch.cat(self.preds, dim=0).float() if self.sigmoid: preds = torch.sigmoid(preds) preds = preds.cpu().numpy() target = torch.cat(self.target, dim=0).cpu().float().numpy() spearman_score = spearmanr(target, preds).correlation if len(np.unique(target)) == 1: return torch.tensor(0, device=self.preds[0].device) return torch.tensor(spearman_score, device=self.preds[0].device)
[docs] class FBeta(GlobalMetric): """The F-beta score is the weighted harmonic mean of precision and recall **Binary mode** >>> fbeta = FBeta() >>> _ = fbeta(torch.tensor([0.3, 0.8, 0.2]), torch.tensor([0, 1, 0])) >>> _ = fbeta(torch.tensor([0.3, 0.3, 0.9]), torch.tensor([1, 1, 0])) >>> round(fbeta.compute().item(), 2) 0.88 **Multi-class mode** This will use the first column as the negative case, and the rest collectively as the positive case. >>> fbeta = FBeta() >>> _ = fbeta(torch.tensor([[0.8, 0.3, 0.7], [0.1, 0.1, 0.1], [0.1, 0.6, 0.2]]).t(), torch.tensor([0, 1, 0])) >>> _ = fbeta(torch.tensor([[0.3, 0.7, 0.8], [0.2, 0.2, 0.1], [0.5, 0.1, 0.1]]).t(), torch.tensor([1, 1, 0])) >>> round(fbeta.compute().item(), 4) 0.9375 """ def __init__( self, step: float = 0.02, beta: int = 2, compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, ) self.step = step self.beta = beta
[docs] def find_best_fbeta_threshold(self, truth, probs): best, best_thres = 0, -1 with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) for thres in np.arange(self.step, 1, self.step): current = fbeta_score( truth, (probs >= thres).astype("int8"), beta=self.beta, average="binary", ) if current > best: best = current best_thres = thres return best, best_thres
[docs] def compute(self): preds = torch.cat(self.preds, dim=0).float().cpu().numpy() target = torch.cat(self.target, dim=0).cpu().long().numpy() if len(preds.shape) > 1: preds = 1 - preds[:, 0] target = (target != 0).astype(int) if len(np.unique(target)) == 1: return torch.tensor(0, device=self.preds[0].device) best_fbeta, best_thres = self.find_best_fbeta_threshold(target, preds) return torch.tensor(best_fbeta, device=self.preds[0].device)