SAM 2 开源:视频上也能“分割一切物体”

文摘   科技   2024-07-30 11:35   上海  



思源Source报道
编辑:seefun
SAM 2: Segment Anything in Images and Videos

Github: https://github.com/facebookresearch/segment-anything-2

Demo: https://sam2.metademolab.com/

Data: https://ai.meta.com/datasets/segment-anything-video/


简介

时隔一年半,Meta Research推出了Segment Anything Model的重磅更新!把图像上的SAM扩展到视频上(时空维度的交互实例分割)。

Segment Anything Model 2(SAM 2)是用于求解图像和视频中快速视觉分割的基础模型。通过将图像视为具有单一帧的视频,将SAM扩展到视频。模型设计是一种简单的transformer体系结构,具有streaming memory,用于实时视频处理。并且,同样使用数据飞轮,制作了SA-V数据集,这是迄今为止最大的视频分割数据集。其中包括大约51000个真实世界的视频和超过600000个mask(时空mask)

方法

SAM 2架构是SAM在视频领域的扩展,它通过点击、边界框或掩码提示来定义对象范围,并使用轻量级掩码解码器结合图像嵌入和编码提示输出分割掩码。SAM 2利用记忆机制,包括记忆编码器、记忆库和记忆注意力模块,来存储会话中对象和用户交互信息,实现在视频帧中生成和细化masklet预测。该架构支持实时处理长视频,适用于视频注释和实际应用如机器人领域。同时,SAM 2能够处理图像分割中的模糊性,通过创建多个有效掩码并选择最高置信度的掩码进行传播,以应对视频中的不确定性。

图像编码器(Image Encoder):

使用特征金字塔网络(Feature Pyramid Network)来融合来自Hiera图像编码器Stage 3和Stage 4的stride 16和32特征,生成每帧的图像嵌入。

来自Stage 1和Stage 2的stride 4和8特征不用于记忆注意力,而是添加到掩码解码器的上采样层中,以帮助产生高分辨率的分割细节。(类似HQ-SAM的思路)

采用窗口化的绝对位置嵌入(Windowed Absolute Positional Embeddings),并采用简单的全局位置嵌入插值方法代替RPB(Relative Positional Bias)。

记忆注意力(Memory Attention):

在自注意力和交叉注意力层中,还使用了苏神的旋转位置编码(RoPE)

对象指针标记(Object Pointer Tokens)不包含RoPE,因为它们没有特定的空间对应关系。

默认情况下,记忆注意力使用4层图像嵌入。

提示编码器和掩码解码器(Prompt Encoder and Mask Decoder):

提示编码器的设计遵循SAM,掩码解码器的设计有所变化。

使用与输出掩码对应的掩码标记作为帧的对象指针标记,并将其放置在记忆库中。

引入了遮挡预测头(Occlusion Prediction Head),通过在掩码和IoU输出标记之外增加一个额外的标记,并应用一个额外的MLP头来生成一个分数,指示当前帧中感兴趣对象的可见性。

记忆编码器和记忆库(Memory Encoder and Memory Bank):

记忆编码器不使用额外的图像编码器,而是重用Hiera编码器生成的图像嵌入,并与预测的掩码信息融合以产生记忆特征。

将记忆特征在记忆库中投影到64维,并把256维的对象指针分割成4个64维的标记,用于与记忆库的交叉注意力。

处理模糊性(Handling Ambiguity):

当图像中被分割的对象存在模糊性时,SAM 2会在视频中的每个步骤预测多个掩码。

如果进一步的提示不能解决模糊性,模型会选择当前帧预测IoU最高的掩码进行进一步的视频传播。

SAM 2 在 17 个zero-shot视频数据集的交互式视频分割方面表现明显优于以前的方法,并且所需的人机交互减少了大约三倍。

SAM 2 在 23 个数据集zero-shot基准测试套件上的表现优于 SAM,而且SAM2比SAM速度快了六倍

数据飞轮构建:

使用

安装

git clone git@github.com:facebookresearch/segment-anything-2.gitcd segment-anything-2; pip install -e .

图像上推理:

import torchfrom sam2.build_sam import build_sam2from sam2.sam2_image_predictor import SAM2ImagePredictor
checkpoint = "./checkpoints/sam2_hiera_large.pt"model_cfg = "sam2_hiera_l.yaml"predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): predictor.set_image(<your_image>) masks, _, _ = predictor.predict(<input_prompts>)

视频上推理:

import torchfrom sam2.build_sam import build_sam2_video_predictor
checkpoint = "./checkpoints/sam2_hiera_large.pt"model_cfg = "sam2_hiera_l.yaml"predictor = build_sam2_video_predictor(model_cfg, checkpoint)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): state = predictor.init_state(<your_video>)
# add new prompts and instantly get the output on the same frame frame_idx, object_ids, masks = predictor.add_new_points(state, <your prompts>):
# propagate the prompts to get masklets throughout the video for frame_idx, object_ids, masks in predictor.propagate_in_video(state): ...


点击👇关注 “思源Source”

👇点个“赞”和“在看”吧

思源数据科学
Towards AGI
 最新文章