60行代码训练/微调 Segment Anything 2 (SAM 2)

文摘   2024-10-22 07:45   重庆  

点击下方卡片,关注“OpenCV与AI深度学习

视觉/图像重干货,第一时间送达!

    SAM2(Segment Anything 2)是 Meta 推出的一款新模型,旨在分割图像中的任何内容,而不局限于特定的类别或领域。该模型的独特之处在于其训练数据规模:1100 万张图像和 110 亿个掩码。这种广泛的训练使 SAM2 成为训练新图像分割任务的强大起点。

    你可能会问,如果 SAM 可以分割任何东西,为什么我们还需要重新训练它?答案是,SAM 非常擅长处理常见物体,但在罕见或特定领域的任务上表现不佳。
    然而,即使在 SAM 给出的结果不足的情况下,仍然可以通过对新数据进行微调来显著提高模型的能力。在许多情况下,这将比从头开始训练模型需要更少的训练数据并获得更好的结果。
    本教程演示了如何仅用 60 行代码(不包括注释和导入)对新数据进行 SAM2 微调。
    完整的训练脚本可以在以下位置找到:
https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code/blob/main/TRAIN.py?source=post_page-----928dd29a63b3--------------------------------

任何内容的细分工作原理
    SAM 的主要工作方式是获取图像和图像中的点,然后预测包含该点的片段的掩码。这种方法无需人工干预即可实现完整的图像分割,并且对片段的类别或类型没有任何限制。
    使用SAM进行全图像分割的流程:
    1. 选择图像中的一组点
    2. 使用 SAM 预测包含每个点的线段
    3, 将结果片段组合成一张地图
    虽然 SAM 也可以利用其他输入,例如蒙版或边界框,但这些输入主要与涉及人工输入的交互式分割相关。在本教程中,我们将重点介绍全自动分割,并且仅考虑单点输入。
下载SAM2并设置环境
    SAM2 可以从以下位置下载:
https://github.com/facebookresearch/segment-anything-2?source=post_page-----928dd29a63b3--------------------------------
    如果您不想复制训练代码,您也可以下载我分叉的版本,其中已经包含TRAIN.py脚本。
https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code?source=post_page-----928dd29a63b3--------------------------------
    按照 github 存储库上的安装说明进行操作。
    一般来说,你需要 Python >=3.11 和PyTorch。
    此外,我们将使用 OpenCV,可以使用以下命令安装:
pip install opencv-python
下载预先训练的模型
    您还需要从以下位置下载预先训练的模型:
https://github.com/facebookresearch/segment-anything-2?tab=readme-ov-file#download-checkpoints
    有几种模型可供您选择,所有模型均与本教程兼容。我建议使用训练速度最快的小型模型。
下载训练数据
    在本教程中,我们将使用LabPics1 数据集来分割材料和液体。您可以从此 URL 下载数据集:
https://zenodo.org/records/3697452/files/LabPicsV1.zip?download=1
准备数据读取器
    我们需要编写的第一件事是数据读取器。它将读取并准备网络的数据。
    数据读取者需要生成:
    1. 一张图片
    2. 图像中所有片段的蒙版。
    3. 每个蒙版内的一个随机点
    让我们开始加载依赖项:
import numpy as npimport torchimport cv2import osfrom sam2.build_sam import build_sam2from sam2.sam2_image_predictor import SAM2ImagePredictor
    接下来我们列出数据集中的所有图像:
data_dir=r"LabPicsV1//" # Path to LabPics1 dataset folderdata=[] # list of files in datasetfor ff, name in enumerate(os.listdir(data_dir+"Simple/Train/Image/")):  # go over all folder annotation    data.append({"image":data_dir+"Simple/Train/Image/"+name,"annotation":data_dir+"Simple/Train/Instance/"+name[:-4]+".png"})
    现在介绍加载训练批次的主要函数。训练批次包括:一张随机图像、属于该图像的所有分割掩码以及每个掩码中的一个随机点:
def read_batch(data): # read random image and its annotaion from  the dataset (LabPics)
# select image
ent = data[np.random.randint(len(data))] # choose random entry Img = cv2.imread(ent["image"])[...,::-1] # read image ann_map = cv2.imread(ent["annotation"]) # read annotation
# resize image
r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) # scalling factor Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r))) ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),interpolation=cv2.INTER_NEAREST)
# merge vessels and materials annotations
mat_map = ann_map[:,:,0] # material annotation map ves_map = ann_map[:,:,2] # vessel annotaion map mat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # merged map
# Get binary masks and points
inds = np.unique(mat_map)[1:] # load all indices points= [] masks = [] for ind in inds: mask=(mat_map == ind).astype(np.uint8) # make binary mask masks.append(mask) coords = np.argwhere(mask > 0) # get all coordinates in mask yx = np.array(coords[np.random.randint(len(coords))]) # choose random point/coordinate points.append([[yx[1], yx[0]]]) return Img,np.array(masks),np.array(points), np.ones([len(masks),1])
    该函数的第一部分是选择一个随机图像并加载它:
ent  = data[np.random.randint(len(data))] # choose random entryImg = cv2.imread(ent["image"])[...,::-1]  # read imageann_map = cv2.imread(ent["annotation"]) # read annotation
    请注意,OpenCV 将图像读取为 BGR,而 SAM 需要 RGB 图像。通过使用[…,::-1],我们将图像从 BGR 更改为 RGB。
    SAM 预计图像大小不超过 1024,因此我们将把图像和注释图的大小调整为此大小。
r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) # scalling factorImg = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),interpolation=cv2.INTER_NEAREST)

    这里的一个重点是,在调整注释地图 ( ann_map ) 的大小时,我们使用INTER_NEAREST模式(最近邻)。在注释地图中,每个像素值都是其所属段的索引。因此,使用不会向地图引入新值的调整大小方法非常重要。

    下一个块特定于 LabPics1 数据集的格式。注释图 ( ann_map ) 包含一个通道中图像中血管的分割图,以及另一个通道中材料注释的图。我们将把它们合并为一张图。

mat_map = ann_map[:,:,0] # material annotation mapves_map = ann_map[:,:,2] # vessel  annotaion mapmat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # merged map
    这为我们提供了一个映射 ( mat_map ),其中每个像素的值是其所属段的索引(例如:所有值为 3 的单元格都属于段 3)。我们希望将其转换为一组二进制掩码 (0/1),其中每个掩码对应不同的段。此外,我们希望从每个掩码中提取一个点。
inds = np.unique(mat_map)[1:] # list of all indices in mappoints= [] # list of all points (one for each mask)masks = [] # list of all masksfor ind in inds:            mask = (mat_map == ind).astype(np.uint8) # make binary mask for index ind            masks.append(mask)            coords = np.argwhere(mask > 0) # get all coordinates in mask            yx = np.array(coords[np.random.randint(len(coords))]) # choose random point/coordinate            points.append([[yx[1], yx[0]]])return Img,np.array(masks),np.array(points), np.ones([len(masks),1])
    我们得到了图像(Img)、与图像中的片段相对应的二进制掩码列表(mask),以及每个掩码内单个点的坐标(points)。
加载 SAM 模型
    现在让我们加载网络:
sam2_checkpoint = "sam2_hiera_small.pt" # path to model weightmodel_cfg = "sam2_hiera_s.yaml" # model configsam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") # load modelpredictor = SAM2ImagePredictor(sam2_model) # load net
    首先,我们在sam2_checkpoint参数中设置模型权重的路径。我们之前从这里 下载了权重。“sam2_hiera_small.pt”指的是小模型,但代码适用于任何模型。无论你选择哪种模型,都需要在model_cfg参数中设置相应的配置文件。配置文件位于主存储库的子文件夹“ sam2_configs/”中。
细分任何事物 总体结构
    在开始训练之前,我们需要了解模型的结构。
    SAM由三部分组成:
    1)图像编码器,2)提示编码器,3)掩码解码器。
    图像编码器负责处理图像并创建图像嵌入。这是最大的组件,训练它需要强大的 GPU。
    提示编码器处理输入提示,在我们的例子中是输入点。
    掩码解码器获取图像编码器和提示编码器的输出并生成最终的分割掩码。
设置训练参数:
    我们可以通过设置来启用掩码解码器和提示编码器的训练:
predictor.model.sam_mask_decoder.train(True) # enable training of mask decoder predictor.model.sam_prompt_encoder.train(True) # enable training of prompt encoder
    您可以使用以下方式启用图像编码器的训练:
“ predictor.model.image_encoder.train(True)”
    这将需要更强大的 GPU,但会为网络提供更多的改进空间。如果您选择训练图像编码器,则必须扫描 SAM2 代码以查找“ no_grad”命令并将其删除。(no_grad会阻止梯度收集,这可以节省内存但会阻止训练)。
    接下来,我们定义标准的adamW优化器:
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)
    我们还将使用混合精度训练,这是一种更节省内存的训练策略:
scaler = torch.cuda.amp.GradScaler() # set mixed precision
主训练循环
    现在让我们构建主要的训练循环。第一部分是读取和准备数据:
for itr in range(100000):    with torch.cuda.amp.autocast(): # cast to mix precision            image,mask,input_point, input_label = read_batch(data) # load data batch            if mask.shape[0]==0: continue # ignore empty batches            predictor.set_image(image) # apply SAM image encoder to the image

    首先,我们将数据转换为混合精度以实现高效训练:

with torch.cuda.amp.autocast():

    接下来,我们使用之前创建的读取器函数来读取训练数据:

image,mask,input_point, input_label = read_batch(data)

    我们将加载的图像传递给图像编码器(网络的第一部分):

predictor.set_image(image)

    接下来,我们使用网络提示编码器处理输入点:

mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)  sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)
    请注意,在此部分我们也可以输入框或掩码,但我们不会使用这些选项。
    现在我们对提示(点)和图像进行了编码,最终我们可以预测分割掩码:
batched_mode = unnorm_coords.shape[0] > 1 # multi mask predictionhigh_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=batched_mode,high_res_features=high_res_features,)prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution
    该代码的主要部分是model.sam_mask_decoder,它运行网络的 mask_decoder 部分并生成分割掩码(low_res_masks)及其分数(prd_scores)。
    这些掩码的分辨率低于原始输入图像,并在postprocess_masks 函数中调整为原始输入大小。
    这给了我们网络的最终预测:我们使用的每个输入点的3 个分割掩码( prd_masks )和掩码分数( prd_scores)。prd_masks包含每个输入点的 3 个预测掩码,但我们只会使用每个点的第一个掩码。prd_scores包含网络认为每个掩码有多好(或它在预测中的确定性)的分数。
损失函数
    分割损失
    现在我们有了净预测,我们可以计算损失了。首先,我们计算分割损失,这意味着预测的掩模与地面真实掩模相比有多好。为此,我们使用标准交叉熵损失。
    首先,我们需要使用 sigmoid 函数将预测掩码( prd_mask )从 logit 转换为概率:
prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map
    接下来我们将真实情况掩码转换为 torch 张量:
prd_mask = torch.sigmoid(prd_masks[:, 0 ]) # 将logit图转为概率图

    最后,我们使用基本事实(gt_mask)和预测概率图(prd_mask )手动计算交叉熵损失(seg_loss ) :

seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001 ) - ( 1 - gt_mask) * torch.log(( 1 - prd_mask) + 0.00001 )).mean() # 交叉熵损失
(我们添加 0.0001 以防止对数函数因零值而爆炸)。
Score loss(可选)
    除了掩码之外,网络还会预测每个预测掩码的好坏得分。训练这部分不太重要,但很有用。要训练这部分,我们首先需要知道每个预测掩码的真实得分是多少。也就是说,预测掩码实际上有多好。我们将通过使用交并比 (IOU) 指标比较 GT 掩码和相应的预测掩码来做到这一点。IOU 只是两个掩码之间的重叠部分除以两个掩码的总面积。首先,我们计算预测掩码和 GT 掩码之间的交集(它们重叠的面积):
inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
    我们使用阈值(prd_mask > 0.5)将预测掩码从概率转换为二元掩码。
    接下来,我们通过将交点除以预测和 gt 掩码的组合面积(并集)来获得 IOU:
iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)

    我们将使用 IOU 作为每个 mask 的真实分数,并将预测分数与我们刚刚计算的 IOU 之间的绝对差作为分数损失。

score_loss = torch.abs ( prd_scores[:, 0 ] - iou).mean()

    最后,我们合并分割损失和分数损失(给予第一个更高的权重):

loss = seg_loss+score_loss* 0.05   # 混合损失
最后一步:反向传播和保存模型
    一旦我们得到了损失,一切都完全标准化了。我们使用之前制作的优化器计算反向传播并更新权重:
predictor.model.zero_grad() # 空梯度scaler.scale(loss).backward()   # 反向传播scaler.step(optimizer) scaler.update() # 混合精度

    我们还希望每 1000 步保存一次训练好的模型:

if itr%1000==0: torch.save(predictor.model.state_dict(), "model.torch") # save model

    由于我们已经计算了 IOU,我们可以将其显示为移动平均值,以查看模型预测随时间的改善程度:

if itr==0: mean_iou=0mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())print("step)",itr, "Accuracy(IOU)=",mean_iou)
    就这样,我们用不到 60 行代码(不包括注释和导入)训练/微调了 Segment-Anything 2。经过大约 25,000 步后,您应该会看到显著的改进。
    该模型将保存到“model.torch”。
    你可以在此处找到完整的训练代码:
https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code/blob/main/TRAIN.py?source=post_page-----928dd29a63b3--------------------------------
    本教程每批使用单个图像,更有效的方法是每批使用多个不同的图像,执行此操作的代码位于:
https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code/blob/main/TRAIN_multi_image_batch.py
推理:加载并使用训练好的模型:
    现在模型已经微调,让我们用它来分割图像。
    我们将按照以下步骤进行操作:
    加载我们刚刚训练的模型。
    给模型一张图像和一堆随机点。对于每个点,网络将预测包含该点的片段掩码和一个分数。
    将这些蒙版拼接起来形成一张分割图。
    完整代码可从以下位置获取:
https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code/blob/main/TEST_Net.py?source=post_page-----928dd29a63b3--------------------------------

    首先,我们加载依赖项并将权重转换为 float16,这使得模型运行速度更快(仅适用于推理)。

import numpy as np import torch import cv2 from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor 
# 对整个脚本使用 bfloat16(节省内存) torch.autocast(device_type= "cuda" , dtype=torch.bfloat16).__enter__()

    接下来,我们加载示例图像和想要分割的图像区域的蒙版(下载图像/蒙版):

image_path = r"sample_image.jpg" # path to imagemask_path = r"sample_mask.png" # path to mask, the mask will define the image region to segmentdef read_image(image_path, mask_path): # read and resize image and mask        img = cv2.imread(image_path)[...,::-1]  # read image as rgb        mask = cv2.imread(mask_path,0) # mask of the region we want to segment
# Resize image to maximum size of 1024
r = np.min([1024 / img.shape[1], 1024 / img.shape[0]]) img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r))) mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)),interpolation=cv2.INTER_NEAREST) return img, maskimage,mask = read_image(image_path, mask_path)

    在我们要分割的区域内随机取样30个点:

num_samples = 30  # 要采样的点/段的数量def  get_points ( mask, num_points ): # 输入掩码内的采样点 points         =[]         for i in  range (num_points):             coords = np.argwhere(mask > 0 )             yx = np.array(coords[np.random.randint( len (coords))])             points.append([[yx[ 1 ], yx[ 0 ]]])         return np.array(points) input_points = get_points(mask, num_samples)

    加载标准 SAM 模型(与训练中相同)

# 加载模型您需要已经制作好预训练模型sam2_checkpoint = "sam2_hiera_small.pt"  model_cfg = "sam2_hiera_s.yaml"  sam2_model = build_sam2(model_cfg, sam2_checkpoint, device= "cuda" ) predictor = SAM2ImagePredictor(sam2_model)

    接下来,加载我们刚刚训练的模型的权重(model.torch):

predictor.model.load_state_dict(torch.load("model.torch"))
    运行微调模型来预测我们之前选择的每个点的分割掩模:
with torch.no_grad(): # prevent the net from caclulate gradient (more efficient inference)        predictor.set_image(image) # image encoder        masks, scores, logits = predictor.predict(  # prompt encoder + mask decoder            point_coords=input_points,            point_labels=np.ones([input_points.shape[0],1])        )

    现在我们有了预测的掩码及其分数的列表。我们希望以某种方式将它们拼接成一个一致的分割图。但是,许多掩码重叠并且可能彼此不一致。
    拼接方法很简单:
    首先,我们将根据预测分数对预测的掩码进行排序:
masks=masks[:,0].astype(bool)shorted_masks = masks[np.argsort(scores[:,0])][::-1].astype(bool)
    现在让我们创建一个空的分割图和占用图:
seg_map = np.zeros_like(shorted_masks[ 0 ],dtype=np.uint8) occupancy_mask = np.zeros_like(shorted_masks[ 0 ],dtype= bool )

    接下来,我们将蒙版逐一(从高分到低分)添加到分割图中。我们只会在蒙版与之前添加的蒙版一致时才添加蒙版,也就是说,只有当我们要添加的蒙版与已占用区域的重叠度小于 15% 时,我们才会添加蒙版。

for i in range(shorted_masks.shape[0]):    mask = shorted_masks[i]    if (mask*occupancy_mask).sum()/mask.sum()>0.15: continue     mask[occupancy_mask]=0    seg_map[mask]=i+1    occupancy_mask[mask]=1
    seg_mask现在包含预测的分割图,每个分割图有不同的值,背景为 0。
    我们可以使用以下方法将其转换为彩色图:
rgb_image = np.zeros((seg_map.shape[0], seg_map.shape[1], 3), dtype=np.uint8)for id_class in range(1,seg_map.max()+1):    rgb_image[seg_map == id_class] = [np.random.randint(255), np.random.randint(255), np.random.randint(255)]
cv2.imshow("annotation",rgb_image)cv2.imshow("mix",(rgb_image/2+image/2).astype(np.uint8))cv2.imshow("image",image)cv2.waitKey()
    完整的推理代码可在此处获得:
https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code/blob/main/TEST_Net.py?source=post_page-----928dd29a63b3--------------------------------

—THE END—

下载1:Pytorch常用函数手册

在「OpenCV与AI深度学习公众号后台回复:Pytorch函数手册即可下载学习全网第一份Pytorch函数常用手册,包括Tensors介绍、基础函数介绍、数据处理函数、优化函数、CUDA编程、多处理等十四章内容。

下载2:145个OpenCV实例应用代码
在「OpenCV与AI深度学习」公众号后台回复:OpenCV145即可下载学习145个OpenCV实例应用代码(Python和C++双语言实现)。

欢迎加入CV学习交流微信

觉得有用,记得点个赞和在看 

OpenCV与AI深度学习
专注计算机视觉、深度学习和人工智能领域干货、应用、行业资讯的分享交流!
 最新文章