Source code for detrex.modeling.losses.focal_loss

# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/facebookresearch/detr/blob/main/models/segmentation.py
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py
# ------------------------------------------------------------------------------------------------

import torch
import torch.nn as nn
import torch.nn.functional as F

from .utils import weight_reduce_loss


[docs]def sigmoid_focal_loss( preds, targets, weight=None, alpha: float = 0.25, gamma: float = 2, reduction: str = "mean", avg_factor: int = None, ): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: preds (torch.Tensor): A float tensor of arbitrary shape. The predictions for each example. targets (torch.Tensor): A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha (float, optional): Weighting factor in range (0, 1) to balance positive vs negative examples. Default: 0.25. gamma (float): Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Default: 2. reduction: 'none' | 'mean' | 'sum' 'none': No reduction will be applied to the output. 'mean': The output will be averaged. 'sum': The output will be summed. avg_factor (int): Average factor that is used to average the loss. Default: None. Returns: torch.Tensor: The computed sigmoid focal loss with the reduction option applied. """ preds = preds.float() targets = targets.float() p = torch.sigmoid(preds) ce_loss = F.binary_cross_entropy_with_logits(preds, targets, reduction="none") p_t = p * targets + (1 - p) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss if weight is not None: assert weight.ndim == loss.ndim loss = weight_reduce_loss(loss, weight, reduction, avg_factor) return loss
def focal_loss_with_prob( preds, targets, weight=None, alpha=0.25, gamma=2.0, reduction="mean", avg_factor=None, ): """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_. Different from `sigmoid_focal_loss`, this function accepts probability as input. Args: preds (torch.Tensor): The prediction probability with shape (N, C), C is the number of classes. targets (torch.Tensor): The learning label of the prediction. weight (torch.Tensor, optional): Sample-wise loss weight. gamma (float, optional): The gamma for calculating the modulating factor. Defaults to 2.0. alpha (float, optional): A balanced form for Focal Loss. Defaults to 0.25. reduction (str, optional): The method used to reduce the loss into a scalar. Defaults to 'mean'. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. """ num_classes = preds.size(1) targets = F.one_hot(targets, num_classes=num_classes + 1) targets = targets[:, :num_classes] targets = targets.type_as(preds) p_t = preds * targets + (1 - preds) * (1 - targets) ce_loss = F.binary_cross_entropy(preds, targets, reduction="none") loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss if weight is not None: assert weight.ndim == loss.ndim loss = weight_reduce_loss(loss, weight, reduction, avg_factor) return loss class FocalLoss(nn.Module): """`Focal Loss <https://arxiv.org/abs/1708.02002>`_ Args: gamma (float, optional): The gamma for calculating the modulating factor. Defaults to 2.0. alpha (float, optional): A balanced form for Focal Loss. Defaults to 0.25. reduction (str, optional): The method used to reduce the loss into a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". loss_weight (float, optional): Weight of loss. Defaults to 1.0. """ def __init__( self, alpha=0.25, gamma=2.0, reduction="mean", loss_weight=1.0, activated=False, ): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction self.loss_weight = loss_weight self.activated = activated def forward( self, preds, targets, weight=None, avg_factor=None, ): """Forward function for FocalLoss Args: preds (torch.Tensor): The prediction probability with shape ``(N, C)``. C is the number of classes. targets (torch.Tensor): The learning label of the prediction. weight (torch.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. Returns: torch.Tensor: The calculated loss """ if self.activated: loss_func = focal_loss_with_prob else: num_classes = preds.size(1) targets = F.one_hot(targets, num_classes=num_classes + 1) targets = targets[:, :num_classes] loss_func = sigmoid_focal_loss loss_class = self.loss_weight * loss_func( preds, targets, weight, alpha=self.alpha, gamma=self.gamma, reduction=self.reduction, avg_factor=avg_factor, ) return loss_class