Source code for detrex.layers.denoising

# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn

from detrex.utils import inverse_sigmoid


[docs]def apply_label_noise( labels: torch.Tensor, label_noise_prob: float = 0.2, num_classes: int = 80, ): """ Args: labels (torch.Tensor): Classification labels with ``(num_labels, )``. label_noise_prob (float): The probability of the label being noised. Default: 0.2. num_classes (int): Number of total categories. Returns: torch.Tensor: The noised labels the same shape as ``labels``. """ if label_noise_prob > 0: p = torch.rand_like(labels.float()) noised_index = torch.nonzero(p < label_noise_prob).view(-1) new_lebels = torch.randint_like(noised_index, 0, num_classes) noised_labels = labels.scatter_(0, noised_index, new_lebels) return noised_labels else: return labels
[docs]def apply_box_noise( boxes: torch.Tensor, box_noise_scale: float = 0.4, ): """ Args: boxes (torch.Tensor): Bounding boxes in format ``(x_c, y_c, w, h)`` with shape ``(num_boxes, 4)`` box_noise_scale (float): Scaling factor for box noising. Default: 0.4. """ if box_noise_scale > 0: diff = torch.zeros_like(boxes) diff[:, :2] = boxes[:, 2:] / 2 diff[:, 2:] = boxes[:, 2:] boxes += torch.mul((torch.rand_like(boxes) * 2 - 1.0), diff) * box_noise_scale boxes = boxes.clamp(min=0.0, max=1.0) return boxes
[docs]class GenerateDNQueries(nn.Module): """Generate denoising queries for DN-DETR Args: num_queries (int): Number of total queries in DN-DETR. Default: 300 num_classes (int): Number of total categories. Default: 80. label_embed_dim (int): The embedding dimension for label encoding. Default: 256. denoising_groups (int): Number of noised ground truth groups. Default: 5. label_noise_prob (float): The probability of the label being noised. Default: 0.2. box_noise_scale (float): Scaling factor for box noising. Default: 0.4 with_indicator (bool): If True, add indicator in noised label/box queries. """ def __init__( self, num_queries: int = 300, num_classes: int = 80, label_embed_dim: int = 256, denoising_groups: int = 5, label_noise_prob: float = 0.2, box_noise_scale: float = 0.4, with_indicator: bool = False, ): super(GenerateDNQueries, self).__init__() self.num_queries = num_queries self.num_classes = num_classes self.label_embed_dim = label_embed_dim self.denoising_groups = denoising_groups self.label_noise_prob = label_noise_prob self.box_noise_scale = box_noise_scale self.with_indicator = with_indicator # leave one dim for indicator mentioned in DN-DETR if with_indicator: self.label_encoder = nn.Embedding(num_classes, label_embed_dim - 1) else: self.label_encoder = nn.Embedding(num_classes, label_embed_dim) def generate_query_masks(self, max_gt_num_per_image, device): noised_query_nums = max_gt_num_per_image * self.denoising_groups tgt_size = noised_query_nums + self.num_queries attn_mask = torch.ones(tgt_size, tgt_size).to(device) < 0 # match query cannot see the reconstruct attn_mask[noised_query_nums:, :noised_query_nums] = True for i in range(self.denoising_groups): if i == 0: attn_mask[ max_gt_num_per_image * i : max_gt_num_per_image * (i + 1), max_gt_num_per_image * (i + 1) : noised_query_nums, ] = True if i == self.denoising_groups - 1: attn_mask[ max_gt_num_per_image * i : max_gt_num_per_image * (i + 1), : max_gt_num_per_image * i, ] = True else: attn_mask[ max_gt_num_per_image * i : max_gt_num_per_image * (i + 1), max_gt_num_per_image * (i + 1) : noised_query_nums, ] = True attn_mask[ max_gt_num_per_image * i : max_gt_num_per_image * (i + 1), : max_gt_num_per_image * i, ] = True return attn_mask
[docs] def forward( self, gt_labels_list, gt_boxes_list, ): """ Args: gt_boxes_list (list[torch.Tensor]): Ground truth bounding boxes per image with normalized coordinates in format ``(x, y, w, h)`` in shape ``(num_gts, 4)`` gt_labels_list (list[torch.Tensor]): Classification labels per image in shape ``(num_gt, )``. """ # concat ground truth labels and boxes in one batch # e.g. [tensor([0, 1, 2]), tensor([2, 3, 4])] -> tensor([0, 1, 2, 2, 3, 4]) gt_labels = torch.cat(gt_labels_list) gt_boxes = torch.cat(gt_boxes_list) # For efficient denoising, repeat the original ground truth labels and boxes to # create more training denoising samples. # e.g. tensor([0, 1, 2, 2, 3, 4]) -> tensor([0, 1, 2, 2, 3, 4, 0, 1, 2, 2, 3, 4]) if group = 2. gt_labels = gt_labels.repeat(self.denoising_groups, 1).flatten() gt_boxes = gt_boxes.repeat(self.denoising_groups, 1) # set the device as "gt_labels" device = gt_labels.device assert len(gt_labels_list) == len(gt_boxes_list) batch_size = len(gt_labels_list) # the number of ground truth per image in one batch # e.g. [tensor([0, 1]), tensor([2, 3, 4])] -> gt_nums_per_image: [2, 3] # means there are 2 instances in the first image and 3 instances in the second image gt_nums_per_image = [x.numel() for x in gt_labels_list] # Add noise on labels and boxes noised_labels = apply_label_noise(gt_labels, self.label_noise_prob, self.num_classes) noised_boxes = apply_box_noise(gt_boxes, self.box_noise_scale) noised_boxes = inverse_sigmoid(noised_boxes) # encoding labels label_embedding = self.label_encoder(noised_labels) query_num = label_embedding.shape[0] # add indicator to label encoding if with_indicator == True if self.with_indicator: label_embedding = torch.cat([label_embedding, torch.ones([query_num, 1]).to(device)], 1) # calculate the max number of ground truth in one image inside the batch. # e.g. gt_nums_per_image = [2, 3] which means # the first image has 2 instances and the second image has 3 instances # then the max_gt_num_per_image should be 3. max_gt_num_per_image = max(gt_nums_per_image) # the total denoising queries is depended on denoising groups and max number of instances. noised_query_nums = max_gt_num_per_image * self.denoising_groups # initialize the generated noised queries to zero. # And the zero initialized queries will be assigned with noised embeddings later. noised_label_queries = ( torch.zeros(noised_query_nums, self.label_embed_dim).to(device).repeat(batch_size, 1, 1) ) noised_box_queries = torch.zeros(noised_query_nums, 4).to(device).repeat(batch_size, 1, 1) # batch index per image: [0, 1, 2, 3] for batch_size == 4 batch_idx = torch.arange(0, batch_size) # e.g. gt_nums_per_image = [2, 3] # batch_idx = [0, 1] # then the "batch_idx_per_instance" equals to [0, 0, 1, 1, 1] # which indicates which image the instance belongs to. # cuz the instances has been flattened before. batch_idx_per_instance = torch.repeat_interleave( batch_idx, torch.tensor(gt_nums_per_image).long() ) # indicate which image the noised labels belong to. For example: # noised label: tensor([0, 1, 2, 2, 3, 4, 0, 1, 2, 2, 3, 4]) # batch_idx_per_group: tensor([0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]) # which means the first label "tensor([0])"" belongs to "image_0". batch_idx_per_group = batch_idx_per_instance.repeat(self.denoising_groups, 1).flatten() # Cuz there might be different numbers of ground truth in each image of the same batch. # So there might be some padding part in noising queries. # Here we calculate the indexes for the valid queries and # fill them with the noised embeddings. # And leave the padding part to zeros. if len(gt_nums_per_image): valid_index_per_group = torch.cat( [torch.tensor(list(range(num))) for num in gt_nums_per_image] ) valid_index_per_group = torch.cat( [ valid_index_per_group + max_gt_num_per_image * i for i in range(self.denoising_groups) ] ).long() if len(batch_idx_per_group): noised_label_queries[(batch_idx_per_group, valid_index_per_group)] = label_embedding noised_box_queries[(batch_idx_per_group, valid_index_per_group)] = noised_boxes # generate attention masks for transformer layers attn_mask = self.generate_query_masks(max_gt_num_per_image, device) return ( noised_label_queries, noised_box_queries, attn_mask, self.denoising_groups, max_gt_num_per_image, )
class GenerateCDNQueries(nn.Module): def __init__( self, num_queries: int = 300, num_classes: int = 80, label_embed_dim: int = 256, denoising_nums: int = 100, label_noise_prob: float = 0.5, box_noise_scale: float = 1.0, ): super(GenerateCDNQueries, self).__init__() self.num_queries = num_queries self.num_classes = num_classes self.label_embed_dim = label_embed_dim self.denoising_nums = denoising_nums self.label_noise_prob = label_noise_prob self.box_noise_scale = box_noise_scale self.label_encoder = nn.Embedding(num_classes, label_embed_dim) def forward( self, gt_labels_list, gt_boxes_list, ): denoising_nums = self.denoising_nums * 2