即插即用多模态条件注意力模块TabAttention,即用即涨点

文摘   2025-01-19 17:20   上海  

论文介绍

题目:TabAttention: Learning Attention Conditionally on Tabular Data

论文地址:https://arxiv.org/pdf/2310.18129

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 提出TabAttention模块

    • 设计了一种新的注意力模块TabAttention,它通过结合表格数据条件学习注意力权重来增强卷积神经网络(CNN)的性能。

    • 将经典的CBAM(Convolutional Block Attention Module)扩展到3D数据处理,并添加了时间注意力模块(TAM),利用多头自注意力(MHSA)学习时间注意力映射。

  • 融合表格数据嵌入

    • TabAttention通过表格数据嵌入将表格数据的信息与影像特征图结合,用于生成通道、空间和时间注意力映射。

    • 这种方法使模型能够更好地判断“关注什么”、“在哪里关注”以及“何时关注”,从而改进预测性能。

  • 在胎儿出生体重(FBW)预测任务上的验证

    • 使用腹部超声视频和胎儿生物测量数据(如腹围、头围等)进行了实验。

    • 实验结果表明,TabAttention在所有测试方法中取得了最低的误差(MAE、RMSE和MAPE指标),性能优于当前的人工方法和基于表格或影像数据的机器学习方法。

  • 模块设计的通用性

    • TabAttention可以无缝集成到任何CNN架构中,例如本文中与3D ResNet-18的集成。

    • 模块设计对表格数据的依赖使其能够在许多需要影像和表格数据结合的临床工作流中应用。

  • 对现有方法的改进

    • 将表格数据作为提示信息,指导网络学习注意力映射,显著提升了基线方法的性能。

    • 解决了影像特征与表格数据之间交互不足的问题。

方法

整体架构

       这篇论文中提出的模型整体结构是基于3D卷积神经网络(3D ResNet-18),通过集成TabAttention模块增强其性能。TabAttention模块包括通道注意力模块(CAM)、空间注意力模块(SAM)和时间注意力模块(TAM),这些模块通过结合影像特征和表格数据嵌入来生成多维度的注意力映射。影像特征首先通过3D卷积处理,生成的中间特征映射依次经过CAM、SAM和TAM的加权调整,最终生成优化的输出特征,用于预测任务(如胎儿出生体重)。模型的设计旨在高效融合影像和表格数据,实现对目标特征的更精准捕捉。

  • 输入数据

    • 影像数据:二维超声视频序列,经过预处理后形成3D特征输入。

    • 表格数据:胎儿生物测量参数(如腹围、头围等)及母体年龄等数值特征。

  • 主干网络

    • 使用3D ResNet-18作为主干网络,提取影像数据的时空特征。

    • 在每个残差块中嵌入TabAttention模块以增强特征学习。

  • TabAttention模块


    • 包括通道注意力模块(CAM)空间注意力模块(SAM)时间注意力模块(TAM)

    • 这些模块通过结合表格数据嵌入,对影像特征进行多维度的注意力加权。

    • 表格数据通过嵌入网络转化为特征向量,与影像特征联合计算注意力映射。

  • 输出层

    • 在3D ResNet-18的最终全连接层输出预测值(如胎儿出生体重)。

    • 使用回归损失函数(如均方误差)优化模型。

即插即用模块作用

TabAttention 作为一个即插即用模块

  • 多模态数据的融合

    • 当任务中同时存在影像数据(如医学影像、视频)和表格数据(如数值特征、统计数据)时,TabAttention能够高效融合这两种数据,提高预测性能。

    • 典型场景:医学诊断(结合影像和患者信息)、工业检测(结合视觉和传感器数据)、遥感影像分析(结合影像和环境参数)。

  • 时空特征分析

    • 涉及时间维度的序列数据(如视频或动态信号)的场景,可以利用TabAttention中的时间注意力模块(TAM)增强对时间维度变化的捕捉。

    • 典型场景:行为识别、视频分类、时序事件检测。

  • 复杂场景中特征的关注优化

    • 当数据具有高维特征(例如3D影像、复杂表格数据)且需要对关键特征进行选择和权重分配时,TabAttention可以通过通道注意力(CAM)空间注意力(SAM)来优化模型关注的内容。

    • 典型场景:高维数据分析(如基因组学、金融数据分析)。

消融实验结果

  • 内容:展示了不同模块组合对模型性能的影响,指标包括:

    • 平均绝对误差(MAE)。

    • 均方根误差(RMSE)。

    • 平均绝对百分比误差(MAPE)。

  • 实验设置

    • 基线模型:3D ResNet-18。

    • 逐步加入关键模块(TAM、CBAM、表格数据嵌入等),最终形成完整的TabAttention模块。

  • 结果说明

    • 仅加入TAM:性能提升有限,说明仅时间注意力不足以显著提高结果。

    • 加入CBAM和表格数据嵌入:性能明显提升,表明表格数据对注意力机制的重要作用。

    • 完整TabAttention模块:实现最佳性能,验证了模块设计的有效性和各部分的协同作用。

即插即用模块

import torch
from torch import nn
from torch.functional import F
import math
class TabAttention(nn.Module):
    def __init__(self, input_dim, tab_dim=6, tabattention=True, cam_sam=True, temporal_attention=True):

        super(TabAttention, self).__init__()

        channel_dim, h, w, frame_dim = input_dim
        hw_size = (h, w)
        self.input_dim = input_dim
        self.tabattention = tabattention
        self.temporal_attention = temporal_attention
        self.cam_sam = cam_sam
        if self.cam_sam:
            self.channel_gate = ChannelGate(channel_dim, tabattention=tabattention, tab_dim=tab_dim)
            self.spatial_gate = SpatialGate(tabattention=tabattention, tab_dim=tab_dim, input_size=hw_size)
        if temporal_attention:
            self.temporal_gate = TemporalGate(frame_dim, tabattention=tabattention, tab_dim=tab_dim)

    def forward(self, x, tab=None):
        b, c, h, w, f = x.shape
        x_in = torch.permute(x, (0, 4, 1, 2, 3))
        x_in = torch.reshape(x_in, (b * f, c, h, w))
        if self.tabattention:
            tab_rep = tab.repeat(f, 1, 1)
        else:
            tab_rep = None

        if self.cam_sam:
            x_out = self.channel_gate(x_in, tab_rep)
            x_out = self.spatial_gate(x_out, tab_rep)
        else:
            x_out = x_in

        x_out = torch.reshape(x_out, (b, f, c, h, w))

        if self.temporal_attention:
            x_out = self.temporal_gate(x_out, tab)

        x_out = torch.permute(x_out, (0, 2, 3, 4, 1)) # b,c,h,w,f

        return x_out


class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
                 bn=True, bias=False)
:

        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
                              dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class ChannelGate(nn.Module):
    def __init__(self, gate_channels, tabattention=True, tab_dim=6, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.tabattention = tabattention
        self.tab_dim = tab_dim
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
        )
        self.pool_types = pool_types
        if tabattention:
            self.pool_types = ['avg', 'max', 'tab']
            self.tab_embedding = nn.Sequential(
                nn.Linear(tab_dim, gate_channels // reduction_ratio),
                nn.ReLU(),
                nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )

    def forward(self, x, tab=None):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type == 'avg':
                avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp(avg_pool)
            elif pool_type == 'max':
                max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp(max_pool)
            elif pool_type == 'lp':
                lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp(lp_pool)
            elif pool_type == 'lse':
                # LSE pool only
                lse_pool = logsumexp_2d(x)
                channel_att_raw = self.mlp(lse_pool)
            elif pool_type == 'tab':
                embedded = self.tab_embedding(tab)
                embedded = torch.reshape(embedded, (-1, self.gate_channels))
                pool = self.mlp(embedded)
                channel_att_raw = pool

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale


class TemporalMHSA(nn.Module):
    def __init__(self, input_dim=2, seq_len=16, heads=2):
        super(TemporalMHSA, self).__init__()

        self.input_dim = input_dim
        self.seq_len = seq_len
        self.embedding_dim = 4
        self.head_dim = self.embedding_dim // heads
        self.heads = heads
        self.qkv = nn.Linear(self.input_dim, self.embedding_dim * 3)
        self.rel = nn.Parameter(torch.randn([1, 1, seq_len, 1]), requires_grad=True)
        self.o_proj = nn.Linear(self.embedding_dim, 1)

    def forward(self, x):
        batch_size, seq_length, _ = x.size()
        qkv = self.qkv(x)
        qkv = qkv.reshape(batch_size, seq_length, self.heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        d_k = q.size()[-1]
        k = k + self.rel.expand_as(k)
        attn_logits = torch.matmul(q, k.transpose(-2, -1))
        attn_logits = attn_logits / math.sqrt(d_k)
        attention = F.softmax(attn_logits, dim=-1)
        values = torch.matmul(attention, v)
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, self.embedding_dim) # [Batch, SeqLen, EmbeddingDim]
        x_out = self.o_proj(values)

        return x_out


class TemporalGate(nn.Module):
    def __init__(self, gate_frames, pool_types=['avg', 'max'], tabattention=True, tab_dim=6):
        super(TemporalGate, self).__init__()
        self.tabattention = tabattention
        self.tab_dim = tab_dim
        self.gate_frames = gate_frames
        self.pool_types = pool_types
        if tabattention:
            self.pool_types = ['avg', 'max', 'tab']
            self.tab_embedding = nn.Sequential(
                nn.Linear(tab_dim, gate_frames // 2),
                nn.ReLU(),
                nn.Linear(gate_frames // 2, gate_frames)
            )
        if tabattention:
            self.mhsa = TemporalMHSA(input_dim=3, seq_len=self.gate_frames)
        else:
            self.mhsa = TemporalMHSA(input_dim=2, seq_len=self.gate_frames)

    def forward(self, x, tab=None):
        avg_pool = F.avg_pool3d(x, (x.size(2), x.size(3), x.size(4))).reshape(-1, self.gate_frames, 1)
        max_pool = F.max_pool3d(x, (x.size(2), x.size(3), x.size(4))).reshape(-1, self.gate_frames, 1)

        if self.tabattention:
            embedded = self.tab_embedding(tab)
            tab_embedded = torch.reshape(embedded, (-1, self.gate_frames, 1))
            concatenated = torch.cat((avg_pool, max_pool, tab_embedded), dim=2)
        else:
            concatenated = torch.cat((avg_pool, max_pool), dim=2)

        scale = torch.sigmoid(self.mhsa(concatenated)).unsqueeze(2).unsqueeze(3).expand_as(x)

        return x * scale


def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs


class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)


class SpatialGate(nn.Module):
    def __init__(self, tabattention=True, tab_dim=6, input_size=(8, 8)):
        super(SpatialGate, self).__init__()
        self.tabattention = tabattention
        self.tab_dim = tab_dim
        self.input_size = input_size
        kernel_size = 7
        self.compress = ChannelPool()
        in_planes = 3 if tabattention else 2
        self.spatial = BasicConv(in_planes, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)
        if self.tabattention:
            self.tab_embedding = nn.Sequential(
                nn.Linear(tab_dim, input_size[0] * input_size[1] // 2),
                nn.ReLU(),
                nn.Linear(input_size[0] * input_size[1] // 2, input_size[0] * input_size[1])
            )

    def forward(self, x, tab=None):
        x_compress = self.compress(x)
        if self.tabattention:
            embedded = self.tab_embedding(tab)
            embedded = torch.reshape(embedded, (-1, 1, self.input_size[0], self.input_size[1]))
            x_compress = torch.cat((x_compress, embedded), dim=1)

        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out) # broadcasting
        return x * scale



if __name__ == '__main__':
    x_input = torch.randn(1, 64, 16, 16, 4)
    tab_input = torch.randn(1, 1, 6)
    input_dim = (64, 16, 16, 4)
    block = TabAttention(input_dim=input_dim, tab_dim=6)
    output = block(x_input, tab_input)
    print(x_input.size())    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章