# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/necks/channel_mapper.py
# ------------------------------------------------------------------------------------------------
import copy
from typing import Dict, List
import torch.nn as nn
from detrex.layers import ConvNormAct, ShapeSpec
[docs]class ChannelMapper(nn.Module):
"""Channel Mapper for reduce/increase channels of backbone features. Modified
from `mmdet <https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/necks/channel_mapper.py>`_.
This is used to reduce/increase the channels of backbone features.
Args:
input_shape (Dict[str, ShapeSpec]): A dict which contains the backbone features meta infomation,
e.g. ``input_shape = {"res5": ShapeSpec(channels=2048)}``.
in_features (List[str]): A list contains the keys which maps the features output from the backbone,
e.g. ``in_features = ["res"]``.
out_channels (int): Number of output channels for each scale.
kernel_size (int, optional): Size of the convolving kernel for each scale.
Default: 3.
stride (int, optional): Stride of convolution for each scale. Default: 1.
bias (bool, optional): If True, adds a learnable bias to the output of each scale.
Default: True.
groups (int, optional): Number of blocked connections from input channels to
output channels for each scale. Default: 1.
dilation (int, optional): Spacing between kernel elements for each scale.
Default: 1.
norm_layer (nn.Module, optional): The norm layer used for each scale. Default: None.
activation (nn.Module, optional): The activation layer used for each scale. Default: None.
num_outs (int, optional): Number of output feature maps. There will be ``extra_convs`` when
``num_outs`` is larger than the length of ``in_features``. Default: None.
Examples:
>>> import torch
>>> import torch.nn as nn
>>> from detrex.modeling import ChannelMapper
>>> from detectron2.modeling import ShapeSpec
>>> input_features = {
... "p0": torch.randn(1, 128, 128, 128),
... "p1": torch.randn(1, 256, 64, 64),
... "p2": torch.randn(1, 512, 32, 32),
... "p3": torch.randn(1, 1024, 16, 16),
... }
>>> input_shapes = {
... "p0": ShapeSpec(channels=128),
... "p1": ShapeSpec(channels=256),
... "p2": ShapeSpec(channels=512),
... "p3": ShapeSpec(channels=1024),
... }
>>> in_features = ["p0", "p1", "p2", "p3"]
>>> neck = ChannelMapper(
... input_shapes=input_shapes,
... in_features=in_features,
... out_channels=256,
... norm_layer=nn.GroupNorm(num_groups=32, num_channels=256)
>>> outputs = neck(input_features)
>>> for i in range(len(outputs)):
... print(f"output[{i}].shape = {outputs[i].shape}")
output[0].shape = torch.Size([1, 256, 128, 128])
output[1].shape = torch.Size([1, 256, 64, 64])
output[2].shape = torch.Size([1, 256, 32, 32])
output[3].shape = torch.Size([1, 256, 16, 16])
"""
def __init__(
self,
input_shapes: Dict[str, ShapeSpec],
in_features: List[str],
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
bias: bool = True,
groups: int = 1,
dilation: int = 1,
norm_layer: nn.Module = None,
activation: nn.Module = None,
num_outs: int = None,
**kwargs,
):
super(ChannelMapper, self).__init__()
self.extra_convs = None
in_channels_per_feature = [input_shapes[f].channels for f in in_features]
if num_outs is None:
num_outs = len(input_shapes)
self.convs = nn.ModuleList()
for in_channel in in_channels_per_feature:
self.convs.append(
ConvNormAct(
in_channels=in_channel,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
bias=bias,
groups=groups,
dilation=dilation,
norm_layer=copy.deepcopy(norm_layer),
activation=copy.deepcopy(activation),
)
)
if num_outs > len(in_channels_per_feature):
self.extra_convs = nn.ModuleList()
for i in range(len(in_channels_per_feature), num_outs):
if i == len(in_channels_per_feature):
in_channel = in_channels_per_feature[-1]
else:
in_channel = out_channels
self.extra_convs.append(
ConvNormAct(
in_channels=in_channel,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=bias,
groups=groups,
dilation=dilation,
norm_layer=copy.deepcopy(norm_layer),
activation=copy.deepcopy(activation),
)
)
self.input_shapes = input_shapes
self.in_features = in_features
self.out_channels = out_channels
[docs] def forward(self, inputs):
"""Forward function for ChannelMapper
Args:
inputs (Dict[str, torch.Tensor]): The backbone feature maps.
Return:
tuple(torch.Tensor): A tuple of the processed features.
"""
assert len(inputs) == len(self.convs)
outs = [self.convs[i](inputs[self.in_features[i]]) for i in range(len(inputs))]
if self.extra_convs:
for i in range(len(self.extra_convs)):
if i == 0:
outs.append(self.extra_convs[0](inputs[self.in_features[-1]]))
else:
outs.append(self.extra_convs[i](outs[-1]))
return tuple(outs)