用 60 行代码训练/调整 SAM 2 完成
2024/10/10 21:03:18
本文主要是介绍用 60 行代码训练/调整 SAM 2 完成,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
SAM2 (Segment Anything 2) 是一款Meta公司推出的全新模型,旨在对图像中的任何内容进行分割,而不受特定类别或领域的限制。它的独特之处在于其训练数据的规模:1.1亿张图像和110亿个掩膜。这种大规模的训练使SAM2成为一个强大的起点,可以用于新的图像分割任务的训练。
你可能会问,既然SAM能够分割任何东西,为什么我们还需要再训练它?答案是,SAM在处理常见对象时表现出色,但在处理稀有或特定领域任务时可能会表现不佳。
然而,即使在SAM表现不佳的情况下,通过在新数据上进行微调,仍然可以显著提高模型的能力。在很多情况下,这只需要更少的数据,并且比从零开始训练模型获得更好的结果。
本教程将演示如何仅需60行代码(不包括注释和导入)来调整SAM2,并使用新的数据集。
下面可以找到完整的训练脚本
使用60行代码对Segment Anything Model 2 (SAM 2) 进行微调和训练/TRAIN.py 在main分支···提供了用于训练/微调Meta的Segment Anything Model 2 (SAM 2) 的代码……github.comSAM2 网络结构图来自 SAM2 GIT 页面
Segment Anything 是怎么工作的
SAM 工作的主要方式是通过对一幅图像和图像中的一个点,预测包含该点的区域掩码。这种方法可以实现全自动的完整图像分割,并且对分割的类别或类型没有任何限制,如在这篇文章中所讨论的一样。
使用SAM进行图像分割的步骤:
- 选择图像中的一组点
- 使用SAM预测出每个点所在的区域
- 将得到的区域合并成一个整体
虽然SAM也可以利用其他输入,如掩膜或边界框(bounding box),但这些主要用于涉及人工输入的交互式分割。在本教程里,我们将专注于全自动分割,而只考虑单点输入。
更多关于模型的细节请访问项目网站了解更多详情。
SAM2 可从这里下载。
如果你不想复制训练脚本这部分,你也可以下载我已包含训练脚本的我的分叉版本。
按照GitHub仓库中的安装说明进行操作。
你需要 Python 3.11 及以上版本和 PyTorch,通常来说。
另外,我们将使用OpenCV,可以通过以下方式安装。
pip安装opencv-python
您还需要下载预训练的模型:
点击这里下载检查点文件:https://github.com/facebookresearch/segment-anything-2?tab=readme-ov-file#download-checkpoints
你可以从几个模型中选择,它们都适用于本教程。我推荐使用这个较小的模型,它是训练速度最快的。
在本教程中,我们将使用LabPics1数据集,并将材料和液体分开。您可以通过上述链接下载数据集。
点击下载实验图片集文件:https://zenodo.org/records/3697452/files/LabPicsV1.zip?download=1
我们需要首先写的是数据读取器程序。它会读取并准备数据以供网络使用。
数据读取器需要输出:
- 一张图像
- 图像中所有区域的蒙版。
- 每个蒙版内的一个随机点:训练指针网络以分割物体、部件和材料
我们先启动一下依赖项:
# 导入必要的库 import numpy as np import torch import cv2 import os from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor
接下来是该数据集的所有图片。
data_dir=r"LabPicsV1//" # LabPicsV1 数据集文件夹的路径 data=[] # 数据集中的文件列表 for ff, name in enumerate(os.listdir(data_dir+"Simple/Train/Image/")): # 遍历所有标注文件 data.append({"图像":data_dir+"Simple/Train/Image/"+name,"注释":data_dir+"Simple/Train/Instance/"+name[:-4]+".png"}) # 注:将图像文件名的最后四位替换为 .png
现在是加载训练批次的主要功能。训练批次包括:一张随机图像,该图像的所有掩模及其相应的掩膜,每个掩模中的一个随机点。
def read_batch(data): # 从数据集(LabPics)中读取随机图像及其标注 # 从数据集中随机选择一个条目 ent = data[np.random.randint(len(data))] # 随机选择条目 Img = cv2.imread(ent["image"])[...,::-1] # 读取图像 ann_map = cv2.imread(ent["annotation"]) # 读取标注 # 调整大小 r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) # 缩放因子 Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r))) ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),interpolation=cv2.INTER_NEAREST) # 合并材料和管标注 mat_map = ann_map[:,:,0] # 材料图 ves_map = ann_map[:,:,2] # 管道图 mat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # 合并后的图 # 生成二进制掩码和像素点 inds = np.unique(mat_map)[1:] # 加载所有索引 points= [] masks = [] for ind in inds: mask=(mat_map == ind).astype(np.uint8) # 生成二进制掩码 masks.append(mask) coords = np.argwhere(mask > 0) # 获取掩码坐标 yx = np.array(coords[np.random.randint(len(coords))]) # 选择随机像素点/坐标 points.append([[yx[1], yx[0]]]) return Img,np.array(masks),np.array(points), np.ones([len(masks),1])
函数的第一部分是从中随机选择一张图片,然后加载它。
ent = data[np.random.randint(len(data))] # 随机选择一个条目 Img = cv2.imread(ent["image"]) # 读入图像 ann_map = cv2.imread(ent["annotation"]) # 读入注释
请注意,OpenCV 以 BGR 色彩顺序读取图像,而 SAM 期望 RGB 色彩顺序的图像。通过使用 […, ::-1],图像从 BGR 转换为 RGB。
根据 SAM 的要求,图像大小不应超过 1024 像素,因此我们将调整图像和标注图层的大小到不超过 1024 像素。
r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) # 缩放比例 Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r))) # 将图像调整为最大1024像素 ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), interpolation=cv2.INTER_NEAREST) # 最邻近插值法
这里的关键点是,在调整注释图的大小时,我们使用了最近邻模式(最近邻,即_INTER_NEAREST_模式)。在注释图(注释地图,即_annmap)中,每个像素值代表其所属区域的索引。因此,重要的是要使用不会在图中引入新值的缩放方法,以确保地图的准确性。
接下来的部分专门针对LabPics1数据集的格式。注解映射(_annmap)包含一个通道中的血管分割图,以及另一个通道中的材料注解图。我们要将它们合并成单一映射。
mat_map = ann_map[:,:,0] # 材料标注图(材料注解图) ves_map = ann_map[:,:,2] # 血管标注图(血管注解图) mat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # 将mat_map中值为0的部分替换为ves_map中对应位置的值乘以mat_map的最大值加1
这给我们提供了一张图像(_matmap),其中每个像素的值是它所属的段的编号(例如:所有值为3的单元格都属于第3段)。我们希望将其转换为一系列二值掩模(0/1),其中每个掩模代表一个不同的段。此外,从每个掩模中,我们希望提取一个单独的点。
inds = np.unique(mat_map)[1:] # 地图中所有索引的列表 points = [] # 所有点的列表(每个掩码一个点) masks = [] # 所有掩码的列表 for ind in inds: mask = (mat_map == ind).astype(np.uint8) # 将索引 ind 对应的二值掩码创建出来 masks.append(mask) coords = np.argwhere(mask > 0) # 获取掩码内所有坐标 yx = np.array(coords[np.random.randint(len(coords))]) # 从这些坐标中随机选取一个点 points.append([[yx[1], yx[0]]]) return Img,np.array(masks),np.array(points), np.ones([len(masks),1])
我们得到了图片 (Img),与图片中各个区域对应的二值掩膜(binary masks)列表 (masks),以及每个掩膜内单个点的坐标 (points)。
以下是一组示例训练数据:1)一张图像。2)掩膜列表。3)每个掩膜内标记一个红色点。取自LabPics数据集。
加载SAM模型
现在让我们加载网络模型。
sam2_checkpoint = "sam2_hiera_small.pt" # 模型权重的路径 model_cfg = "sam2_hiera_s.yaml" # 模型配置 sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") # 加载模型权重 predictor = SAM2ImagePredictor(sam2_model) # 初始化预测器
首先,我们在 _sam2checkpoint 参数中设置权重文件的路径。我们之前从这里下载了权重文件。 “sam2_hiera_small.pt”指的是小型模型,但是代码可以适用于任何类型的模型。无论选择哪种模型,都需要在 _modelcfg 参数中设置相应的配置文件。配置文件位于主仓库中的子文件夹“sam2_configs/”中。
在我们开始训练之前,我们先理解一下模型的结构。
SAM由三个部分构成:
1) 图像编码器模块,2) 提示编码器模块,3) 掩码解码器模块。
图像编码器负责处理图像并过程创建图像嵌入。这是其中最大的一个组件,训练它需要强大的GPU。
提示编码器(Prompt Encoder)处理我们的输入,也就是输入提示(input prompt),在我们的情况下是输入点(input point)。
掩码解码器从图像编码器和提示嵌入编码器获取输出,并生成最终的分割掩模。
我们可以通过设置来开启掩码解码器和提示编码器的训练。
predictor.model.sam_mask_decoder.train(True) predictor.model.sam_prompt_encoder.train(True)
你可以通过使用‘_predictor.model.imageencoder.train(True)’来开启图像编码器的训练。
这需要一个更强大的GPU,但会为网络提供更多的进步空间。如果你选择训练图像编码器,你需要在SAM2代码中找到并移除所有“_no_grad”命令。(“_no_grad”阻止了梯度收集,这虽然节省了内存,但会阻止训练)。
接下来,我们要定义标准的adamW优化算法:
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)
优化器使用了AdamW算法,参数包括预测模型的所有参数,学习率为1e-5,权重衰减为4e-5。
我们也会采用混合精度训练,这仅仅是一种更节省内存的训练方法。
scaler = torch.cuda.amp.GradScaler() # 混合精度设置
现在,我们开始搭建主要的训练循环部分。第一部分是读取数据并进行准备,
for itr in range(100000): with torch.cuda.amp.autocast(): # 混合精度模式 image,mask,input_point, input_label = read_batch(data) # 加载一批数据 if mask.shape[0]==0: continue # 跳过空批次 predictor.set_image(image) # 将图像输入到SAM图像编码器中
首先,我们将数据转换为混合精度格式以便高效训练:
with torch.cuda.amp.autocast(); # 使用torch的自动混合精度上下文管理器
接下来,我们来用我们之前创建的读取功能读取训练资料。
image, mask, 输入点, 输入标签 = 读取数据批次(data)
我们将加载的图像输入到图像编码器中(网络的第一个环节)进行编码。
预测器设置图像(image)
接下来,我们使用网络提示编码器模型处理输入数据点。
mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True) sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(非归一化坐标, 标签), boxes=None, masks=None,)
注意,在这部分虽然可以输入框或遮罩效果,但我们不会使用这些选项。
现在我们已经将提示点(points)和图像都编码了,我们现在终于可以预测分割掩码了。
batched_mode = unnorm_coords.shape[0] > 1 # 多个掩码的预测模式 high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]] low_res_masks, prd_scores, 忽略的得分, 忽略的掩码 = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=batched_mode,high_res_features=high_res_features,) prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# 将低分辨率掩码上采样到原始图像分辨率
这段代码的主要部分是 _model.sam_maskdecoder,它运行掩码解码器部分,并生成低分辨率掩模(_low_resmasks)及其得分(_prdscores)。
这些掩膜的分辨率低于原始输入图像,并在_postprocessmasks函数中调整回原始输入图像的大小。
这为我们提供了网络的最终预测结果:对于每个输入点,网络会给出3个分割掩码(_prdmasks)和相应的掩码评分(_prdscores)。_prd_masks_为每个输入点提供了3个预测掩码,但我们只使用每个点的第一个掩码。_prd_scores_给出了网络对每个掩码有多好的评分(或对预测有多确定)。
现在我们有了网络预测,我们可以计算损失。首先,我们计算分割误差,这表示预测掩膜与真实掩膜的匹配程度如何。为此,我们采用了标准交叉熵损失。
首先,我们将预测掩码(_prdmask)使用 sigmoid 函数来将 logits 转换为概率:
prd_mask = torch.sigmoid(prd_masks[:, 0]) # 将逻辑图转化为概率图
我们将 ground truth mask(真实掩模)转换为 torch tensor(torch 张量)。
prd_mask = torch.sigmoid(prd_masks[:, 0]) # 将逻辑图转换为概率图
最后,我们手动计算交叉熵损失值(seg_loss),使用 ground truth (gt_mask)和预测的概率图(prd_mask)来:
seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # 交叉熵损失,这里计算的是预测掩模和真实掩模之间的差异
为了防止对数函数因0而导致结果趋向无穷,我们在数值中加0.0001。
除了预测掩膜之外,网络还会为每个掩膜评估一个质量分数。虽然训练这部分不那么重要,但是它仍然非常有用。首先,我们需要确定每个预测掩膜的真实质量评分。换句话说,我们需要找出预测掩膜的实际质量如何。我们通过比较GT掩膜和预测掩膜的IOU来确定这个分数。IOU是两个掩膜重叠部分的面积除以它们总面积的比值。首先,我们要计算这两个掩膜的重叠部分。
# 计算gt_mask和prd_mask的交集,并求和 inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
我们使用这个阈值(即 prd_mask > 0.5)将预测掩码从概率形式转换为二进制掩码。
接下来,我们通过将预测和真实掩码的交集面积除以它们的总面积来计算IOU。
iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter) # iou:交并比,inter:交集区域,gt_mask:真实掩码,prd_mask:预测掩码
我们将使用 IOU(交并比)作为每个掩膜的真实分数,并得分损失则为预测分数与刚刚计算的 IOU 之间的差值的绝对值。
score_loss = torch.abs(prd_scores[:, 0] - iou).mean() # 计算预测得分与IOU的绝对差值的平均值
最后,我们将分词损失和评分损失(给予前者更高的权重)合并在一起:
loss = seg_loss + 权重score_loss * 0.05 # 损失的混合(将分割损失和评分损失按比例混合)
一旦我们得到损失值,所有步骤就完全按照标准流程进行。我们计算反向传播,并利用之前制作的优化器来更新权重。
predictor.model.zero_grad() # 清零梯度 scaler.scale(loss).backward() # 计算反向传播 scaler.step(optimizer) # 执行优化器步骤 scaler.update() # 混合精度更新
我们也想每隔1000步保存一次训练模型:
if itr%1000==0: torch.save(predictor.model状态, "model.torch") # 保存模型到文件
既然已经计算了IOU,我们可以将其当作移动平均值显示出来,来看看模型预测随着时间推移是否有所改进:
如果 itr == 0: mean_iou = 0 mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy()) print("步数:", itr, "准确率(IoU)=", mean_iou)
就这样了,我们用不到60行代码(不包括注释和导入)训练并微调了Segment-Anything 2模型。经过大约25,000步后,你就会看到明显的改善。
模型将会被保存为“model.torch”。
你可以在这里找到完整的训练代码:
本教程使用每批一张图片的方法,更高效的方法是每批使用多张不同图片。相关代码可以在以下链接找到:
_https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code/blob/main/TRAIN_multi_imagebatch.py
现在模型已经调整好了,让我们用它来分割一幅图片吧。
咱们打算按照以下步骤来做这件事:
- 加载我们刚刚训练的模型。
- 给模型一张图片和一堆随机点。对于每个点,网络会预测包含它的分割掩码和一个分数。
- 将这些掩码拼接在一起,形成一张分割图。
完整的代码可以在以下位置找到:
首先,我们加载依赖并将权重调整为float16,这让模型运行更快(仅在进行推理时有效)。
import numpy as np import torch import cv2 from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor # 使用bfloat16来进行整个脚本的运算,这样更省内存。 torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
接下来,我们加载一个样本图像。下载图像和掩膜的链接如下:图像/掩膜。
image_path = r"sample_image.jpg" # 图像文件路径 mask_path = r"sample_mask.png" # 掩膜路径,定义要分割的图像区域 def read_image(image_path, mask_path): # 读取并调整图像和掩膜大小 img = cv2.imread(image_path)[...,::-1] # 以RGB读取图像 mask = cv2.imread(mask_path,0) # 要分割区域的掩码 # 调整图像到最大尺寸1024 r = np.min([1024 / img.shape[1], 1024 / img.shape[0]]) img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r))) mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)),interpolation=cv2.INTER_NEAREST) return img, mask image,mask = read_image(image_path, mask_path)
从我们想要分割的区域内随机抽取30个点:
num_samples = 30 # 采样点的数量 def get_points(mask, num_points): # 在输入掩膜内获取点 points=[] for i in range(num_points): coords = np.argwhere(mask > 0) yx = np.array(coords[np.random.randint(len(coords))]) points.append([yx[1], yx[0]]) return np.array(points) input_points = get_points(mask, num_samples)
加载标准的SAM模型(和训练时所用的一样)
# 先加载你需要的模型,确保你已经有预训练好的模型 sam2_checkpoint = "sam2_hiera_small.pt" # the path to the checkpoint file model_cfg = "sam2_hiera_s.yaml" # the configuration file path for the model sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") # 使用配置文件和检查点文件构建SAM2模型 predictor = SAM2ImagePredictor(sam2_model) # 创建一个预测器对象来处理图像
接下来,加载我们刚刚训练的模型的权重(model.torch)。
预测器模型加载来自文件 "model.torch" 的状态字典。
让我们运行微调模型,为之前选择的每个点预测其分割掩码。
在不计算梯度的情况下 # 禁用网络计算梯度(更高效的推断) predictor.set_image(image) # 图像编码器 masks, scores, logits = predictor.predict( # prompt编码和mask解码 point_coords=input_points, point_labels=np.ones((input_points.shape[0], 1)) )
我们现在有一系列预测的掩膜及其得分。我们希望将它们拼合成一个单一且一致的分割结果图。然而,许多掩膜重叠,可能相互不一致。
接下来,拼接的方法很简单:
首先,我们将按照预测的分数对掩码结果进行排序。
masks=masks[:,0].astype(bool) # 将掩码转换为布尔类型 shorted_masks = masks[np.argsort(scores[:,0])][::-1].astype(bool) # 对掩码按得分排序并反转
现在让我们创建一个空的分割图和占用图
seg_map = np.zeros_like(shorted_masks[0],dtype=np.uint8) # 分割图初始化为与缩短的掩码相同尺寸的零矩阵 occupancy_mask = np.zeros_like(shorted_masks[0],dtype=bool) # 占用掩码初始化为与缩短的掩码相同尺寸的布尔矩阵
接下来,我们依次按得分从高到低将掩码添加到分割图中。我们只会添加那些与已添加的掩码不冲突的掩码,换句话说,只有当我们要添加的掩码与已占用区域的重叠部分不超过15%时才会添加。
for i in range(shorted_masks.shape[0]): # 这个循环用于处理缩短的掩码(shorted_masks),并更新分割映射(seg_map)和占用掩码(occupancy_mask). mask = shorted_masks[i] if (mask*occupancy_mask).sum()/mask.sum()>0.15: # 如果掩码与占用掩码的交集超过15%,则继续下一次迭代. continue mask[occupancy_mask]=0 seg_map[mask]=i+1 occupancy_mask[mask]=1
就这样了。
_segmask 现在包含了预测的分割结果图,每个区域有不同的值,背景值为0。
我们可以将其转换为颜色映射,方法如下:
# 相关代码片段
rgb_image = np.zeros((seg_map.shape[0], seg_map.shape[1], 3), dtype=np.uint8) # 创建一个全零的三维数组,用于存储RGB图像 for id_class in range(1,seg_map.max()+1): rgb_image[seg_map == id_class] = [np.random.randint(255), np.random.randint(255), np.random.randint(255)] # 为每个类别随机生成RGB颜色值
显示:
cv2.imshow("annotation", rgb_image) # 显示注释图像 cv2.imshow("mix", (rgb_image / 2 + image / 2).astype(np.uint8)) # 展示混合图像(RGB图像和图像的平均值) cv2.imshow("image", image) # 显示原始图像 cv2.waitKey() # 等待按键
这是使用微调过的SAM2进行分割操作的一个示例。图像来自LabPics数据集。
完整的推断代码如下所示:
就这样,我们已经训练并测试了SAM2模型。除了更改数据读取器之外,这应该对任何数据集都有效。在许多情况下,这应该足以显著提升性能。
最后,SAM2 还能做到对视频中的对象进行分割和跟踪,但微调这部分功能留待以后再说。
版权: 所有发帖所用的图片均来自 SAM2 GIT 仓库(Apache 许可证)和 LabPics 数据集(MIT 许可证)。本教程的代码和模型在 Apache 许可证下。
这篇关于用 60 行代码训练/调整 SAM 2 完成的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-22程序员出海做 AI 工具:如何用 similarweb 找到最佳流量渠道?
- 2024-12-20自建AI入门:生成模型介绍——GAN和VAE浅析
- 2024-12-20游戏引擎的进化史——从手工编码到超真实画面和人工智能
- 2024-12-20利用大型语言模型构建文本中的知识图谱:从文本到结构化数据的转换指南
- 2024-12-20揭秘百年人工智能:从深度学习到可解释AI
- 2024-12-20复杂RAG(检索增强生成)的入门介绍
- 2024-12-20基于大型语言模型的积木堆叠任务研究
- 2024-12-20从原型到生产:提升大型语言模型准确性的实战经验
- 2024-12-20啥是大模型1
- 2024-12-20英特尔的 Lunar Lake 计划:一场未竟的承诺