即插即用Crossformer,涨点起飞起飞了!

文摘   2025-01-05 17:20   中国香港  

论文介绍

题目:CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention

论文地址:https://arxiv.org/pdf/2108.00154

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 跨尺度嵌入层 (Cross-scale Embedding Layer, CEL)

    • 设计了一种新的嵌入方式,将不同尺度的特征融合在一起,使得每一层嵌入包含跨尺度的信息。这种设计解决了现有视觉Transformer无法有效建立不同尺度特征间交互的问题。

    • 通过使用多种不同大小的卷积核提取特征并拼接,提供了更细粒度的跨尺度特性,尤其适合处理含有多尺度目标的视觉任务。

  • 长短距离注意力机制 (Long Short Distance Attention, LSDA)

    • 将传统的自注意力模块分为短距离注意力 (SDA) 和长距离注意力 (LDA) 两部分,分别处理相邻和远距离特征的依赖。

    • 在降低计算成本的同时,保持了细粒度和粗粒度特征的完整性,从而实现了跨尺度的注意力交互。

  • 动态位置偏置 (Dynamic Position Bias, DPB)

    • 提出了一种动态位置偏置模块,替代了传统的固定位置偏置,使得模型能够适应不同大小的输入图像和分组规模。

    • 在实现位置偏置灵活性的同时,保留了相对位置表示的优势。

  • 通用性和性能提升

    • 提出了一种统一的Transformer架构 (CrossFormer),不仅适用于图像分类,还能处理目标检测、实例分割和语义分割等任务。在多个视觉任务上,CrossFormer均显著优于现有的视觉Transformer架构,尤其在密集预测任务(如目标检测和分割)上表现出更大的优势。

方法

整体架构

     CrossFormer模型采用了基于金字塔的多阶段Transformer架构,包括跨尺度嵌入层(CEL)、长短距离注意力机制(LSDA)和动态位置偏置(DPB)。CEL通过多尺度卷积核提取和融合特征,LSDA交替处理局部和全局依赖,DPB增强位置表示的灵活性。模型从输入图像逐步提取跨尺度特征,分辨率逐阶段降低,特征维度逐步增加,最终通过分类头输出结果。这种设计兼具多尺度特征提取、高效计算和任务通用性

  1. 跨尺度嵌入层 (Cross-scale Embedding Layer, CEL)

  • 每个阶段开始时,采用跨尺度嵌入层(CEL)生成输入嵌入。

  • CEL通过不同尺寸的卷积核对输入进行多尺度采样,并将这些特征拼接为一个跨尺度嵌入向量。

  • 在Stage-1中,CEL从图像输入提取跨尺度特征,而在后续阶段,CEL从前一阶段的输出中提取特征。

  • CrossFormer块 (CrossFormer Block)

      • 长短距离注意力机制 (LSDA):分为短距离注意力(SDA)和长距离注意力(LDA),交替出现,用于捕获局部和全局特征依赖。

      • 动态位置偏置 (Dynamic Position Bias, DPB):为嵌入提供相对位置表示,适用于不同尺寸的输入。

      • 每个阶段包含若干CrossFormer块(论文中的图1(b)详细展示了内部结构)。

      • CrossFormer块由以下两个关键组件构成:

    1. 分类头 (Classification Head)

      • 在最后一个阶段之后,采用全局平均池化和一个全连接层,用于分类任务。

      金字塔结构的细节

      • 分辨率和特征维度逐步降低,特征表示逐步加深

        • Stage-1: 分辨率最高,特征维度最低,用于提取细粒度特征。

        • 随着阶段的深入,分辨率逐步减小,特征维度逐步增加,以形成金字塔结构。

      • CEL的卷积核和步长

        • Stage-1使用多尺度卷积核(如 4×4, 8×8, 16×16, 32×32),步长为4。

        • Stage-2/3/4使用更小的卷积核(如 2×2 和 4×4),步长为2,以减小分辨率。

      即插即用模块作用

      CrossFormer 作为一个即插即用模块

      • 跨尺度特征建模

        • 通过跨尺度嵌入层(CEL)融合不同尺度的特征,帮助模型在不同分辨率和粒度的特征之间建立更强的依赖关系。

        • 在图像包含多种尺度对象的任务中,可以显著提高对场景的理解能力。

      • 高效的全局和局部特征捕获

        • 通过长短距离注意力机制(LSDA),既能够高效捕获局部特征的细节,又能建立全局上下文依赖,适合需要平衡局部和全局特征的任务。

      • 动态适配输入尺寸

        • 动态位置偏置(DPB)提供灵活的位置表示能力,使模型适用于不同输入大小,减少需要调整模型结构的复杂性。

      • 提升多任务性能

      • 在目标检测、分割和分类任务上均展现出性能提升,尤其是在密集预测任务(如检测和分割)上效果显著。

      消融实验结果

      • 内容

        • 比较了不同卷积核组合(即不同的跨尺度嵌入策略)对模型性能的影响。

        • 包括单尺度嵌入(如仅使用 4×4 或 8×8 卷积核)与多尺度嵌入(如组合使用 4×4, 8×8, 16×16, 32×32 卷积核)。

      • 结果

        • 使用多尺度嵌入时,模型准确率显著提高(例如,单一 4×4 卷积核的准确率为 81.5%,而多尺度嵌入的准确率达到 82.5%)。

        • 说明跨尺度特征提取对模型性能至关重要。

      • 内容

        • 比较了仅使用CEL、仅使用LSDA,以及同时使用两者对模型性能的影响。

        • 与现有的PVT(使用简单降采样)和Swin(局部注意力)进行对比。

      • 结果

        • 同时使用CEL和LSDA时,模型的准确率最高(82.5%),显著优于仅使用CEL(81.5%)或仅使用LSDA(81.9%)。

        • 说明CEL和LSDA对模型性能的提升是互补的,二者结合可以充分发挥跨尺度特征的优势。

      即插即用模块

      import torch
      import torch.nn as nn
      import torch.utils.checkpoint as checkpoint
      from timm.models.layers import DropPath, to_2tuple, trunc_normal_

      # 论文地址:https://arxiv.org/pdf/2108.00154
      # 论文:CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention (ICLR 2022 Acceptance).
      class Mlp(nn.Module):
          def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
              super().__init__()
              out_features = out_features or in_features
              hidden_features = hidden_features or in_features
              self.fc1 = nn.Linear(in_features, hidden_features)
              self.act = act_layer()
              self.fc2 = nn.Linear(hidden_features, out_features)
              self.drop = nn.Dropout(drop)

          def forward(self, x):
              x = self.fc1(x)
              x = self.act(x)
              x = self.drop(x)
              x = self.fc2(x)
              x = self.drop(x)
              return x

      class DynamicPosBias(nn.Module):
          def __init__(self, dim, num_heads, residual):
              super().__init__()
              self.residual = residual
              self.num_heads = num_heads
              self.pos_dim = dim // 4
              self.pos_proj = nn.Linear(2, self.pos_dim)
              self.pos1 = nn.Sequential(
                  nn.LayerNorm(self.pos_dim),
                  nn.ReLU(inplace=True),
                  nn.Linear(self.pos_dim, self.pos_dim),
              )
              self.pos2 = nn.Sequential(
                  nn.LayerNorm(self.pos_dim),
                  nn.ReLU(inplace=True),
                  nn.Linear(self.pos_dim, self.pos_dim)
              )
              self.pos3 = nn.Sequential(
                  nn.LayerNorm(self.pos_dim),
                  nn.ReLU(inplace=True),
                  nn.Linear(self.pos_dim, self.num_heads)
              )
          def forward(self, biases):
              if self.residual:
                  pos = self.pos_proj(biases) # 2Wh-1 * 2Ww-1, heads
                  pos = pos + self.pos1(pos)
                  pos = pos + self.pos2(pos)
                  pos = self.pos3(pos)
              else:
                  pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
              return pos

          def flops(self, N):
              flops = N * 2 * self.pos_dim
              flops += N * self.pos_dim * self.pos_dim
              flops += N * self.pos_dim * self.pos_dim
              flops += N * self.pos_dim * self.num_heads
              return flops

      class Attention(nn.Module):
          r""" Multi-head self attention module with dynamic position bias.

          Args:
              dim (int): Number of input channels.
              group_size (tuple[int]): The height and width of the group.
              num_heads (int): Number of attention heads.
              qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
              qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
              attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
              proj_drop (float, optional): Dropout ratio of output. Default: 0.0
          """


          def __init__(self, dim, group_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
                       position_bias=True)
      :


              super().__init__()
              self.dim = dim
              self.group_size = group_size # Wh, Ww
              self.num_heads = num_heads
              head_dim = dim // num_heads
              self.scale = qk_scale or head_dim ** -0.5
              self.position_bias = position_bias

              if position_bias:
                  self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
                  
                  # generate mother-set
                  position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0])
                  position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1])
                  biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Wh-1, 2W2-1
                  biases = biases.flatten(1).transpose(0, 1).float()
                  self.register_buffer("biases", biases)

                  # get pair-wise relative position index for each token inside the group
                  coords_h = torch.arange(self.group_size[0])
                  coords_w = torch.arange(self.group_size[1])
                  coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
                  coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
                  relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
                  relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
                  relative_coords[:, :, 0] += self.group_size[0] - 1  # shift to start from 0
                  relative_coords[:, :, 1] += self.group_size[1] - 1
                  relative_coords[:, :, 0] *= 2 * self.group_size[1] - 1
                  relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
                  self.register_buffer("relative_position_index", relative_position_index)

              self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
              self.attn_drop = nn.Dropout(attn_drop)
              self.proj = nn.Linear(dim, dim)
              self.proj_drop = nn.Dropout(proj_drop)

              self.softmax = nn.Softmax(dim=-1)

          def forward(self, x, mask=None):
              """
              Args:
                  x: input features with shape of (num_groups*B, N, C)
                  mask: (0/-inf) mask with shape of (num_groups, Wh*Ww, Wh*Ww) or None
              """

              B_, N, C = x.shape
              qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
              q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)

              q = q * self.scale
              attn = (q @ k.transpose(-2, -1))

              if self.position_bias:
                  pos = self.pos(self.biases) # 2Wh-1 * 2Ww-1, heads
                  # select position bias
                  relative_position_bias = pos[self.relative_position_index.view(-1)].view(
                      self.group_size[0] * self.group_size[1], self.group_size[0] * self.group_size[1], -1) # Wh*Ww,Wh*Ww,nH
                  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
                  attn = attn + relative_position_bias.unsqueeze(0)

              if mask is not None:
                  nW = mask.shape[0]
                  attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                  attn = attn.view(-1, self.num_heads, N, N)
                  attn = self.softmax(attn)
              else:
                  attn = self.softmax(attn)

              attn = self.attn_drop(attn)

              x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
              x = self.proj(x)
              x = self.proj_drop(x)
              return x

          def extra_repr(self) -> str:
              return f'dim={self.dim}, group_size={self.group_size}, num_heads={self.num_heads}'

          def flops(self, N):
              # calculate flops for 1 group with token length of N
              flops = 0
              # qkv = self.qkv(x)
              flops += N * self.dim * 3 * self.dim
              # attn = (q @ k.transpose(-2, -1))
              flops += self.num_heads * N * (self.dim // self.num_heads) * N
              # x = (attn @ v)
              flops += self.num_heads * N * N * (self.dim // self.num_heads)
              # x = self.proj(x)
              flops += N * self.dim * self.dim
              if self.position_bias:
                  flops += self.pos.flops(N)
              return flops


      class CrossFormerBlock(nn.Module):
          r""" CrossFormer Block.

          Args:
              dim (int): Number of input channels.
              input_resolution (tuple[int]): Input resulotion.
              num_heads (int): Number of attention heads.
              group_size (int): Group size.
              lsda_flag (int): use SDA or LDA, 0 for SDA and 1 for LDA.
              mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
              qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
              qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
              drop (float, optional): Dropout rate. Default: 0.0
              attn_drop (float, optional): Attention dropout rate. Default: 0.0
              drop_path (float, optional): Stochastic depth rate. Default: 0.0
              act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
              norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
          """


          def __init__(self, dim, input_resolution, num_heads, group_size=7, lsda_flag=0,
                       mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                       act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_patch_size=1)
      :

              super().__init__()
              self.dim = dim
              self.input_resolution = input_resolution
              self.num_heads = num_heads
              self.group_size = group_size
              self.lsda_flag = lsda_flag
              self.mlp_ratio = mlp_ratio
              self.num_patch_size = num_patch_size
              if min(self.input_resolution) <= self.group_size:
                  # if group size is larger than input resolution, we don't partition groups
                  self.lsda_flag = 0
                  self.group_size = min(self.input_resolution)

              self.norm1 = norm_layer(dim)

              self.attn = Attention(
                  dim, group_size=to_2tuple(self.group_size), num_heads=num_heads,
                  qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
                  position_bias=True)

              self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
              self.norm2 = norm_layer(dim)
              mlp_hidden_dim = int(dim * mlp_ratio)
              self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

              attn_mask = None
              self.register_buffer("attn_mask", attn_mask)

          def forward(self, x):
              H, W = self.input_resolution
              B, L, C = x.shape
              assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W)

              shortcut = x
              x = self.norm1(x)
              x = x.view(B, H, W, C)

              # group embeddings
              G = self.group_size
              if self.lsda_flag == 0: # 0 for SDA
                  x = x.reshape(B, H // G, G, W // G, G, C).permute(0, 1, 3, 2, 4, 5)
              else: # 1 for LDA
                  x = x.reshape(B, G, H // G, G, W // G, C).permute(0, 2, 4, 1, 3, 5)
              x = x.reshape(B * H * W // G**2, G**2, C)

              # multi-head self-attention
              x = self.attn(x, mask=self.attn_mask) # nW*B, G*G, C

              # ungroup embeddings
              x = x.reshape(B, H // G, W // G, G, G, C)
              if self.lsda_flag == 0:
                  x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, W, C)
              else:
                  x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, H, W, C)
              x = x.view(B, H * W, C)

              # FFN
              x = shortcut + self.drop_path(x)
              x = x + self.drop_path(self.mlp(self.norm2(x)))

              return x

          def extra_repr(self) -> str:
              return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
                     f"group_size={self.group_size}, lsda_flag={self.lsda_flag}, mlp_ratio={self.mlp_ratio}"

          def flops(self):
              flops = 0
              H, W = self.input_resolution
              # norm1
              flops += self.dim * H * W
              # LSDA
              nW = H * W / self.group_size / self.group_size
              flops += nW * self.attn.flops(self.group_size * self.group_size)
              # mlp
              flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
              # norm2
              flops += self.dim * H * W
              return flops

      class PatchMerging(nn.Module):
          r""" Patch Merging Layer.

          Args:
              input_resolution (tuple[int]): Resolution of input feature.
              dim (int): Number of input channels.
              norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
          """


          def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, patch_size=[2], num_input_patch_size=1):
              super().__init__()
              self.input_resolution = input_resolution
              self.dim = dim
              self.reductions = nn.ModuleList()
              self.patch_size = patch_size
              self.norm = norm_layer(dim)

              for i, ps in enumerate(patch_size):
                  if i == len(patch_size) - 1:
                      out_dim = 2 * dim // 2 ** i
                  else:
                      out_dim = 2 * dim // 2 ** (i + 1)
                  stride = 2
                  padding = (ps - stride) // 2
                  self.reductions.append(nn.Conv2d(dim, out_dim, kernel_size=ps,
                                                      stride=stride, padding=padding))

          def forward(self, x):
              """
              x: B, H*W, C
              """

              H, W = self.input_resolution
              B, L, C = x.shape
              assert L == H * W, "input feature has wrong size"
              assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

              x = self.norm(x)
              x = x.view(B, H, W, C).permute(0, 3, 1, 2)

              xs = []
              for i in range(len(self.reductions)):
                  tmp_x = self.reductions[i](x).flatten(2).transpose(1, 2)
                  xs.append(tmp_x)
              x = torch.cat(xs, dim=2)
              return x

          def extra_repr(self) -> str:
              return f"input_resolution={self.input_resolution}, dim={self.dim}"

          def flops(self):
              H, W = self.input_resolution
              flops = H * W * self.dim
              for i, ps in enumerate(self.patch_size):
                  if i == len(self.patch_size) - 1:
                      out_dim = 2 * self.dim // 2 ** i
                  else:
                      out_dim = 2 * self.dim // 2 ** (i + 1)
                  flops += (H // 2) * (W // 2) * ps * ps * out_dim * self.dim
              return flops


      class Stage(nn.Module):
          """ CrossFormer blocks for one stage.

          Args:
              dim (int): Number of input channels.
              input_resolution (tuple[int]): Input resolution.
              depth (int): Number of blocks.
              num_heads (int): Number of attention heads.
              group_size (int): variable G in the paper, one group has GxG embeddings
              mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
              qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
              qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
              drop (float, optional): Dropout rate. Default: 0.0
              attn_drop (float, optional): Attention dropout rate. Default: 0.0
              drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
              norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
              downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
              use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
          """


          def __init__(self, dim, input_resolution, depth, num_heads, group_size,
                       mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                       drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
                       patch_size_end=[4], num_patch_size=None)
      :


              super().__init__()
              self.dim = dim
              self.input_resolution = input_resolution
              self.depth = depth
              self.use_checkpoint = use_checkpoint

              # build blocks
              self.blocks = nn.ModuleList()
              for i in range(depth):
                  lsda_flag = 0 if (i % 2 == 0) else 1
                  self.blocks.append(CrossFormerBlock(dim=dim, input_resolution=input_resolution,
                                       num_heads=num_heads, group_size=group_size,
                                       lsda_flag=lsda_flag,
                                       mlp_ratio=mlp_ratio,
                                       qkv_bias=qkv_bias, qk_scale=qk_scale,
                                       drop=drop, attn_drop=attn_drop,
                                       drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                       norm_layer=norm_layer,
                                       num_patch_size=num_patch_size))

              # patch merging layer
              if downsample is not None:
                  self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer,
                                               patch_size=patch_size_end, num_input_patch_size=num_patch_size)
              else:
                  self.downsample = None

          def forward(self, x):
              for blk in self.blocks:
                  if self.use_checkpoint:
                      x = checkpoint.checkpoint(blk, x)
                  else:
                      x = blk(x)
              if self.downsample is not None:
                  x = self.downsample(x)
              return x

          def extra_repr(self) -> str:
              return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

          def flops(self):
              flops = 0
              for blk in self.blocks:
                  flops += blk.flops()
              if self.downsample is not None:
                  flops += self.downsample.flops()
              return flops


      class PatchEmbed(nn.Module):
          r""" Image to Patch Embedding

          Args:
              img_size (int): Image size. Default: 224.
              patch_size (int): Patch token size. Default: [4].
              in_chans (int): Number of input image channels. Default: 3.
              embed_dim (int): Number of linear projection output channels. Default: 96.
              norm_layer (nn.Module, optional): Normalization layer. Default: None
          """


          def __init__(self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None):
              super().__init__()
              img_size = to_2tuple(img_size)
              # patch_size = to_2tuple(patch_size)
              patches_resolution = [img_size[0] // patch_size[0], img_size[0] // patch_size[0]]
              self.img_size = img_size
              self.patch_size = patch_size
              self.patches_resolution = patches_resolution
              self.num_patches = patches_resolution[0] * patches_resolution[1]

              self.in_chans = in_chans
              self.embed_dim = embed_dim

              self.projs = nn.ModuleList()
              for i, ps in enumerate(patch_size):
                  if i == len(patch_size) - 1:
                      dim = embed_dim // 2 ** i
                  else:
                      dim = embed_dim // 2 ** (i + 1)
                  stride = patch_size[0]
                  padding = (ps - patch_size[0]) // 2
                  self.projs.append(nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding))
              if norm_layer is not None:
                  self.norm = norm_layer(embed_dim)
              else:
                  self.norm = None

          def forward(self, x):
              B, C, H, W = x.shape
              # FIXME look at relaxing size constraints
              assert H == self.img_size[0] and W == self.img_size[1], \
                  f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
              xs = []
              for i in range(len(self.projs)):
                  tx = self.projs[i](x).flatten(2).transpose(1, 2)
                  xs.append(tx) # B Ph*Pw C
              x = torch.cat(xs, dim=2)
              if self.norm is not None:
                  x = self.norm(x)
              return x

          def flops(self):
              Ho, Wo = self.patches_resolution
              flops = 0
              for i, ps in enumerate(self.patch_size):
                  if i == len(self.patch_size) - 1:
                      dim = self.embed_dim // 2 ** i
                  else:
                      dim = self.embed_dim // 2 ** (i + 1)
                  flops += Ho * Wo * dim * self.in_chans * (self.patch_size[i] * self.patch_size[i])
              if self.norm is not None:
                  flops += Ho * Wo * self.embed_dim
              return flops


      class CrossFormer(nn.Module):
          r""" CrossFormer
              A PyTorch impl of : `CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention` -

          Args:
              img_size (int | tuple(int)): Input image size. Default 224
              patch_size (int | tuple(int)): Patch size. Default: 4
              in_chans (int): Number of input image channels. Default: 3
              num_classes (int): Number of classes for classification head. Default: 1000
              embed_dim (int): Patch embedding dimension. Default: 96
              depths (tuple(int)): Depth of each stage.
              num_heads (tuple(int)): Number of attention heads in different layers.
              group_size (int): Group size. Default: 7
              mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
              qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
              qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
              drop_rate (float): Dropout rate. Default: 0
              attn_drop_rate (float): Attention dropout rate. Default: 0
              drop_path_rate (float): Stochastic depth rate. Default: 0.1
              norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
              ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
              patch_norm (bool): If True, add normalization after patch embedding. Default: True
              use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
          """


          def __init__(self, img_size=224, patch_size=[4], in_chans=3, num_classes=1000,
                       embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                       group_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                       drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                       norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                       use_checkpoint=False, merge_size=[[2], [2], [2]], **kwargs)
      :

              super().__init__()

              self.num_classes = num_classes
              self.num_layers = len(depths)
              self.embed_dim = embed_dim
              self.ape = ape
              self.patch_norm = patch_norm
              self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
              self.mlp_ratio = mlp_ratio

              # split image into non-overlapping patches
              self.patch_embed = PatchEmbed(
                  img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
                  norm_layer=norm_layer if self.patch_norm else None)
              num_patches = self.patch_embed.num_patches
              patches_resolution = self.patch_embed.patches_resolution
              self.patches_resolution = patches_resolution

              # absolute position embedding
              if self.ape:
                  self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
                  trunc_normal_(self.absolute_pos_embed, std=.02)

              self.pos_drop = nn.Dropout(p=drop_rate)

              # stochastic depth
              dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule

              # build layers
              self.layers = nn.ModuleList()

              num_patch_sizes = [len(patch_size)] + [len(m) for m in merge_size]
              for i_layer in range(self.num_layers):
                  patch_size_end = merge_size[i_layer] if i_layer < self.num_layers - 1 else None
                  num_patch_size = num_patch_sizes[i_layer]
                  layer = Stage(dim=int(embed_dim * 2 ** i_layer),
                                     input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                       patches_resolution[1] // (2 ** i_layer)),
                                     depth=depths[i_layer],
                                     num_heads=num_heads[i_layer],
                                     group_size=group_size[i_layer],
                                     mlp_ratio=self.mlp_ratio,
                                     qkv_bias=qkv_bias, qk_scale=qk_scale,
                                     drop=drop_rate, attn_drop=attn_drop_rate,
                                     drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                                     norm_layer=norm_layer,
                                     downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                                     use_checkpoint=use_checkpoint,
                                     patch_size_end=patch_size_end,
                                     num_patch_size=num_patch_size)
                  self.layers.append(layer)

              self.norm = norm_layer(self.num_features)
              self.avgpool = nn.AdaptiveAvgPool1d(1)
              self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

              self.apply(self._init_weights)

          def _init_weights(self, m):
              if isinstance(m, nn.Linear):
                  trunc_normal_(m.weight, std=.02)
                  if isinstance(m, nn.Linear) and m.bias is not None:
                      nn.init.constant_(m.bias, 0)
              elif isinstance(m, nn.LayerNorm):
                  nn.init.constant_(m.bias, 0)
                  nn.init.constant_(m.weight, 1.0)

          @torch.jit.ignore
          def no_weight_decay(self):
              return {'absolute_pos_embed'}

          @torch.jit.ignore
          def no_weight_decay_keywords(self):
              return {'relative_position_bias_table'}

          def forward_features(self, x):
              x = self.patch_embed(x)
              if self.ape:
                  x = x + self.absolute_pos_embed
              x = self.pos_drop(x)

              for layer in self.layers:
                  x = layer(x)

              x = self.norm(x) # B L C
              x = self.avgpool(x.transpose(1, 2)) # B C 1
              x = torch.flatten(x, 1)
              return x

          def forward(self, x):
              x = self.forward_features(x)
              x = self.head(x)
              return x

          def flops(self):
              flops = 0
              flops += self.patch_embed.flops()
              for i, layer in enumerate(self.layers):
                  flops += layer.flops()
              flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
              flops += self.num_features * self.num_classes
              return flops

      if __name__ == '__main__':
          input=torch.randn(1,3,224,224)
          model = CrossFormer(img_size=224,
              patch_size=[4, 8, 16, 32],
              in_chans= 3,
              num_classes=1000,
              embed_dim=48,
              depths=[2, 2, 6, 2],
              num_heads=[3, 6, 12, 24],
              group_size=[7, 7, 7, 7],
              mlp_ratio=4.,
              qkv_bias=True,
              qk_scale=None,
              drop_rate=0.0,
              drop_path_rate=0.1,
              ape=False,
              patch_norm=True,
              use_checkpoint=False,
              merge_size=[[2, 4], [2,4], [2, 4]]
          )
          output=model(input)    print(output.shape)

      便捷下载方式

      浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

      更多分析可见原文


      ai缝合大王
      聚焦AI前沿,分享相关技术、论文,研究生自救指南
       最新文章