Source code for detrex.data.detr_dataset_mapper

# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates.
# All Rights Reserved
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/facebookresearch/detr/blob/main/d2/detr/dataset_mapper.py
# ------------------------------------------------------------------------------------------------

import copy
import logging
import numpy as np
import torch

from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T

__all__ = ["DetrDatasetMapper"]


[docs]class DetrDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into the format used by DETR. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors Args: augmentation (list[detectron.data.Transforms]): The geometric transforms for the input raw image and annotations. augmentation_with_crop (list[detectron.data.Transforms]): The geometric transforms with crop. is_train (bool): Whether to load train set or val set. Default: True. mask_on (bool): Whether to return the mask annotations. Default: False. img_format (str): The format of the input raw images. Default: RGB. Because detectron2 did not implement `RandomSelect` augmentation. So we provide both `augmentation` and `augmentation_with_crop` here and randomly apply one of them to the input raw images. """ def __init__( self, augmentation, augmentation_with_crop, is_train=True, mask_on=False, img_format="RGB", ): self.mask_on = mask_on self.augmentation = augmentation self.augmentation_with_crop = augmentation_with_crop logging.getLogger(__name__).info( "Full TransformGens used in training: {}, crop: {}".format( str(self.augmentation), str(self.augmentation_with_crop) ) ) self.img_format = img_format self.is_train = is_train def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) if self.augmentation_with_crop is None: image, transforms = T.apply_transform_gens(self.augmentation, image) else: if np.random.rand() > 0.5: image, transforms = T.apply_transform_gens(self.augmentation, image) else: image, transforms = T.apply_transform_gens(self.augmentation_with_crop, image) image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) if not self.is_train: # USER: Modify this if you want to keep them for some reason. dataset_dict.pop("annotations", None) return dataset_dict if "annotations" in dataset_dict: # USER: Modify this if you want to keep them for some reason. for anno in dataset_dict["annotations"]: if not self.mask_on: anno.pop("segmentation", None) anno.pop("keypoints", None) # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict.pop("annotations") if obj.get("iscrowd", 0) == 0 ] instances = utils.annotations_to_instances(annos, image_shape) dataset_dict["instances"] = utils.filter_empty_instances(instances) return dataset_dict