Source code for detrex.layers.transformer

# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/bricks/transformer.py
# ------------------------------------------------------------------------------------------------

import copy
import warnings
from typing import List
import torch
import torch.nn as nn


[docs]class BaseTransformerLayer(nn.Module): # TODO: add more tutorials about BaseTransformerLayer """The implementation of Base `TransformerLayer` used in Transformer. Modified from `mmcv <https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/bricks/transformer.py>`_. It can be built by directly passing the `Attentions`, `FFNs`, `Norms` module, which support more flexible cusomization combined with `LazyConfig` system. The `BaseTransformerLayer` also supports `prenorm` when you specifying the `norm` as the first element of `operation_order`. More details about the `prenorm`: `On Layer Normalization in the Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ . Args: attn (list[nn.Module] | nn.Module): nn.Module or a list contains the attention module used in TransformerLayer. ffn (nn.Module): FFN module used in TransformerLayer. norm (nn.Module): Normalization layer used in TransformerLayer. operation_order (tuple[str]): The execution order of operation in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). Support `prenorm` when you specifying the first element as `norm`. Default = None. """ def __init__( self, attn: List[nn.Module], ffn: nn.Module, norm: nn.Module, operation_order: tuple = None, ): super(BaseTransformerLayer, self).__init__() assert set(operation_order).issubset({"self_attn", "norm", "cross_attn", "ffn"}) # count attention nums num_attn = operation_order.count("self_attn") + operation_order.count("cross_attn") if isinstance(attn, nn.Module): attn = [copy.deepcopy(attn) for _ in range(num_attn)] else: assert len(attn) == num_attn, ( f"The length of attn (nn.Module or List[nn.Module]) {num_attn}" f"is not consistent with the number of attention in " f"operation_order {operation_order}" ) self.num_attn = num_attn self.operation_order = operation_order self.pre_norm = operation_order[0] == "norm" self.attentions = nn.ModuleList() index = 0 for operation_name in operation_order: if operation_name in ["self_attn", "cross_attn"]: self.attentions.append(attn[index]) index += 1 self.embed_dim = self.attentions[0].embed_dim # count ffn nums self.ffns = nn.ModuleList() num_ffns = operation_order.count("ffn") for _ in range(num_ffns): self.ffns.append(copy.deepcopy(ffn)) # count norm nums self.norms = nn.ModuleList() num_norms = operation_order.count("norm") for _ in range(num_norms): self.norms.append(copy.deepcopy(norm))
[docs] def forward( self, query: torch.Tensor, key: torch.Tensor = None, value: torch.Tensor = None, query_pos: torch.Tensor = None, key_pos: torch.Tensor = None, attn_masks: List[torch.Tensor] = None, query_key_padding_mask: torch.Tensor = None, key_padding_mask: torch.Tensor = None, **kwargs, ): """Forward function for `BaseTransformerLayer`. **kwargs contains the specific arguments of attentions. Args: query (torch.Tensor): Query embeddings with shape `(num_query, bs, embed_dim)` or `(bs, num_query, embed_dim)` which should be specified follows the attention module used in `BaseTransformerLayer`. key (torch.Tensor): Key embeddings used in `Attention`. value (torch.Tensor): Value embeddings with the same shape as `key`. query_pos (torch.Tensor): The position embedding for `query`. Default: None. key_pos (torch.Tensor): The position embedding for `key`. Default: None. attn_masks (List[Tensor] | None): A list of 2D ByteTensor used in calculation the corresponding attention. The length of `attn_masks` should be equal to the number of `attention` in `operation_order`. Default: None. query_key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_query)`. Only used in `self_attn` layer. Defaults to None. key_padding_mask (torch.Tensor): ByteTensor for `key`, with shape `(bs, num_key)`. Default: None. """ norm_index = 0 attn_index = 0 ffn_index = 0 identity = query if attn_masks is None: attn_masks = [None for _ in range(self.num_attn)] elif isinstance(attn_masks, torch.Tensor): attn_masks = [copy.deepcopy(attn_masks) for _ in range(self.num_attn)] warnings.warn(f"Use same attn_mask in all attentions in " f"{self.__class__.__name__} ") else: assert len(attn_masks) == self.num_attn, ( f"The length of " f"attn_masks {len(attn_masks)} must be equal " f"to the number of attention in " f"operation_order {self.num_attn}" ) for layer in self.operation_order: if layer == "self_attn": temp_key = temp_value = query query = self.attentions[attn_index]( query, temp_key, temp_value, identity if self.pre_norm else None, query_pos=query_pos, key_pos=query_pos, attn_mask=attn_masks[attn_index], key_padding_mask=query_key_padding_mask, **kwargs, ) attn_index += 1 identity = query elif layer == "norm": query = self.norms[norm_index](query) norm_index += 1 elif layer == "cross_attn": query = self.attentions[attn_index]( query, key, value, identity if self.pre_norm else None, query_pos=query_pos, key_pos=key_pos, attn_mask=attn_masks[attn_index], key_padding_mask=key_padding_mask, **kwargs, ) attn_index += 1 identity = query elif layer == "ffn": query = self.ffns[ffn_index](query, identity if self.pre_norm else None) ffn_index += 1 return query
[docs]class TransformerLayerSequence(nn.Module): """Base class for TransformerEncoder and TransformerDecoder, which will copy the passed `transformer_layers` module `num_layers` time or save the passed list of `transformer_layers` as parameters named ``self.layers`` which is the type of ``nn.ModuleList``. The users should inherit `TransformerLayerSequence` and implemente their own forward function. Args: transformer_layers (list[BaseTransformerLayer] | BaseTransformerLayer): A list of BaseTransformerLayer. If it is obj:`BaseTransformerLayer`, it would be repeated `num_layers` times to a list[BaseTransformerLayer] num_layers (int): The number of `TransformerLayer`. Default: None. """ def __init__( self, transformer_layers=None, num_layers=None, ): super(TransformerLayerSequence, self).__init__() self.num_layers = num_layers self.layers = nn.ModuleList() if isinstance(transformer_layers, nn.Module): for _ in range(num_layers): self.layers.append(copy.deepcopy(transformer_layers)) else: assert isinstance(transformer_layers, list) and len(transformer_layers) == num_layers
[docs] def forward(self): """Forward function of `TransformerLayerSequence`. The users should inherit `TransformerLayerSequence` and implemente their own forward function. """ raise NotImplementedError()