SAM 2: Segment Anything in Images and Videos
Github: https://github.com/facebookresearch/segment-anything-2
Demo: https://sam2.metademolab.com/
Data: https://ai.meta.com/datasets/segment-anything-video/
简介
时隔一年半,Meta Research推出了Segment Anything Model的重磅更新!把图像上的SAM扩展到视频上(时空维度的交互实例分割)。
Segment Anything Model 2(SAM 2)是用于求解图像和视频中快速视觉分割的基础模型。通过将图像视为具有单一帧的视频,将SAM扩展到视频。模型设计是一种简单的transformer体系结构,具有streaming memory,用于实时视频处理。并且,同样使用数据飞轮,制作了SA-V数据集,这是迄今为止最大的视频分割数据集。其中包括大约51000个真实世界的视频和超过600000个mask(时空mask)
方法
SAM 2架构是SAM在视频领域的扩展,它通过点击、边界框或掩码提示来定义对象范围,并使用轻量级掩码解码器结合图像嵌入和编码提示输出分割掩码。SAM 2利用记忆机制,包括记忆编码器、记忆库和记忆注意力模块,来存储会话中对象和用户交互信息,实现在视频帧中生成和细化masklet预测。该架构支持实时处理长视频,适用于视频注释和实际应用如机器人领域。同时,SAM 2能够处理图像分割中的模糊性,通过创建多个有效掩码并选择最高置信度的掩码进行传播,以应对视频中的不确定性。
图像编码器(Image Encoder):
使用特征金字塔网络(Feature Pyramid Network)来融合来自Hiera图像编码器Stage 3和Stage 4的stride 16和32特征,生成每帧的图像嵌入。
来自Stage 1和Stage 2的stride 4和8特征不用于记忆注意力,而是添加到掩码解码器的上采样层中,以帮助产生高分辨率的分割细节。(类似HQ-SAM的思路)
采用窗口化的绝对位置嵌入(Windowed Absolute Positional Embeddings),并采用简单的全局位置嵌入插值方法代替RPB(Relative Positional Bias)。
记忆注意力(Memory Attention):
在自注意力和交叉注意力层中,还使用了苏神的旋转位置编码(RoPE)。
对象指针标记(Object Pointer Tokens)不包含RoPE,因为它们没有特定的空间对应关系。
默认情况下,记忆注意力使用4层图像嵌入。
提示编码器和掩码解码器(Prompt Encoder and Mask Decoder):
提示编码器的设计遵循SAM,掩码解码器的设计有所变化。
使用与输出掩码对应的掩码标记作为帧的对象指针标记,并将其放置在记忆库中。
引入了遮挡预测头(Occlusion Prediction Head),通过在掩码和IoU输出标记之外增加一个额外的标记,并应用一个额外的MLP头来生成一个分数,指示当前帧中感兴趣对象的可见性。
记忆编码器和记忆库(Memory Encoder and Memory Bank):
记忆编码器不使用额外的图像编码器,而是重用Hiera编码器生成的图像嵌入,并与预测的掩码信息融合以产生记忆特征。
将记忆特征在记忆库中投影到64维,并把256维的对象指针分割成4个64维的标记,用于与记忆库的交叉注意力。
处理模糊性(Handling Ambiguity):
当图像中被分割的对象存在模糊性时,SAM 2会在视频中的每个步骤预测多个掩码。
如果进一步的提示不能解决模糊性,模型会选择当前帧预测IoU最高的掩码进行进一步的视频传播。
SAM 2 在 17 个zero-shot视频数据集的交互式视频分割方面表现明显优于以前的方法,并且所需的人机交互减少了大约三倍。
SAM 2 在 23 个数据集zero-shot基准测试套件上的表现优于 SAM,而且SAM2比SAM速度快了六倍。
数据飞轮构建:
使用
安装
git clone git@github.com:facebookresearch/segment-anything-2.git
cd segment-anything-2; pip install -e .
图像上推理:
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)
视频上推理:
import torch
from sam2.build_sam import build_sam2_video_predictor
checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = build_sam2_video_predictor(model_cfg, checkpoint)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = predictor.init_state(<your_video>)
# add new prompts and instantly get the output on the same frame
frame_idx, object_ids, masks = predictor.add_new_points(state, <your prompts>):
# propagate the prompts to get masklets throughout the video
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
...
点击👇关注 “思源Source”
👇点个“赞”和“在看”吧