Source code for torchflare.criterion.focal_loss

"""Implements variants for Focal loss."""
import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class BCEFocalLoss(nn.Module): """Implementation of Focal Loss for Binary Classification Problems. Focal loss was proposed in `Focal Loss for Dense Object Detection_. <https://arxiv.org/abs/1708.02002>`_. """ def __init__(self, gamma=0, eps=1e-7, reduction="mean"): """Constructor Method for FocalLoss class. Args: gamma : The focal parameter. Defaults to 0. eps : Constant for computational stability. reduction: The reduction parameter for Cross Entropy Loss. """ super(BCEFocalLoss, self).__init__() self.gamma = gamma self.reduction = reduction self.eps = eps self.bce = torch.nn.BCEWithLogitsLoss(reduction="none")
[docs] def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """Forward method. Args: logits: The raw logits from the network of shape (N,k) \ where C = number of classes , k = extra dims targets: The targets Returns: The computed loss value """ targets = targets.view(logits.shape) logp = self.bce(logits, targets) p = torch.exp(-logp) loss = (1 - p) ** self.gamma * logp return ( loss.mean() if self.reduction == "mean" else loss.sum() if self.reduction == "sum" else loss )
[docs]class FocalLoss(nn.Module): """Implementation of Focal Loss. Focal loss was proposed in `Focal Loss for Dense Object Detection_. <https://arxiv.org/abs/1708.02002>`_. Args: gamma : The focal parameter. Defaults to 0. eps : Constant for computational stability. reduction: The reduction parameter for Cross Entropy Loss. """ def __init__(self, gamma=0, eps=1e-7, reduction="mean"): """Constructor Method for FocalLoss class.""" super(FocalLoss, self).__init__() self.gamma = gamma self.reduction = reduction self.eps = eps self.ce = torch.nn.CrossEntropyLoss(reduction="none")
[docs] def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """Forward method. Args: logits: The raw logits from the network of shape (N,C,k)\ where C = number of classes and (k) = extra dims targets: The targets of shape (N , k). Returns: The computed loss value """ logp = self.ce(logits, targets) p = torch.exp(-logp) loss = (1 - p) ** self.gamma * logp return ( loss.mean() if self.reduction == "mean" else loss.sum() if self.reduction == "sum" else loss )
class FocalCosineLoss(nn.Module): """Implementation Focal cosine loss. Source : https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/203271 """ def __init__(self, alpha: float = 1, gamma: float = 2, xent: float = 0.1, reduction="mean"): """Constructor for FocalCosineLoss.""" super(FocalCosineLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.xent = xent self.reduction = reduction def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward Method.""" cosine_loss = F.cosine_embedding_loss( logits, torch.nn.functional.one_hot(target, num_classes=logits.size(-1)), torch.tensor([1], device=target.device), reduction=self.reduction, ) cent_loss = F.cross_entropy(F.normalize(logits), target, reduction="none") pt = torch.exp(-cent_loss) focal_loss = self.alpha * (1 - pt) ** self.gamma * cent_loss if self.reduction == "mean": focal_loss = torch.mean(focal_loss) return cosine_loss + self.xent * focal_loss