Source code for pytorch_lightning_spells.metrics

import warnings
from typing import Any, cast

import torch
import numpy as np
from torchmetrics import Metric
from scipy.stats import spearmanr
from pytorch_lightning.utilities import rank_zero_warn
from sklearn.metrics import fbeta_score, roc_auc_score
from sklearn.exceptions import UndefinedMetricWarning


[docs] class GlobalMetric(Metric): def __init__( self, dist_sync_on_step: bool = False, process_group: Any = None, ): super().__init__( dist_sync_on_step=dist_sync_on_step, process_group=process_group, ) self.add_state("preds", default=[], dist_reduce_fx=None) self.add_state("targets", default=[], dist_reduce_fx=None) rank_zero_warn( "This metric will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint." )
[docs] def update(self, preds: torch.Tensor, target: torch.Tensor): """ Update state with predictions and targets. Args: preds: Predictions from model target: Ground truth values """ cast(list[torch.Tensor], cast(object, self.preds)).append(preds) cast(list[torch.Tensor], cast(object, self.targets)).append(target)
[docs] class AUC(GlobalMetric): """AUC: Area Under the ROC Curve **Binary mode** >>> auc = AUC() >>> _ = auc(torch.tensor([0.3, 0.8, 0.2]), torch.tensor([0, 1, 0])) >>> _ = auc(torch.tensor([0.3, 0.3, 0.9]), torch.tensor([1, 1, 0])) >>> round(auc.compute().item(), 2) 0.56 **Multi-class mode** This will use the first column as the negative case, and the rest collectively as the positive case. >>> auc = AUC() >>> _ = auc(torch.tensor([[0.3, 0.8, 0.2], [0.2, 0.1, 0.1], [0.5, 0.1, 0.7]]).t(), torch.tensor([0, 1, 0])) >>> _ = auc(torch.tensor([[0.3, 0.3, 0.8], [0.2, 0.6, 0.1], [0.5, 0.1, 0.1]]).t(), torch.tensor([1, 1, 0])) >>> round(auc.compute().item(), 2) 0.39 """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) message = ( "This is a specialized metric that coalesces all classes except the first into a single positive class. " "Please use BinaryAUROC from torchmetrics if you have a binary classification problem: " "'from torchmetrics.classification import BinaryAUROC'." ) rank_zero_warn(message)
[docs] def compute(self): preds_list = cast(list[torch.Tensor], cast(object, self.preds)) target_list = cast(list[torch.Tensor], cast(object, self.targets)) targets_np = torch.cat(target_list, dim=0).cpu().long().numpy() preds_np = torch.nan_to_num(torch.cat(preds_list, dim=0).float().cpu()).numpy() if len(preds_np.shape) > 1: preds_np = 1 - preds_np[:, 0] targets_np = (targets_np != 0).astype(int) if len(np.unique(targets_np)) == 1: return torch.tensor(0, device=preds_list[0].device) return torch.tensor(roc_auc_score(targets_np, preds_np), device=preds_list[0].device)
[docs] class SpearmanCorrelation(GlobalMetric): def __init__( self, sigmoid: bool = False, compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: Any = None, ): super().__init__( dist_sync_on_step=dist_sync_on_step, process_group=process_group, ) self.sigmoid = sigmoid
[docs] def compute(self): preds_list = cast(list[torch.Tensor], cast(object, self.preds)) target_list = cast(list[torch.Tensor], cast(object, self.targets)) preds_tensor = torch.cat(preds_list, dim=0).float() if self.sigmoid: preds_tensor = torch.sigmoid(preds_tensor) preds_np = preds_tensor.cpu().numpy() targets_np = torch.cat(target_list, dim=0).cpu().float().numpy() spearman_score = spearmanr(targets_np, preds_np).correlation # pyright: ignore[reportAttributeAccessIssue] if len(np.unique(targets_np)) == 1: return torch.tensor(0, device=preds_list[0].device) return torch.tensor(spearman_score, device=preds_list[0].device)
[docs] class FBeta(GlobalMetric): """The F-beta score is the weighted harmonic mean of precision and recall This is a specialized implementation. - **Binary mode** >>> fbeta = FBeta() >>> _ = fbeta(torch.tensor([0.3, 0.8, 0.2]), torch.tensor([0, 1, 0])) >>> _ = fbeta(torch.tensor([0.3, 0.3, 0.9]), torch.tensor([1, 1, 0])) >>> round(fbeta.compute().item(), 2) 0.88 **Multi-class mode** This will **use the first column as the negative case, and the rest collectively as the positive case**. Use MulticlassFBetaScore from torchmetrics for the true multiclass F-score metric. >>> fbeta = FBeta() >>> _ = fbeta(torch.tensor([[0.8, 0.3, 0.7], [0.1, 0.1, 0.1], [0.1, 0.6, 0.2]]).t(), torch.tensor([0, 1, 0])) >>> _ = fbeta(torch.tensor([[0.3, 0.7, 0.8], [0.2, 0.2, 0.1], [0.5, 0.1, 0.1]]).t(), torch.tensor([1, 1, 0])) >>> round(fbeta.compute().item(), 4) 0.9375 """ def __init__( self, step: float = 0.02, beta: int = 2, compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: Any = None, ): super().__init__( dist_sync_on_step=dist_sync_on_step, process_group=process_group, ) self.step = step self.beta = beta rank_zero_warn( ( "This is a specialized metric that coalesces all classes except the first into the positive class and automatically optimizes the classification threshold. " "Please use FBeta from torchmetrics for general use cases." ) )
[docs] def find_best_fbeta_threshold(self, truth, probs): best, best_thres = 0, -1 with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) for thres in np.arange(self.step, 1, self.step): current = fbeta_score( truth, (probs >= thres).astype("int8"), beta=self.beta, average="binary", ) if current > best: best = current best_thres = thres return best, best_thres
[docs] def compute(self): preds_list = cast(list[torch.Tensor], cast(object, self.preds)) target_list = cast(list[torch.Tensor], cast(object, self.targets)) preds_np = torch.cat(preds_list, dim=0).float().cpu().numpy() targets_np = torch.cat(target_list, dim=0).cpu().long().numpy() if len(preds_np.shape) > 1: preds_np = 1 - preds_np[:, 0] targets_np = (targets_np != 0).astype(int) if len(np.unique(targets_np)) == 1: return torch.tensor(0, device=preds_list[0].device) best_fbeta, best_thres = self.find_best_fbeta_threshold(targets_np, preds_np) return torch.tensor(best_fbeta, device=preds_list[0].device)