Source code for detrex.modeling.backbone.timm_backbone

# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# # Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------------------------------
# Support TIMM Backbone
# Modified from:
# https://github.com/open-mmlab/mmclassification/blob/master/mmcls/models/backbones/timm_backbone.py
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/backbone.py
# ------------------------------------------------------------------------------------------------

import warnings
from typing import Tuple
import torch.nn as nn

from detectron2.modeling.backbone import Backbone
from detectron2.utils import comm
from detectron2.utils.logger import setup_logger

try:
    import timm
except ImportError:
    timm = None


def log_timm_feature_info(feature_info):
    """Print feature_info of timm backbone to help development and debug.
    Args:
        feature_info (list[dict] | timm.models.features.FeatureInfo | None):
            feature_info of timm backbone.
    """
    logger = setup_logger(name="timm backbone")
    if feature_info is None:
        logger.warning("This backbone does not have feature_info")
    elif isinstance(feature_info, list):
        for feat_idx, each_info in enumerate(feature_info):
            logger.info(f"backbone feature_info[{feat_idx}]: {each_info}")
    else:
        try:
            logger.info(f"backbone out_indices: {feature_info.out_indices}")
            logger.info(f"backbone out_channels: {feature_info.channels()}")
            logger.info(f"backbone out_strides: {feature_info.reduction()}")
        except AttributeError:
            logger.warning("Unexpected format of backbone feature_info")


[docs]class TimmBackbone(Backbone): """A wrapper for using backbone from timm library. Please see the document for `feature extraction with timm <https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_ for more details. Args: model_name (str): Name of timm model to instantiate. features_only (bool): Whether to extract feature pyramid (multi-scale feature maps from the deepest layer of each stage). pretrained (bool): Whether to load pretrained weights. Default: False. checkpoint_path (str): Whether to load pretrained weights. Default: False. in_channels (int): The number of input channels. Default: 3. out_indices (tuple[str]): The extracted feature indices which select specific feature levels or limit the stride of the feature extractor. out_features (tuple[str]): A map for the output feature dict, e.g., set ("p0", "p1") to return only the feature from indices (0, 1) as ``{"p0": feature from indice 0, "p1": feature from indice 1}``. norm_layer (nn.Module): Set the specified norm layer for feature extractor, e.g., set ``norm_layer=FrozenBatchNorm2d`` to freeze the norm layer in feature extractor. """ def __init__( self, model_name: str, features_only: bool = True, pretrained: bool = False, checkpoint_path: str = "", in_channels: int = 3, out_indices: Tuple[int] = (0, 1, 2, 3), norm_layer: nn.Module = None, ): super().__init__() logger = setup_logger(name="timm backbone") if timm is None: raise RuntimeError('Failed to import timm. Please run "pip install timm". ') if not isinstance(pretrained, bool): raise TypeError("pretrained must be bool, not str for model path") if features_only and checkpoint_path: warnings.warn( "Using both features_only and checkpoint_path may cause error" " in timm. See " "https://github.com/rwightman/pytorch-image-models/issues/488" ) try: self.timm_model = timm.create_model( model_name=model_name, features_only=features_only, pretrained=pretrained, in_chans=in_channels, out_indices=out_indices, checkpoint_path=checkpoint_path, norm_layer=norm_layer, ) except Exception as error: if "feature_info" in str(error): raise AttributeError( "Using features_only may cause attribute error" " in timm, cause there's no feature_info attribute in some models. See " "https://github.com/rwightman/pytorch-image-models/issues/1438" ) elif "norm_layer" in str(error): raise ValueError( f"{model_name} does not support specified norm layer, please set 'norm_layer=None'" ) else: logger.info(error) exit() self.out_indices = out_indices feature_info = getattr(self.timm_model, "feature_info", None) if comm.get_rank() == 0: log_timm_feature_info(feature_info) if feature_info is not None: output_feature_channels = { "p{}".format(out_indices[i]): feature_info.channels()[i] for i in range(len(out_indices)) } out_feature_strides = { "p{}".format(out_indices[i]): feature_info.reduction()[i] for i in range(len(out_indices)) } self._out_features = {"p{}".format(out_indices[i]) for i in range(len(out_indices))} self._out_feature_channels = { feat: output_feature_channels[feat] for feat in self._out_features } self._out_feature_strides = { feat: out_feature_strides[feat] for feat in self._out_features }
[docs] def forward(self, x): """Forward function of `TimmBackbone`. Args: x (torch.Tensor): the input tensor for feature extraction. Returns: dict[str->Tensor]: mapping from feature name (e.g., "p1") to tensor """ features = self.timm_model(x) outs = {} for i in range(len(self.out_indices)): out = features[i] outs["p{}".format(self.out_indices[i])] = out return outs