得克萨斯大学提出一种解码器,以极低计算成本达成医学图像分割SOTA新性能!

文摘   2025-01-18 11:42   安徽  

点击下方卡片,关注“AI前沿速递”公众号

各种重磅干货,第一时间送达


【论文标题】EMCAD: Efficient Multi-scale Convolutional Attention Decoding for Medical Image Segmentation

【论文链接】https://arxiv.org/abs/2405.06880

【代码链接】https://github.com/SLDGroup/EMCAD

【论文单位/作者】得克萨斯大学

【论文出处】CVPR 2024

【摘要】✨

这篇文章介绍了一种名为 EMCAD 的高效多尺度卷积注意力解码器,旨在解决医学图像分割中解码机制通常伴随高昂计算成本的问题,尤其适用于计算资源有限的场景,通过该解码器可实现性能与计算效率的优化。EMCAD 利用独特的多尺度深度可分离卷积块,通过多尺度卷积显著增强特征图。它还采用了通道、空间以及分组(大内核)门控注意力机制,这些机制在聚焦显著区域的同时,能极为有效地捕捉复杂的空间关系。通过采用分组卷积和深度可分离卷积,EMCAD 效率极高且扩展性良好(例如,使用标准编码器时仅需 191 万个参数和 3.81 亿次浮点运算)。该文章在属于六项医学图像分割任务的 12 个数据集上进行了严格评估,结果表明,EMCAD 在实现了最先进(SOTA)的性能的同时,参数数量和浮点运算次数分别减少了 79.4% 和 80.3% 。此外,EMCAD 对不同编码器的适应性以及在各种分割任务中的通用性,进一步证明它是一种极具潜力的工具,推动该领域朝着更高效、更准确的医学图像分析方向发展。

【技术亮点/创新点】🎉

1.多尺度卷积解码器:在解码器中使用多尺度深度卷积块(MSDC),能有效捕捉特征图内不同尺度信息,这是提升医学图像分割精度的关键,且其深度卷积的设计显著提高了计算效率。

2.多尺度卷积注意力模块:引入多尺度卷积注意力模块(MSCAM),MSCAM 利用深度可分离卷积在多尺度上精炼特征图,相比传统卷积注意力模块更高效。

3.大核分组注意力门:在解码器中加入了一种大核分组注意力门(LGAG),通过特定的分组卷积设计,以较低计算成本捕获更大空间上下文并融合特征,提升了模型对重要区域的关注度。

【工作原理/方法】🔍

EMCAD 是一种用于医学图像分割的高效多尺度卷积注意力解码器。它首先接收来自预训练分层视觉编码器的多阶段特征,然后利用MSCAM增强特征图,通过LGAG融合跳跃连接的特征,再经高效上卷积块(EUCB)上采样和进一步处理,最后由分割头(SH)生成分割输出。其在多个模块中采用深度卷积、独特的注意力机制及特定的组合策略,有效提升了医学图像分割的性能与效率。

EMCAD解码器,包含多个模块:

1.MSCAM:通过通道注意力块、空间注意力块和高效多尺度卷积块的协同作用,有效精炼特征图,提升特征表达能力。

2.LGAG:利用 3×3 组卷积处理特征和门控信号,以较低计算成本融合特征图与注意力系数,增强分割精度。

3.EUCB:采用上采样和深度卷积等操作,对特征图进行上采样并增强,使其维度和分辨率与后续连接相匹配。

4.SH:对解码器各阶段精炼后的特征图应用 1×1 卷积,生成与目标数据集类别数相对应的分割输出。

【实验结果】📈

1.二元医学图像分割:在 10 个二元医学图像分割数据集上,PVT - EMCAD - B2 表现卓越,以 26.76M 参数和 5.6G FLOPs 取得 91.10% 的最高平均 DICE 分数,在息肉、皮肤病变、细胞、乳腺癌分割等子任务中均超越现有 SOTA 方法,如在息肉分割的五个数据集中超越所有对比方法,在皮肤病变分割的 ISIC17 和 ISIC18 数据集上分别比 DeepLabV3 + 提高 2.11% 和 2.32% 等。

2.腹部器官分割:在 Synapse 多器官数据集上,PVT - EMCAD - B2 平均 DICE 分数达 83.63%,超越所有基于 CNN 和 Transformer 的 SOTA 方法,在器官边界定位上比 PVT - CASCADE 更精准,且在八个器官中的六个分割效果显著优于现有方法。

3.心脏器官分割:在 ACDC 数据集的心脏器官分割任务中,PVT - EMCAD - B2 平均 DICE 分数达到 92.12%,比 Cascaded MERIT 提高约 0.27%,在三个器官的分割 DICE 分数上均表现更优,展现出良好的分割性能。

【总结】📑

  1. 研究背景与动机:医学图像分割意义重大,但现有解码机制计算成本高,限制了其应用,因此需要新的高效方法。
  2. EMCAD 创新点
  • 采用多尺度深度卷积块,能有效增强特征图,捕捉多尺度信息,且计算高效。
  • 融合通道、空间和分组(大核)门控注意力机制,可有效捕捉复杂空间关系并聚焦显著区域。
  • 实验设置与数据集
    • 基于 Pytorch 1.11.0 在 NVIDIA RTX A6000 GPU 上实现,使用 ImageNet 预训练的 PVTv2 - b0 和 PVTv2 - b2 作为编码器。
    • 在 12 个属于 6 种医学图像分割任务的数据集上进行实验,包括息肉、皮肤病变、细胞、腹部器官、心脏器官等分割任务。
  • 实验结果
    • 二元医学图像分割:PVT - EMCAD - B2 以 26.76M 参数和 5.6G FLOPs 取得最高平均 DICE 分数 91.10%,在息肉、皮肤病变等子任务中表现优异。
    • 腹部器官分割:PVT - EMCAD - B2 在 Synapse 多器官数据集上平均 DICE 分数达 83.63%,超越现有 SOTA 方法。
    • 心脏器官分割:PVT - EMCAD - B2 在 ACDC 数据集上平均 DICE 分数达到 92.12%,优于 Cascaded MERIT。
  • 消融研究结果
    • 表明 EMCAD 解码器各组件(如 cascaded 结构、LGAG、MSCAM)对性能提升有积极作用,且 MSCAM 效果更显著。
    • 确定了 MSDC 中多尺度内核的最佳选择为[1, 3, 5]。
    • 显示 EMCAD 解码器相比基线解码器(CASCADE)在计算复杂度和性能上更具优势。
  • 研究结论:EMCAD 在性能上超越近期的 CASCADE 解码器,参数和计算量大幅减少,在多个医学图像分割任务中表现出色,对医学图像分割和语义分割任务有重要价值。
  • 【代码】💻

    编码器代码

    • PVT-B2
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from functools import partial

    from timm.models.layers import DropPath, to_2tuple, trunc_normal_
    from timm.models.registry import register_model

    import math


    class Mlp(nn.Module):
        def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
            super().__init__()
            out_features = out_features or in_features
            hidden_features = hidden_features or in_features
            self.fc1 = nn.Linear(in_features, hidden_features)
            self.dwconv = DWConv(hidden_features)
            self.act = act_layer()
            self.fc2 = nn.Linear(hidden_features, out_features)
            self.drop = nn.Dropout(drop)

            self.apply(self._init_weights)

        def _init_weights(self, m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
            elif isinstance(m, nn.Conv2d):
                fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                fan_out //= m.groups
                m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
                if m.bias is not None:
                    m.bias.data.zero_()

        def forward(self, x, H, W):
            x = self.fc1(x)
            x = self.dwconv(x, H, W)
            x = self.act(x)
            x = self.drop(x)
            x = self.fc2(x)
            x = self.drop(x)
            return x


    class Attention(nn.Module):
        def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
            super().__init__()
            assert dim % num_heads == 0f"dim {dim} should be divided by num_heads {num_heads}."

            self.dim = dim
            self.num_heads = num_heads
            head_dim = dim // num_heads
            self.scale = qk_scale or head_dim ** -0.5

            self.q = nn.Linear(dim, dim, bias=qkv_bias)
            self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
            self.attn_drop = nn.Dropout(attn_drop)
            self.proj = nn.Linear(dim, dim)
            self.proj_drop = nn.Dropout(proj_drop)

            self.sr_ratio = sr_ratio
            if sr_ratio > 1:
                self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
                self.norm = nn.LayerNorm(dim)

            self.apply(self._init_weights)

        def _init_weights(self, m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
            elif isinstance(m, nn.Conv2d):
                fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                fan_out //= m.groups
                m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
                if m.bias is not None:
                    m.bias.data.zero_()

        def forward(self, x, H, W):
            B, N, C = x.shape
            q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0213)

            if self.sr_ratio > 1:
                x_ = x.permute(021).reshape(B, C, H, W)
                x_ = self.sr(x_).reshape(B, C, -1).permute(021)
                x_ = self.norm(x_)
                kv = self.kv(x_).reshape(B, -12, self.num_heads, C // self.num_heads).permute(20314)
            else:
                kv = self.kv(x).reshape(B, -12, self.num_heads, C // self.num_heads).permute(20314)
            k, v = kv[0], kv[1]

            attn = (q @ k.transpose(-2-1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)

            x = (attn @ v).transpose(12).reshape(B, N, C)
            x = self.proj(x)
            x = self.proj_drop(x)

            return x


    class Block(nn.Module):

        def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                     drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1)
    :

            super().__init__()
            self.norm1 = norm_layer(dim)
            self.attn = Attention(
                dim,
                num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
            NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
            self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
            self.norm2 = norm_layer(dim)
            mlp_hidden_dim = int(dim * mlp_ratio)
            self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

            self.apply(self._init_weights)

        def _init_weights(self, m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
            elif isinstance(m, nn.Conv2d):
                fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                fan_out //= m.groups
                m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
                if m.bias is not None:
                    m.bias.data.zero_()

        def forward(self, x, H, W):
            x = x + self.drop_path(self.attn(self.norm1(x), H, W))
            x = x + self.drop_path(self.mlp(self.norm2(x), H, W))

            return x


    class OverlapPatchEmbed(nn.Module):
        """ Image to Patch Embedding
        """


        def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
            super().__init__()
            img_size = to_2tuple(img_size)
            patch_size = to_2tuple(patch_size)

            self.img_size = img_size
            self.patch_size = patch_size
            self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
            self.num_patches = self.H * self.W
            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                                  padding=(patch_size[0] // 2, patch_size[1] // 2))
            self.norm = nn.LayerNorm(embed_dim)

            self.apply(self._init_weights)

        def _init_weights(self, m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
            elif isinstance(m, nn.Conv2d):
                fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                fan_out //= m.groups
                m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
                if m.bias is not None:
                    m.bias.data.zero_()

        def forward(self, x):
            x = self.proj(x)
            _, _, H, W = x.shape
            x = x.flatten(2).transpose(12)
            x = self.norm(x)

            return x, H, W


    class PyramidVisionTransformerImpr(nn.Module):
        def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64128256512],
                     num_heads=[1248], mlp_ratios=[4444], qkv_bias=False, qk_scale=None, drop_rate=0.,
                     attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                     depths=[3463], sr_ratios=[8421])
    :

            super().__init__()
            self.num_classes = num_classes
            self.depths = depths

            # patch_embed
            self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
                                                  embed_dim=embed_dims[0])
            self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
                                                  embed_dim=embed_dims[1])
            self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
                                                  embed_dim=embed_dims[2])
            self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
                                                  embed_dim=embed_dims[3])

            # transformer encoder
            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
            cur = 0
            self.block1 = nn.ModuleList([Block(
                dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
                sr_ratio=sr_ratios[0])
                for i in range(depths[0])])
            self.norm1 = norm_layer(embed_dims[0])

            cur += depths[0]
            self.block2 = nn.ModuleList([Block(
                dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
                sr_ratio=sr_ratios[1])
                for i in range(depths[1])])
            self.norm2 = norm_layer(embed_dims[1])

            cur += depths[1]
            self.block3 = nn.ModuleList([Block(
                dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
                sr_ratio=sr_ratios[2])
                for i in range(depths[2])])
            self.norm3 = norm_layer(embed_dims[2])

            cur += depths[2]
            self.block4 = nn.ModuleList([Block(
                dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
                sr_ratio=sr_ratios[3])
                for i in range(depths[3])])
            self.norm4 = norm_layer(embed_dims[3])

            # classification head
            # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()

            self.apply(self._init_weights)

        def _init_weights(self, m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
            elif isinstance(m, nn.Conv2d):
                fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                fan_out //= m.groups
                m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
                if m.bias is not None:
                    m.bias.data.zero_()

        def init_weights(self, pretrained=None):
            if isinstance(pretrained, str):
                logger = 1
                #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)

        def reset_drop_path(self, drop_path_rate):
            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
            cur = 0
            for i in range(self.depths[0]):
                self.block1[i].drop_path.drop_prob = dpr[cur + i]

            cur += self.depths[0]
            for i in range(self.depths[1]):
                self.block2[i].drop_path.drop_prob = dpr[cur + i]

            cur += self.depths[1]
            for i in range(self.depths[2]):
                self.block3[i].drop_path.drop_prob = dpr[cur + i]

            cur += self.depths[2]
            for i in range(self.depths[3]):
                self.block4[i].drop_path.drop_prob = dpr[cur + i]

        def freeze_patch_emb(self):
            self.patch_embed1.requires_grad = False

        @torch.jit.ignore
        def no_weight_decay(self):
            return {'pos_embed1''pos_embed2''pos_embed3''pos_embed4''cls_token'}  # has pos_embed may be better

        def get_classifier(self):
            return self.head

        def reset_classifier(self, num_classes, global_pool=''):
            self.num_classes = num_classes
            self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        # def _get_pos_embed(self, pos_embed, patch_embed, H, W):
        #     if H * W == self.patch_embed1.num_patches:
        #         return pos_embed
        #     else:
        #         return F.interpolate(
        #             pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
        #             size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)

        def forward_features(self, x):
            B = x.shape[0]
            outs = []
            
            # stage 1
            x, H, W = self.patch_embed1(x)
            for i, blk in enumerate(self.block1):
                x = blk(x, H, W)
            x = self.norm1(x)
            x = x.reshape(B, H, W, -1).permute(0312).contiguous()
            outs.append(x)

            # stage 2
            x, H, W = self.patch_embed2(x)
            for i, blk in enumerate(self.block2):
                x = blk(x, H, W)
            x = self.norm2(x)
            x = x.reshape(B, H, W, -1).permute(0312).contiguous()
            outs.append(x)

            # stage 3
            x, H, W = self.patch_embed3(x)
            for i, blk in enumerate(self.block3):
                x = blk(x, H, W)
            x = self.norm3(x)
            x = x.reshape(B, H, W, -1).permute(0312).contiguous()
            outs.append(x)

            # stage 4
            x, H, W = self.patch_embed4(x)
            for i, blk in enumerate(self.block4):
                x = blk(x, H, W)
            x = self.norm4(x)
            x = x.reshape(B, H, W, -1).permute(0312).contiguous()
            outs.append(x)

            return outs

            # return x.mean(dim=1)

        def forward(self, x):
            x = self.forward_features(x)
            # x = self.head(x)

            return x


    class DWConv(nn.Module):
        def __init__(self, dim=768):
            super(DWConv, self).__init__()
            self.dwconv = nn.Conv2d(dim, dim, 311, bias=True, groups=dim)

        def forward(self, x, H, W):
            B, N, C = x.shape
            x = x.transpose(12).view(B, C, H, W)
            x = self.dwconv(x)
            x = x.flatten(2).transpose(12)

            return x


    def _conv_filter(state_dict, patch_size=16):
        """ convert patch embedding weight from manual patchify + linear proj to conv"""
        out_dict = {}
        for k, v in state_dict.items():
            if 'patch_embed.proj.weight' in k:
                v = v.reshape((v.shape[0], 3, patch_size, patch_size))
            out_dict[k] = v

        return out_dict


    @register_model
    class pvt_v2_b0(PyramidVisionTransformerImpr):
        def __init__(self, **kwargs):
            super(pvt_v2_b0, self).__init__(
                patch_size=4, embed_dims=[3264160256], num_heads=[1258], mlp_ratios=[8844],
                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2222], sr_ratios=[8421],
                drop_rate=0.0, drop_path_rate=0.1)



    @register_model
    class pvt_v2_b1(PyramidVisionTransformerImpr):
        def __init__(self, **kwargs):
            super(pvt_v2_b1, self).__init__(
                patch_size=4, embed_dims=[64128320512], num_heads=[1258], mlp_ratios=[8844],
                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2222], sr_ratios=[8421],
                drop_rate=0.0, drop_path_rate=0.1)

    @register_model
    class pvt_v2_b2(PyramidVisionTransformerImpr):
        def __init__(self, **kwargs):
            super(pvt_v2_b2, self).__init__(
                patch_size=4, embed_dims=[64128320512], num_heads=[1258], mlp_ratios=[8844],
                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3463], sr_ratios=[8421],
                drop_rate=0.0, drop_path_rate=0.1)

    @register_model
    class pvt_v2_b3(PyramidVisionTransformerImpr):
        def __init__(self, **kwargs):
            super(pvt_v2_b3, self).__init__(
                patch_size=4, embed_dims=[64128320512], num_heads=[1258], mlp_ratios=[8844],
                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[34183], sr_ratios=[8421],
                drop_rate=0.0, drop_path_rate=0.1)

    @register_model
    class pvt_v2_b4(PyramidVisionTransformerImpr):
        def __init__(self, **kwargs):
            super(pvt_v2_b4, self).__init__(
                patch_size=4, embed_dims=[64128320512], num_heads=[1258], mlp_ratios=[8844],
                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[38273], sr_ratios=[8421],
                drop_rate=0.0, drop_path_rate=0.1)


    @register_model
    class pvt_v2_b5(PyramidVisionTransformerImpr):
        def __init__(self, **kwargs):
            super(pvt_v2_b5, self).__init__(
                patch_size=4, embed_dims=[64128320512], num_heads=[1258], mlp_ratios=[4444],
                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[36403], sr_ratios=[8421],
                drop_rate=0.0, drop_path_rate=0.1)

    由于篇幅有限,这里只展示部分代码,如需全部代码请前往GitHub主页自信查看。



    本文内容为论文学习收获分享,受限于知识能力,本文对原文的理解可能存在偏差,最终内容以原论文为准。本文信息旨在传播和学术交流,其内容由作者负责,不代表本号观点。文中作品文字、图片等如涉及内容、版权和其他问题,请及时与我们联系,我们将在第一时间回复并处理。


    欢迎投稿

    想要让高质量的内容更快地触达读者,降低他们寻找优质信息的成本吗?关键在于那些你尚未结识的人。他们可能掌握着你渴望了解的知识。【AI前沿速递】愿意成为这样的一座桥梁,连接不同领域、不同背景的学者,让他们的学术灵感相互碰撞,激发出无限可能。

    【AI前沿速递】欢迎各高校实验室和个人在我们的平台上分享各类精彩内容,无论是最新的论文解读,还是对学术热点的深入分析,或是科研心得和竞赛经验的分享,我们的目标只有一个:让知识自由流动。

    📝 投稿指南

    • 确保文章为个人原创,未在任何公开渠道发布。若文章已在其他平台发表或即将发表,请明确说明。

    • 建议使用Markdown格式撰写稿件,并以附件形式发送清晰、无版权争议的配图。

    • 【AI前沿速递】尊重作者的署名权,并为每篇被采纳的原创首发稿件提供具有市场竞争力的稿酬。稿酬将根据文章的阅读量和质量进行阶梯式结算。

    📬 投稿方式

    • 您可以通过添加我们的小助理微信(aiqysd)进行快速投稿。请在添加时备注“投稿-姓名-学校-研究方向”


      长按添加AI前沿速递小助理



    AI前沿速递
    持续分享最新AI前沿论文成果
     最新文章