1.感知模块(TrackFormer 和 MapFormer)这部分模块作者直接采用已有的框架,因此,我们不再重点介绍这部分的算法实现!感兴趣的铁子请持续关注端到端自动驾驶!
2.运动预测模块(MotionFormer)
MotionFormer是由三层 Transformer 构成的解码器,每个 Transformer 层有三个并行的交叉注意力模块,分别为对象-对象交互、对象-地图交互,以及对象-目标交互注意力模块。它们分别将上游跟踪模块的查询 ,建图模块的查询 和鸟瞰特征 作为键值,公式如下:
motion_head = dict(
type='MotionHead',
transformerlayers = dict(
# 解码器层,用于处理运动预测
type='MotionTransformerDecoder',
# 意图交互层,处理agent与目标的交互
intention_interaction_layers = IntentionInteraction(),
# agent-agent交互层,用于处理多agent之间的交互,使用多层结构
track_agent_interaction_layers = nn.ModuleList(
[TrackAgentInteraction() for i in range(self.num_layers)]
),
# agent-地图交互层,用于处理agent与地图元素的交互,使用多层结构
map_interaction_layers = nn.ModuleList(
[MapInteraction() for i in range(self.num_layers)]
),
# agent-BEV交互层,用于处理agent与鸟瞰视角特征的交互,使用多层结构
bev_interaction_layers = nn.ModuleList(
[build_transformer_layer(transformerlayers) for i in range(self.num_layers)]
),
# 查询转换的MLP层,用于融合和处理静态与动态特征
static_dynamic_fuser = nn.Sequential(
nn.Linear(self.embed_dims * 2, self.embed_dims * 2),
nn.ReLU(),
nn.Linear(self.embed_dims * 2, self.embed_dims),
),
# 动态嵌入特征融合的MLP层
dynamic_embed_fuser = nn.Sequential(
nn.Linear(self.embed_dims * 3, self.embed_dims * 2),
nn.ReLU(),
nn.Linear(self.embed_dims * 2, self.embed_dims),
),
# 输入查询融合的MLP层
in_query_fuser = nn.Sequential(
nn.Linear(self.embed_dims * 2, self.embed_dims * 2),
nn.ReLU(),
nn.Linear(self.embed_dims * 2, self.embed_dims),
),
# 输出查询融合的MLP层
out_query_fuser = nn.Sequential(
nn.Linear(self.embed_dims * 4, self.embed_dims * 2),
nn.ReLU(),
nn.Linear(self.embed_dims * 2, self.embed_dims),
),
transformerlayers = dict(
# agent-目标交互层,用于处理agent与目标的交互
type='MotionTransformerAttentionLayer',
# 注意力机制的配置,使用可变形注意力机制
attn_cfgs = [
dict(
type='MotionDeformableAttention',
),
],
# 操作顺序:交叉注意力、归一化、前馈网络、归一化
operation_order = ('cross_attn', 'norm', 'ffn', 'norm')
),
),
)
此外,对核心代码进行进一轮的深挖:
class MotionHead(BaseMotionHead):
def forward(self,
bev_embed,
track_query,
lane_query,
lane_query_pos,
track_bbox_results):
"""
该函数执行模型的前向传播,用于基于鸟瞰图(BEV)嵌入、轨迹查询、车道查询和轨迹边界框结果进行运动预测。
参数:
bev_embed (torch.Tensor):形状为 (h*w, B, D) 的张量,表示鸟瞰图嵌入。
track_query (torch.Tensor):形状为 (B, num_dec, A_track, D) 的张量,表示轨迹查询。
lane_query (torch.Tensor):形状为 (N, M_thing, D) 的张量,表示车道查询。
lane_query_pos (torch.Tensor):形状为 (N, M_thing, D) 的张量,表示车道查询的位置。
track_bbox_results (List[torch.Tensor]):包含批次中每个图像的跟踪边界框结果的张量列表。
返回值:
dict:包含以下键和值的字典:
'all_traj_scores':形状为 (num_levels, B, A_track, num_points) 的张量,包含每个级别的轨迹分数。
'all_traj_preds':形状为 (num_levels, B, A_track, num_points, num_future_steps, 2) 的张量,包含每个级别的预测轨迹。
'valid_traj_masks':形状为 (B, A_track) 的张量,指示轨迹掩码的有效性。
'traj_query':包含轨迹查询中间状态的张量。
'track_query':包含输入轨迹查询的张量。
'track_query_pos':包含轨迹查询位置嵌入的张量。
"""
# 构造agent级别/场景级别的查询位置嵌入
# (num_groups, num_anchor, 12, 2)
# 以融入不同组和坐标的信息,并嵌入方向和位置信息
agent_level_anchors = self.kmeans_anchors.to(dtype).to(device).view(num_groups, self.num_anchor, self.predict_steps, 2).detach()
scene_level_ego_anchors = anchor_coordinate_transform(agent_level_anchors, track_bbox_results, with_translation_transform=True) # B, A, G, P ,12 ,2
scene_level_offset_anchors = anchor_coordinate_transform(agent_level_anchors, track_bbox_results, with_translation_transform=False) # B, A, G, P ,12 ,2
# 对锚点进行归一化
agent_level_norm = norm_points(agent_level_anchors, self.pc_range)
scene_level_ego_norm = norm_points(scene_level_ego_anchors, self.pc_range)
scene_level_offset_norm = norm_points(scene_level_offset_anchors, self.pc_range)
# 仅使用锚点的最后一个点
agent_level_embedding = self.agent_level_embedding_layer(
pos2posemb2d(agent_level_norm[..., -1, :])) # G, P, D
scene_level_ego_embedding = self.scene_level_ego_embedding_layer(
pos2posemb2d(scene_level_ego_norm[..., -1, :])) # B, A, G, P , D
scene_level_offset_embedding = self.scene_level_offset_embedding_layer(
pos2posemb2d(scene_level_offset_norm[..., -1, :])) # B, A, G, P , D
outputs_traj_scores = []
outputs_trajs = []
# 通过MotionFormer模型进行前向传播
# 输入各种查询、位置、边界框结果、BEV嵌入、初始参考轨迹等
# 以及锚点嵌入和锚点位置嵌入层
inter_states, inter_references = self.motionformer(
track_query, # B, A_track, D
lane_query, # B, M, D
track_query_pos=track_query_pos,
lane_query_pos=lane_query_pos,
track_bbox_results=track_bbox_results,
bev_embed=bev_embed,
reference_trajs=init_reference,
traj_reg_branches=self.traj_reg_branches,
traj_cls_branches=self.traj_cls_branches,
# 锚点嵌入
agent_level_embedding=agent_level_embedding,
scene_level_ego_embedding=scene_level_ego_embedding,
scene_level_offset_embedding=scene_level_offset_embedding,
learnable_embed=learnable_embed,
# 锚点位置嵌入层
agent_level_embedding_layer=self.agent_level_embedding_layer,
scene_level_ego_embedding_layer=self.scene_level_ego_embedding_layer,
scene_level_offset_embedding_layer=self.scene_level_offset_embedding_layer,
spatial_shapes=torch.tensor(
[[self.bev_h, self.bev_w]], device=device),
level_start_index=torch.tensor([0], device=device))
# 遍历每个级别,计算轨迹分数和轨迹
for lvl in range(inter_states.shape[0]):
outputs_class = self.traj_cls_branches[lvl](inter_states[lvl])
tmp = self.traj_reg_branches[lvl](inter_states[lvl])
tmp = self.unflatten_traj(tmp)
# 使用累积和技巧来获取轨迹
tmp[..., :2] = torch.cumsum(tmp[..., :2], dim=3)
outputs_class = self.log_softmax(outputs_class.squeeze(3))
outputs_traj_scores.append(outputs_class)
# 对每个批次应用双变量高斯激活
for bs in range(tmp.shape[0]):
tmp[bs] = bivariate_gaussian_activation(tmp[bs])
outputs_trajs.append(tmp)
# 堆叠并输出轨迹分数和轨迹
outputs_traj_scores = torch.stack(outputs_traj_scores)
outputs_trajs = torch.stack(outputs_trajs)
# 获取轨迹查询的有效性掩码
B, A_track, D = track_query.shape
valid_traj_masks = track_query.new_ones((B, A_track)) > 0
# 构造输出字典
outs = {
'all_traj_scores': outputs_traj_scores,
'all_traj_preds': outputs_trajs,
'valid_traj_masks': valid_traj_masks,
'traj_query': inter_states,
'track_query': track_query,
'track_query_pos': track_query_pos,
}
return outs
MotionTransformerDecoder:实现了基于轨迹查询、车道查询、位置嵌入和鸟瞰图(BEV)嵌入等输入的运动预测。该函数首先对各种查询和嵌入进行处理,包括静态和动态意图嵌入,然后通过多层的代理间交互、代理与地图间交互、代理与目标(BEV)间交互,逐层融合和更新查询嵌入。最后,将各层次的查询嵌入进行融合,输出中间结果,用于后续的轨迹预测和行为分析。
class MotionTransformerDecoder(BaseModule):
def forward(self,
track_query,
lane_query,
track_query_pos=None,
lane_query_pos=None,
track_bbox_results=None,
bev_embed=None,
reference_trajs=None,
traj_reg_branches=None,
agent_level_embedding=None,
scene_level_ego_embedding=None,
scene_level_offset_embedding=None,
learnable_embed=None,
agent_level_embedding_layer=None,
scene_level_ego_embedding_layer=None,
scene_level_offset_embedding_layer=None,
:
"""
的前向传播函数。
:
track_query (torch.Tensor): 形状为 (B, A, D),表示代理查询,其中 B 为批次大小,A 为代理数,D 为特征维度。
lane_query (torch.Tensor): 形状为 (B, M, D),表示地图查询,其中 M 为地图对象数。
track_query_pos (torch.Tensor, optional): 轨迹查询位置。
lane_query_pos (torch.Tensor, optional): 车道查询位置。
track_bbox_results (List[torch.Tensor], optional): 跟踪边界框结果。
bev_embed (torch.Tensor, optional): 鸟瞰图嵌入。
reference_trajs (torch.Tensor, optional): 参考轨迹。
traj_reg_branches (List[torch.nn.Module], optional): 轨迹回归分支。
agent_level_embedding (torch.Tensor, optional): 代理级别嵌入。
scene_level_ego_embedding (torch.Tensor, optional): 场景级别自我嵌入。
scene_level_offset_embedding (torch.Tensor, optional): 场景级别偏移嵌入。
learnable_embed (torch.Tensor, optional): 可学习嵌入。
agent_level_embedding_layer (torch.nn.Module, optional): 代理级别嵌入层。
scene_level_ego_embedding_layer (torch.nn.Module, optional): 场景级别自我嵌入层。
scene_level_offset_embedding_layer (torch.nn.Module, optional): 场景级别偏移嵌入层。
kwargs: 其他额外参数。
:
None
"""
intermediate = [] # 用于存储中间输出的列表
intermediate_reference_trajs = [] # 用于存储中间参考轨迹的列表
# 对输入进行广播和扩展,以匹配所需的形状
_, P, D = agent_level_embedding.shape
track_query_bc = track_query.unsqueeze(2).expand(-1, -1, P, -1) # (B, A, P, D)
track_query_pos_bc = track_query_pos.unsqueeze(2).expand(-1, -1, P, -1) # (B, A, P, D)
# 计算静态意图嵌入,它在所有层中都是不变的
agent_level_embedding = self.intention_interaction_layers(agent_level_embedding)
static_intention_embed = agent_level_embedding + scene_level_offset_embedding + learnable_embed
reference_trajs_input = reference_trajs.unsqueeze(4).detach()
# 初始化查询嵌入,其形状与静态意图嵌入相同
query_embed = torch.zeros_like(static_intention_embed)
for lid in range(self.num_layers):
# 融合动态意图嵌入
# 动态意图嵌入是前一层的输出,初始化为锚点嵌入(anchor embedding)
dynamic_query_embed = self.dynamic_embed_fuser(torch.cat(
scene_level_offset_embedding, scene_level_ego_embedding], dim=-1))
# 融合静态和动态意图嵌入
query_embed_intention = self.static_dynamic_fuser(torch.cat(
dynamic_query_embed], dim=-1)) # (B, A, P, D)
# 将意图嵌入与查询嵌入融合
query_embed = self.in_query_fuser(torch.cat([query_embed, query_embed_intention], dim=-1))
# 代理之间的交互
track_query_embed = self.track_agent_interaction_layers[lid](
track_query, query_pos=track_query_pos_bc, key_pos=track_query_pos)
# 代理与地图之间的交互
map_query_embed = self.map_interaction_layers[lid](
lane_query, query_pos=track_query_pos_bc, key_pos=lane_query_pos)
# 代理与目标(BEV,即鸟瞰图)之间的交互,使用可变形Transformer实现
bev_query_embed = self.bev_interaction_layers[lid](
query_embed,
value=bev_embed,
query_pos=track_query_pos_bc,
bbox_results=track_bbox_results,
reference_trajs=reference_trajs_input,
**kwargs)
# 融合来自不同交互层的嵌入
query_embed = [track_query_embed, map_query_embed, bev_query_embed, track_query_bc+track_query_pos_bc]
query_embed = torch.cat(query_embed, dim=-1)
query_embed = self.out_query_fuser(query_embed)
3. 占用预测模块(OccFormer)
class OccHead(BaseModule):
def forward(self, x, ins_query):
# 重新排列输入特征图以匹配预期的形状
base_state = rearrange(x, '(h w) b d -> b d h w', h=self.bev_size[0])
# 对特征图进行采样、投影和下采样处理
base_state = self.bev_sampler(base_state)
base_state = self.bev_light_proj(base_state)
base_state = self.base_downscale(base_state)
# 初始化查询和状态变量
last_state = base_state
last_ins_query = ins_query
future_states = []
mask_preds = []
temporal_query = []
temporal_embed_for_mask_attn = []
# 确定每个块的 Transformer 层数
n_trans_layer_each_block = self.num_trans_layers // self.n_future_blocks
assert n_trans_layer_each_block >= 1
# 遍历未来的块
for i in range(self.n_future_blocks):
# 下采样当前状态
cur_state = self.downscale_convs[i](last_state)
# 处理时间感知查询
cur_ins_query = self.temporal_mlps[i](last_ins_query)
temporal_query.append(cur_ins_query)
# 生成注意力掩码和掩码预测
mask_pred, cur_ins_emb_for_mask_attn = self.get_attn_mask(cur_state, cur_ins_query)
attn_masks = [None, attn_mask]
mask_preds.append(mask_pred)
temporal_embed_for_mask_attn.append(cur_ins_emb_for_mask_attn)
# 重新排列状态和查询以适应 Transformer 输入
cur_state = rearrange(cur_state, 'b c h w -> (h w) b c')
cur_ins_query = rearrange(cur_ins_query, 'b q c -> q b c')
# 遍历每层 Transformer
for j in range(n_trans_layer_each_block):
trans_layer_ind = i * n_trans_layer_each_block + j
trans_layer = self.transformer_decoder.layers[trans_layer_ind]
cur_state = trans_layer(
query=cur_state,
key=cur_ins_query,
value=cur_ins_query,
query_pos=None,
key_pos=None,
attn_masks=attn_masks,
query_key_padding_mask=None,
key_padding_mask=None
)
# 重新排列状态并进行上采样
cur_state = rearrange(cur_state, '(h w) b c -> b c h w', h=self.bev_size[0] // 8)
cur_state = self.upsample_adds[i](cur_state, last_state)
future_states.append(cur_state)
last_state = cur_state
# 堆叠未来的状态、时间查询、掩码预测和查询嵌入
future_states = torch.stack(future_states, dim=1)
temporal_query = torch.stack(temporal_query, dim=1)
mask_preds = torch.stack(mask_preds, dim=2)
ins_query = torch.stack(temporal_embed_for_mask_attn, dim=1)
# 将未来状态解码到更大的分辨率
future_states = self.dense_decoder(future_states)
ins_occ_query = self.query_to_occ_feat(ins_query)
# 生成最终输出
ins_occ_logits = torch.einsum("btqc,btchw->bqthw", ins_occ_query, future_states)
return mask_preds, ins_occ_logits
4. 规划模块(Planner)
首先,Planner 模块将输入的轨迹查询、车道查询等通过多头自注意力机制 (MHSA) 和交叉注意力机制 (MHCA) 进行处理, 公式如下:
class PlanningHeadSingleMode(nn.Module):
def forward(self,
bev_embed, # BEV(鸟瞰图)特征嵌入
occ_mask, # 占用实例掩码
bev_pos, # BEV位置
sdc_traj_query, # SDC轨迹查询
sdc_track_query, # SDC轨迹追踪查询
command): # 驾驶命令
"""
前向传播过程。
参数:
bev_embed (torch.Tensor): 鸟瞰图特征嵌入。
occ_mask (torch.Tensor): 占用实例掩码。
bev_pos (torch.Tensor): BEV位置。
sdc_traj_query (torch.Tensor): SDC轨迹查询。
sdc_track_query (torch.Tensor): SDC轨迹追踪查询。
command (int): 驾驶命令。
返回:
dict: 包含SDC轨迹和所有SDC轨迹的字典。
"""
# 根据驾驶命令获取导航嵌入
navi_embed = self.navi_embed.weight[command]
navi_embed = navi_embed[None].expand(-1, P, -1)
# 融合SDC轨迹查询、SDC轨迹追踪查询和导航嵌入
plan_query = torch.cat([sdc_traj_query, sdc_track_query, navi_embed], dim=-1)
# 使用多层感知机(MLP)融合查询,并取最大值
plan_query = self.mlp_fuser(plan_query).max(1, keepdim=True)[0]
# 重排plan_query的形状
plan_query = rearrange(plan_query, 'b p c -> p b c')
# 重排bev_pos的形状
bev_pos = rearrange(bev_pos, 'b c h w -> (h w) b c')
bev_feat = bev_embed + bev_pos
# 插件适配器
if self.with_adapter:
bev_feat = rearrange(bev_feat, '(h w) b c -> b c h w', h=self.bev_h, w=self.bev_w)
bev_feat = bev_feat + self.bev_adapter(bev_feat) # 残差连接
bev_feat = rearrange(bev_feat, 'b c h w -> (h w) b c')
# 添加位置嵌入
pos_embed = self.pos_embed.weight
plan_query = plan_query + pos_embed[None]
# 使用注意力模块处理plan_query和bev_feat
plan_query = self.attn_module(plan_query, bev_feat)
# 回归分支,生成SDC轨迹
sdc_traj_all = self.reg_branch(plan_query).view((-1, self.planning_steps, 2))
# 累计求和,生成轨迹点
sdc_traj_all[...,:2] = torch.cumsum(sdc_traj_all[...,:2], dim=1)
# 对第一条轨迹应用双变量高斯激活
sdc_traj_all[0] = bivariate_gaussian_activation(sdc_traj_all[0])
# 如果使用碰撞优化且非训练模式,进行后处理
if self.use_col_optim and not self.training:
assert occ_mask is not None
sdc_traj_all = self.collision_optimization(sdc_traj_all, occ_mask)
# 返回SDC轨迹和所有SDC轨迹
return dict(
sdc_traj=sdc_traj_all,
sdc_traj_all=sdc_traj_all,
)
结合两期内容,我们重温了UniAD作为端到端自动驾驶的开山级研究。第一期进行UniAD模型的框架,数据流,训练以及实验部分进行方法论层面的解读,第二期对具体策略进一步深挖,同步对源码进行逐行解读辅助理解!本篇论文完结撒花~