NeurIPS CoAtNet:结合卷积与注意力的高效视觉网络,全面提升图像处理性能

文摘   2025-01-02 17:20   中国香港  

论文介绍

题目:CoAtNet: Marrying Convolution and Attention for All Data Sizes

论文地址:https://arxiv.org/pdf/2106.04803

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 统一卷积和注意力的特性: CoAtNet 提出了将深度卷积(Depthwise Convolution)和自注意力(Self-Attention)相结合的架构。它利用了相对注意力(Relative Attention)的方法,将卷积的归纳偏置(Inductive Bias)和注意力的高模型容量(Model Capacity)结合,弥补了 Transformer 在小数据集上的泛化问题。

  • 多阶段网络设计: CoAtNet 采用分阶段的网络结构,逐渐减少空间分辨率并增加通道数。在设计中,通过将卷积阶段(用于提取低级特征)和注意力阶段(用于处理高级特征)按照特定顺序组合,优化了泛化能力和模型容量的平衡。

  • 泛化性与模型容量的系统优化: 实验表明,CoAtNet 在小数据集(如 ImageNet-1K)上表现优异,与同样计算资源下的 SOTA 卷积网络(如 EfficientNet-V2 和 NFNet)相媲美;在大数据集(如 JFT-300M 和 JFT-3B)上,利用 Transformer 的扩展性,显著提高了性能。

  • 效率与准确性的平衡: CoAtNet 在消耗更少计算资源的情况下实现了更高的准确性。例如,在 JFT-3B 上预训练后,其变体 CoAtNet-7 达到了 ImageNet 数据集上的 90.88% 的准确率,超越了此前的 SOTA 模型(如 ViT-G/14)。

  • 相对注意力机制的应用: 相较于传统的全局注意力机制,CoAtNet 的相对注意力通过预归一化形式显著减少了计算复杂度,同时在泛化和转移学习上表现更优。

方法

整体架构

     CoAtNet 是一种结合卷积和自注意力的多阶段混合架构,通过前两阶段的卷积(MBConv 块)提取局部特征,后两阶段的自注意力(相对注意力)捕获全局特征,实现了卷积的高效泛化能力和 Transformer 的强模型容量的有机结合。网络从高分辨率到低分辨率逐步下采样,通道数逐步增加,最终通过全局池化和全连接层完成分类任务,适用于从小规模到大规模数据的图像处理场景。

1. 多阶段设计

CoAtNet 将网络分为多个阶段(Stage 0 到 Stage 4),每个阶段在空间分辨率和通道数上逐步调整:

  • Stage 0(Stem阶段)

    • 采用卷积操作作为初始特征提取层。

    • 分辨率:112×112112 \times 112

    • 主要作用:对输入图像进行基本特征提取和下采样。

  • Stage 1(卷积阶段)

    • 使用 MBConv 块(MobileNet 中提出的卷积块)。

    • 分辨率进一步缩小至56×5656 \times 56

    • 作用:提取低级局部特征,利用卷积的归纳偏置提高泛化能力。

  • Stage 2(卷积阶段)

    • 继续使用 MBConv 块,并逐步减少分辨率至28×2828 \times 28

    • 提供卷积的优势以处理局部模式。

  • Stage 3(Transformer阶段)

    • 转换为使用带有相对注意力的 Transformer 块(Rel-Attention)。

    • 分辨率缩小至14×1414 \times 14

    • 作用:捕获全局特征和长距离依赖。

  • Stage 4(Transformer阶段)

    • 再次堆叠更多的 Transformer 块。

    • 分辨率缩小至7×77 \times 7

    • 作用:进一步捕获全局上下文信息,形成高层语义特征。


2. 网络设计原则

  • 卷积与注意力的顺序

    • 卷积阶段(Stage 1 和 Stage 2)用于提取局部特征。

    • 注意力阶段(Stage 3 和 Stage 4)用于建模全局上下文。

    • 这种分阶段的设计遵循了卷积对低级特征的优势和注意力对高级特征的优势。

  • 下采样(Down-sampling)

    • 每个阶段开始时,分辨率缩减(例如通过步长为2的卷积或池化操作),通道数增加。

    • 下采样设计对卷积和 Transformer 分别进行了优化。


3. 模型的模块化实现

  • MBConv 块(卷积阶段核心模块):

    • 深度卷积(Depthwise Convolution)用于捕获局部特征。

    • 倒置瓶颈(Inverted Bottleneck)结构将通道数扩展后再压缩,便于残差连接。

  • Transformer 块(注意力阶段核心模块):

    • 使用相对注意力(Relative Attention),结合位置偏置和自注意力。

    • 包含前馈网络(FFN),进行特征投影。


4. 模型配置

论文中提供了多个 CoAtNet 模型的具体配置(CoAtNet-0 到 CoAtNet-7),主要变化在于不同阶段的模块数量和通道宽度。例如:

  • CoAtNet-0:用于小规模任务,参数量较少。

  • CoAtNet-7:大规模任务,如 JFT-3B 数据集预训练,参数量和 FLOPs 明显增加。

具体参数和层次设置如下(部分配置,详细见论文附录):

  • S0(Conv):L=2,D=64L=2, D=64(层数2,通道数64)。

  • S1(MBConv):L=2,D=96L=2, D=96

  • S2(MBConv):L=6,D=192L=6, D=192

  • S3(Rel-Attention):L=14,D=384L=14, D=384

  • S4(Rel-Attention):L=2,D=768L=2, D=768

即插即用模块作用

CoAtNet 作为一个即插即用模块

  • 小数据集任务

    • 例如医学影像分析或工业视觉检测,数据量有限但需要高泛化能力的场景。

    • 作用:利用卷积的归纳偏置,有效提取局部特征并提高模型泛化能力。

  • 大数据集任务

    • 例如大规模图像分类任务(ImageNet)、预训练任务(如 JFT-3B)。

    • 作用:通过自注意力建模全局上下文,在大数据量下充分利用其高模型容量。

  • 多尺度特征建模

    • 例如目标检测或语义分割,需要从不同分辨率提取特征。

    • 作用:利用多阶段架构逐步降低分辨率,保证从局部到全局的特征有效聚合。

  • 计算效率要求高的任务

    • 例如边缘设备(如移动设备)上的图像识别。

      作用:通过卷积的高效性和注意力的扩展性,减少计算量和内存占用,实现高效推理

消融实验结果

  • 内容:对比了是否使用相对注意力(Relative Attention)的 CoAtNet 模型在 ImageNet-1K 和 ImageNet-21K 的表现。

  • 结论

    • 使用相对注意力的模型在小数据集(ImageNet-1K)上具有更好的泛化能力。

    • 在大数据集(ImageNet-21K)上的迁移学习中,使用相对注意力的模型表现出更好的准确率,证明了该设计在特征建模中的重要性。



  • 内容:对比了不同阶段的块数分布(S2 和 S3)对性能的影响。

  • 结论

    • S2(MBConv)和 S3(Transformer)之间的块数分配对性能有显著影响。

    • 在 ImageNet-21K 上迁移学习的实验表明,保持适当数量的卷积块(S2)对模型的泛化能力和迁移性能至关重要。

即插即用模块

from torch import nn, sqrt
import torch
import sys
from math import sqrt
sys.path.append('.')
import math
from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import init
import numpy as np

# 论文地址:https://arxiv.org/pdf/2106.04803
# 论文:CoAtNet: Marrying Convolution and Attention for All Data Sizes


class ScaledDotProductAttention(nn.Module):
    '''
    Scaled dot-product attention
    '''


    def __init__(self, d_model, d_k, d_v, h,dropout=.1):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        '''

        super(ScaledDotProductAttention, self).__init__()
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)
        self.fc_o = nn.Linear(h * d_v, d_model)
        self.dropout=nn.Dropout(dropout)

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h

        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''

        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)
        att = torch.softmax(att, -1)
        att=self.dropout(att)

        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
        out = self.fc_o(out) # (b_s, nq, d_model)
        return out
    
class SwishImplementation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

class MemoryEfficientSwish(nn.Module):
    def forward(self, x):
        return SwishImplementation.apply(x)


def drop_connect(inputs, p, training):
    """ Drop connect. """
    if not training: return inputs
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    binary_tensor = torch.floor(random_tensor)
    output = inputs / keep_prob * binary_tensor
    return output


def get_same_padding_conv2d(image_size=None):
     return partial(Conv2dStaticSamePadding, image_size=image_size)

def get_width_and_height_from_size(x):
    """ Obtains width and height from a int or tuple """
    if isinstance(x, int): return x, x
    if isinstance(x, list) or isinstance(x, tuple): return x
    else: raise TypeError()

def calculate_output_image_size(input_image_size, stride):
    """
    计算出 Conv2dSamePadding with a stride.
    """

    if input_image_size is None: return None
    image_height, image_width = get_width_and_height_from_size(input_image_size)
    stride = stride if isinstance(stride, int) else stride[0]
    image_height = int(math.ceil(image_height / stride))
    image_width = int(math.ceil(image_width / stride))
    return [image_height, image_width]



class Conv2dStaticSamePadding(nn.Conv2d):
    """ 2D Convolutions like TensorFlow, for a fixed image size"""

    def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
        super().__init__(in_channels, out_channels, kernel_size, **kwargs)
        self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2

        # Calculate padding based on image size and save it
        assert image_size is not None
        ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
        kh, kw = self.weight.size()[-2:]
        sh, sw = self.stride
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
        if pad_h > 0 or pad_w > 0:
            self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
        else:
            self.static_padding = Identity()

    def forward(self, x):
        x = self.static_padding(x)
        x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
        return x

class Identity(nn.Module):
    def __init__(self, ):
        super(Identity, self).__init__()

    def forward(self, input):
        return input


# MBConvBlock
class MBConvBlock(nn.Module):
    '''
    层 ksize3*3 输入32 输出16 conv1 stride步长1
    '''

    def __init__(self, ksize, input_filters, output_filters, expand_ratio=1, stride=1, image_size=224):
        super().__init__()
        self._bn_mom = 0.1
        self._bn_eps = 0.01
        self._se_ratio = 0.25
        self._input_filters = input_filters
        self._output_filters = output_filters
        self._expand_ratio = expand_ratio
        self._kernel_size = ksize
        self._stride = stride

        inp = self._input_filters
        oup = self._input_filters * self._expand_ratio
        if self._expand_ratio != 1:
            Conv2d = get_same_padding_conv2d(image_size=image_size)
            self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
            self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)


        # Depthwise convolution
        k = self._kernel_size
        s = self._stride
        Conv2d = get_same_padding_conv2d(image_size=image_size)
        self._depthwise_conv = Conv2d(
            in_channels=oup, out_channels=oup, groups=oup,
            kernel_size=k, stride=s, bias=False)
        self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
        image_size = calculate_output_image_size(image_size, s)

        # Squeeze and Excitation layer, if desired
        Conv2d = get_same_padding_conv2d(image_size=(1,1))
        num_squeezed_channels = max(1, int(self._input_filters * self._se_ratio))
        self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
        self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)

        # Output phase
        final_oup = self._output_filters
        Conv2d = get_same_padding_conv2d(image_size=image_size)
        self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
        self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
        self._swish = MemoryEfficientSwish()

    def forward(self, inputs, drop_connect_rate=None):
        """
        :param inputs: input tensor
        :param drop_connect_rate: drop connect rate (float, between 0 and 1)
        :return: output of block
        """


        # Expansion and Depthwise Convolution
        x = inputs
        if self._expand_ratio != 1:
            expand = self._expand_conv(inputs)
            bn0 = self._bn0(expand)
            x = self._swish(bn0)
        depthwise = self._depthwise_conv(x)
        bn1 = self._bn1(depthwise)
        x = self._swish(bn1)

        # Squeeze and Excitation
        x_squeezed = F.adaptive_avg_pool2d(x, 1)
        x_squeezed = self._se_reduce(x_squeezed)
        x_squeezed = self._swish(x_squeezed)
        x_squeezed = self._se_expand(x_squeezed)
        x = torch.sigmoid(x_squeezed) * x

        x = self._bn2(self._project_conv(x))

        # Skip connection and drop connect
        input_filters, output_filters = self._input_filters, self._output_filters
        if self._stride == 1 and input_filters == output_filters:
            if drop_connect_rate:
                x = drop_connect(x, p=drop_connect_rate, training=self.training)
            x = x + inputs # skip connection
        return x
    
class CoAtNet(nn.Module):
    def __init__(self,in_ch,image_size,out_chs=[64,96,192,384,768]):
        super().__init__()
        self.out_chs=out_chs
        self.maxpool2d=nn.MaxPool2d(kernel_size=2,stride=2)
        self.maxpool1d = nn.MaxPool1d(kernel_size=2, stride=2)

        self.s0=nn.Sequential(
            nn.Conv2d(in_ch,in_ch,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(in_ch,in_ch,kernel_size=3,padding=1)
        )
        self.mlp0=nn.Sequential(
            nn.Conv2d(in_ch,out_chs[0],kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(out_chs[0],out_chs[0],kernel_size=1)
        )
        
        self.s1=MBConvBlock(ksize=3,input_filters=out_chs[0],output_filters=out_chs[0],image_size=image_size//2)
        self.mlp1=nn.Sequential(
            nn.Conv2d(out_chs[0],out_chs[1],kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(out_chs[1],out_chs[1],kernel_size=1)
        )

        self.s2=MBConvBlock(ksize=3,input_filters=out_chs[1],output_filters=out_chs[1],image_size=image_size//4)
        self.mlp2=nn.Sequential(
            nn.Conv2d(out_chs[1],out_chs[2],kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(out_chs[2],out_chs[2],kernel_size=1)
        )

        self.s3=ScaledDotProductAttention(out_chs[2],out_chs[2]//8,out_chs[2]//8,8)
        self.mlp3=nn.Sequential(
            nn.Linear(out_chs[2],out_chs[3]),
            nn.ReLU(),
            nn.Linear(out_chs[3],out_chs[3])
        )

        self.s4=ScaledDotProductAttention(out_chs[3],out_chs[3]//8,out_chs[3]//8,8)
        self.mlp4=nn.Sequential(
            nn.Linear(out_chs[3],out_chs[4]),
            nn.ReLU(),
            nn.Linear(out_chs[4],out_chs[4])
        )


    def forward(self, x) :
        B,C,H,W=x.shape
        #stage0
        y=self.mlp0(self.s0(x))
        y=self.maxpool2d(y)
        #stage1
        y=self.mlp1(self.s1(y))
        y=self.maxpool2d(y)
        #stage2
        y=self.mlp2(self.s2(y))
        y=self.maxpool2d(y)
        #stage3
        y=y.reshape(B,self.out_chs[2],-1).permute(0,2,1) #B,N,C
        y=self.mlp3(self.s3(y,y,y))
        y=self.maxpool1d(y.permute(0,2,1)).permute(0,2,1)
        #stage4
        y=self.mlp4(self.s4(y,y,y))
        y=self.maxpool1d(y.permute(0,2,1))
        N=y.shape[-1]
        y=y.reshape(B,self.out_chs[4],int(sqrt(N)),int(sqrt(N)))

        return y

if __name__ == '__main__':
    input=torch.randn(1,3,224,224)
    block=CoAtNet(3,224)
    output=block(input)    print(output.shape)

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章