Source code for pytorch_lightning_spells.losses

from typing import Callable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


[docs] class Poly1CrossEntropyLoss(nn.Module): """Poly-1 Cross-Entropy Loss Adapted from `abhuse/polyloss-pytorch <https://github.com/abhuse/polyloss-pytorch>`_. Reference: `PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions <http://arxiv.org/abs/2204.12511>`_. """ def __init__( self, epsilon: float = 1.0, reduction: str = "none", weight: Optional[Tensor] = None ): """ Create instance of Poly1CrossEntropyLoss :param num_classes: :param epsilon: :param reduction: one of none|sum|mean, apply reduction to final loss tensor :param weight: manual rescaling weight for each class, passed to Cross-Entropy loss """ super(Poly1CrossEntropyLoss, self).__init__() self.epsilon = epsilon self.reduction = reduction self.weight = weight
[docs] def forward(self, logits, labels, **kwargs): """ Forward pass :param logits: tensor of shape [N, num_classes] :param labels: tensor of shape [N] :return: poly cross-entropy loss """ probs = F.softmax(logits, dim=-1) if self.weight is not None: self.weight = self.weight.to(labels.device) probs = probs * self.weight.unsqueeze(0) / self.weight.mean() pt = torch.gather(probs, -1, labels.unsqueeze(1))[:, 0] CE = F.cross_entropy( input=logits, target=labels, reduction="none", weight=self.weight ) poly1 = CE + self.epsilon * (1 - pt) if self.reduction == "mean": poly1 = poly1.mean() elif self.reduction == "sum": poly1 = poly1.sum() return poly1
[docs] class Poly1FocalLoss(nn.Module): """Poly-1 Focal Loss Adapted from `abhuse/polyloss-pytorch <https://github.com/abhuse/polyloss-pytorch>`_. Reference: `PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions <http://arxiv.org/abs/2204.12511>`_. """ def __init__( self, num_classes: int, epsilon: float = 1.0, alpha: float = 0.25, gamma: float = 2.0, reduction: str = "none", weight: Tensor = None, label_is_onehot: bool = False, ): """ Create instance of Poly1FocalLoss :param num_classes: number of classes :param epsilon: poly loss epsilon :param alpha: focal loss alpha :param gamma: focal loss gamma :param reduction: one of none|sum|mean, apply reduction to final loss tensor :param weight: manual rescaling weight for each class, passed to binary Cross-Entropy loss :param label_is_onehot: set to True if labels are one-hot encoded """ super(Poly1FocalLoss, self).__init__() self.num_classes = num_classes self.epsilon = epsilon self.alpha = alpha self.gamma = gamma self.reduction = reduction self.weight = weight self.label_is_onehot = label_is_onehot return
[docs] def forward(self, logits, labels): """ Forward pass :param logits: output of neural netwrok of shape [N, num_classes] or [N, num_classes, ...] :param labels: ground truth tensor of shape [N] or [N, ...] with class ids if label_is_onehot was set to False, otherwise one-hot encoded tensor of same shape as logits :return: poly focal loss """ # focal loss implementation taken from # https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py p = torch.sigmoid(logits) if not self.label_is_onehot: # if labels are of shape [N] # convert to one-hot tensor of shape [N, num_classes] if labels.ndim == 1: labels = F.one_hot(labels, num_classes=self.num_classes) # if labels are of shape [N, ...] e.g. segmentation task # convert to one-hot tensor of shape [N, num_classes, ...] else: labels = ( F.one_hot(labels.unsqueeze(1), self.num_classes) .transpose(1, -1) .squeeze_(-1) ) labels = labels.to(device=logits.device, dtype=logits.dtype) ce_loss = F.binary_cross_entropy_with_logits( input=logits, target=labels, reduction="none", weight=self.weight.to(logits.device) if self.weight is not None else None ) pt = labels * p + (1 - labels) * (1 - p) FL = ce_loss * ((1 - pt) ** self.gamma) if self.alpha >= 0: alpha_t = self.alpha * labels + (1 - self.alpha) * (1 - labels) FL = alpha_t * FL poly1 = FL + self.epsilon * torch.pow(1 - pt, self.gamma + 1) if self.reduction == "mean": poly1 = poly1.mean() elif self.reduction == "sum": poly1 = poly1.sum() return poly1
def _linear_combination(x, y, epsilon): return epsilon * x + (1 - epsilon) * y
[docs] class LabelSmoothCrossEntropy(nn.Module): """Cross Entropy with Label Smoothing Reference: `wangleiofficial/lable-smoothing-pytorch <https://github.com/wangleiofficial/label-smoothing-pytorch>`_ The ground truth label will have a value of `1-eps` in the target vector. Args: eps (float): the smoothing factor. """ def __init__(self, eps: float): super().__init__() self.eps = eps
[docs] def forward(self, preds, targets, weight=None): n = preds.size()[-1] log_preds = F.log_softmax(preds, dim=-1) if weight is None: loss = -log_preds.sum(dim=-1).mean() else: loss = -(log_preds.sum(dim=-1) * weight.unsqueeze(0)).mean() / weight.mean() nll = F.nll_loss(log_preds, targets, weight=weight) return _linear_combination(loss / n, nll, self.eps)
[docs] class MixupSoftmaxLoss(nn.Module): """A softmax loss that supports MixUp augmentation. It requires the input batch to be manipulated into certain format. Works best with MixUpCallback, CutMixCallback, and SnapMixCallback. Reference: `Fast.ai's implementation <https://github.com/fastai/fastai/blob/master/fastai/callbacks/mixup.py#L6>`_ Args: class_weights (torch.Tensor, optional): The weight of each class. Defaults to the same weight. reduction (str, optional): Loss reduction method. Defaults to 'mean'. label_smooth_eps (float, optional): If larger than zero, use `LabelSmoothedCrossEntropy` instead of `CrossEntropy`. Defaults to 0. """ def __init__( self, class_weights: Optional[torch.Tensor] = None, reduction: str = "mean", label_smooth_eps: float = 0, poly1_eps: float = 0, ): super().__init__() # setattr(self.crit, 'reduction', 'none') self.reduction = reduction self.weight = class_weights assert not ( (label_smooth_eps > 0) and (poly1_eps != 0) ), "You cannot set both `label_smooth_eps` and `poly1_eps` to non-default values!" if label_smooth_eps > 0: self.loss_fn: Callable = LabelSmoothCrossEntropy(eps=label_smooth_eps) elif poly1_eps != 0: self.loss_fn = Poly1CrossEntropyLoss(epsilon=poly1_eps, weight=self.weight) else: self.loss_fn = F.cross_entropy
[docs] def forward(self, output: torch.Tensor, target): """The feed-forward. The target tensor should have three columns: 1. the first class. 2. the second class. 3. the lambda value to mix the above two classes. Args: output (torch.Tensor): the model output. target (torch.Tensor): Shaped (batch_size, 3). Returns: torch.Tensor: the result loss """ if self.weight is not None: self.weight = self.weight.to(output.device) weight = self.weight if len(target.size()) == 2: loss1 = self.loss_fn(output, target[:, 0].long(), weight=weight) loss2 = self.loss_fn(output, target[:, 1].long(), weight=weight) assert target.size(1) in (3, 4) if target.size(1) == 3: lambda_ = target[:, 2] d = (loss1 * lambda_ + loss2 * (1 - lambda_)).mean() else: lamb_1, lamb_2 = target[:, 2], target[:, 3] d = (loss1 * lamb_1 + loss2 * lamb_2).mean() else: # This handles the cases without MixUp for backward compatibility d = self.loss_fn(output, target, weight=weight) if self.reduction == "mean": return d.mean() elif self.reduction == "sum": return d.sum() return d