点击下方卡片,关注“OpenCV与AI深度学习”
视觉/图像重磅干货,第一时间送达!
https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code/blob/main/TRAIN.py?source=post_page-----928dd29a63b3--------------------------------
https://github.com/facebookresearch/segment-anything-2?source=post_page-----928dd29a63b3--------------------------------
https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code?source=post_page-----928dd29a63b3--------------------------------
pip install opencv-python
https://github.com/facebookresearch/segment-anything-2?tab=readme-ov-file#download-checkpoints
https://zenodo.org/records/3697452/files/LabPicsV1.zip?download=1
import numpy as np
import torch
import cv2
import os
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
data_dir=r"LabPicsV1//" # Path to LabPics1 dataset folder
data=[] # list of files in dataset
for ff, name in enumerate(os.listdir(data_dir+"Simple/Train/Image/")): # go over all folder annotation
data.append({"image":data_dir+"Simple/Train/Image/"+name,"annotation":data_dir+"Simple/Train/Instance/"+name[:-4]+".png"})
def read_batch(data): # read random image and its annotaion from the dataset (LabPics)
# select image
ent = data[np.random.randint(len(data))] # choose random entry
Img = cv2.imread(ent["image"])[...,::-1] # read image
ann_map = cv2.imread(ent["annotation"]) # read annotation
# resize image
r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) # scalling factor
Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),interpolation=cv2.INTER_NEAREST)
# merge vessels and materials annotations
mat_map = ann_map[:,:,0] # material annotation map
ves_map = ann_map[:,:,2] # vessel annotaion map
mat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # merged map
# Get binary masks and points
inds = np.unique(mat_map)[1:] # load all indices
points= []
masks = []
for ind in inds:
mask=(mat_map == ind).astype(np.uint8) # make binary mask
masks.append(mask)
coords = np.argwhere(mask > 0) # get all coordinates in mask
yx = np.array(coords[np.random.randint(len(coords))]) # choose random point/coordinate
points.append([[yx[1], yx[0]]])
return Img,np.array(masks),np.array(points), np.ones([len(masks),1])
ent = data[np.random.randint(len(data))] # choose random entry
Img = cv2.imread(ent["image"])[...,::-1] # read image
ann_map = cv2.imread(ent["annotation"]) # read annotation
r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) # scalling factor
Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),interpolation=cv2.INTER_NEAREST)
这里的一个重点是,在调整注释地图 ( ann_map ) 的大小时,我们使用INTER_NEAREST模式(最近邻)。在注释地图中,每个像素值都是其所属段的索引。因此,使用不会向地图引入新值的调整大小方法非常重要。
下一个块特定于 LabPics1 数据集的格式。注释图 ( ann_map ) 包含一个通道中图像中血管的分割图,以及另一个通道中材料注释的图。我们将把它们合并为一张图。
mat_map = ann_map[:,:,0] # material annotation map
ves_map = ann_map[:,:,2] # vessel annotaion map
mat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # merged map
inds = np.unique(mat_map)[1:] # list of all indices in map
points= [] # list of all points (one for each mask)
masks = [] # list of all masks
for ind in inds:
mask = (mat_map == ind).astype(np.uint8) # make binary mask for index ind
masks.append(mask)
coords = np.argwhere(mask > 0) # get all coordinates in mask
yx = np.array(coords[np.random.randint(len(coords))]) # choose random point/coordinate
points.append([[yx[1], yx[0]]])
return Img,np.array(masks),np.array(points), np.ones([len(masks),1])
sam2_checkpoint = "sam2_hiera_small.pt" # path to model weight
model_cfg = "sam2_hiera_s.yaml" # model config
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") # load model
predictor = SAM2ImagePredictor(sam2_model) # load net
predictor.model.sam_mask_decoder.train(True) # enable training of mask decoder
predictor.model.sam_prompt_encoder.train(True) # enable training of prompt encoder
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)
scaler = torch.cuda.amp.GradScaler() # set mixed precision
for itr in range(100000):
with torch.cuda.amp.autocast(): # cast to mix precision
image,mask,input_point, input_label = read_batch(data) # load data batch
if mask.shape[0]==0: continue # ignore empty batches
predictor.set_image(image) # apply SAM image encoder to the image
首先,我们将数据转换为混合精度以实现高效训练:
with torch.cuda.amp.autocast():
接下来,我们使用之前创建的读取器函数来读取训练数据:
image,mask,input_point, input_label = read_batch(data)
我们将加载的图像传递给图像编码器(网络的第一部分):
predictor.set_image(image)
接下来,我们使用网络提示编码器处理输入点:
mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)
batched_mode = unnorm_coords.shape[0] > 1 # multi mask prediction
high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=batched_mode,high_res_features=high_res_features,)
prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution
prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map
prd_mask = torch.sigmoid(prd_masks[:, 0 ]) # 将logit图转为概率图
最后,我们使用基本事实(gt_mask)和预测概率图(prd_mask )手动计算交叉熵损失(seg_loss ) :
seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001 ) - ( 1 - gt_mask) * torch.log(( 1 - prd_mask) + 0.00001 )).mean() # 交叉熵损失
inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
我们将使用 IOU 作为每个 mask 的真实分数,并将预测分数与我们刚刚计算的 IOU 之间的绝对差作为分数损失。
score_loss = torch.abs ( prd_scores[:, 0 ] - iou).mean()
最后,我们合并分割损失和分数损失(给予第一个更高的权重):
loss = seg_loss+score_loss* 0.05 # 混合损失
predictor.model.zero_grad() # 空梯度
scaler.scale(loss).backward() # 反向传播
scaler.step(optimizer)
scaler.update() # 混合精度
我们还希望每 1000 步保存一次训练好的模型:
if itr%1000==0: torch.save(predictor.model.state_dict(), "model.torch") # save model
由于我们已经计算了 IOU,我们可以将其显示为移动平均值,以查看模型预测随时间的改善程度:
if itr==0: mean_iou=0
mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
print("step)",itr, "Accuracy(IOU)=",mean_iou)
https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code/blob/main/TRAIN.py?source=post_page-----928dd29a63b3--------------------------------
https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code/blob/main/TRAIN_multi_image_batch.py
https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code/blob/main/TEST_Net.py?source=post_page-----928dd29a63b3--------------------------------
首先,我们加载依赖项并将权重转换为 float16,这使得模型运行速度更快(仅适用于推理)。
import numpy as np
import torch
import cv2
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
# 对整个脚本使用 bfloat16(节省内存)
torch.autocast(device_type= "cuda" , dtype=torch.bfloat16).__enter__()
接下来,我们加载示例图像和想要分割的图像区域的蒙版(下载图像/蒙版):
image_path = r"sample_image.jpg" # path to image
mask_path = r"sample_mask.png" # path to mask, the mask will define the image region to segment
def read_image(image_path, mask_path): # read and resize image and mask
img = cv2.imread(image_path)[...,::-1] # read image as rgb
mask = cv2.imread(mask_path,0) # mask of the region we want to segment
# Resize image to maximum size of 1024
r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)),interpolation=cv2.INTER_NEAREST)
return img, mask
image,mask = read_image(image_path, mask_path)
在我们要分割的区域内随机取样30个点:
num_samples = 30 # 要采样的点/段的数量
def get_points ( mask, num_points ): # 输入掩码内的采样点 points
=[]
for i in range (num_points):
coords = np.argwhere(mask > 0 )
yx = np.array(coords[np.random.randint( len (coords))])
points.append([[yx[ 1 ], yx[ 0 ]]])
return np.array(points)
input_points = get_points(mask, num_samples)
加载标准 SAM 模型(与训练中相同)
# 加载模型您需要已经制作好预训练模型
sam2_checkpoint = "sam2_hiera_small.pt"
model_cfg = "sam2_hiera_s.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device= "cuda" )
predictor = SAM2ImagePredictor(sam2_model)
接下来,加载我们刚刚训练的模型的权重(model.torch):
predictor.model.load_state_dict(torch.load("model.torch"))
with torch.no_grad(): # prevent the net from caclulate gradient (more efficient inference)
predictor.set_image(image) # image encoder
masks, scores, logits = predictor.predict( # prompt encoder + mask decoder
point_coords=input_points,
point_labels=np.ones([input_points.shape[0],1])
)
masks=masks[:,0].astype(bool)
shorted_masks = masks[np.argsort(scores[:,0])][::-1].astype(bool)
seg_map = np.zeros_like(shorted_masks[ 0 ],dtype=np.uint8)
occupancy_mask = np.zeros_like(shorted_masks[ 0 ],dtype= bool )
接下来,我们将蒙版逐一(从高分到低分)添加到分割图中。我们只会在蒙版与之前添加的蒙版一致时才添加蒙版,也就是说,只有当我们要添加的蒙版与已占用区域的重叠度小于 15% 时,我们才会添加蒙版。
for i in range(shorted_masks.shape[0]):
mask = shorted_masks[i]
if (mask*occupancy_mask).sum()/mask.sum()>0.15: continue
mask[occupancy_mask]=0
seg_map[mask]=i+1
occupancy_mask[mask]=1
rgb_image = np.zeros((seg_map.shape[0], seg_map.shape[1], 3), dtype=np.uint8)
for id_class in range(1,seg_map.max()+1):
rgb_image[seg_map == id_class] = [np.random.randint(255), np.random.randint(255), np.random.randint(255)]
cv2.imshow("annotation",rgb_image)
cv2.imshow("mix",(rgb_image/2+image/2).astype(np.uint8))
cv2.imshow("image",image)
cv2.waitKey()
https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code/blob/main/TEST_Net.py?source=post_page-----928dd29a63b3--------------------------------
—THE END—
下载1:Pytorch常用函数手册
欢迎加入CV学习交流微信群!