用 60 行代码训练/调整 SAM 2 完成

2024/10/10 21:03:18

本文主要是介绍用 60 行代码训练/调整 SAM 2 完成,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

分步教程:如何为自定义分割任务微调SAM2模型

SAM2 (Segment Anything 2) 是一款Meta公司推出的全新模型,旨在对图像中的任何内容进行分割,而不受特定类别或领域的限制。它的独特之处在于其训练数据的规模:1.1亿张图像和110亿个掩膜。这种大规模的训练使SAM2成为一个强大的起点,可以用于新的图像分割任务的训练。

你可能会问,既然SAM能够分割任何东西,为什么我们还需要再训练它?答案是,SAM在处理常见对象时表现出色,但在处理稀有或特定领域任务时可能会表现不佳。
然而,即使在SAM表现不佳的情况下,通过在新数据上进行微调,仍然可以显著提高模型的能力。在很多情况下,这只需要更少的数据,并且比从零开始训练模型获得更好的结果。

本教程将演示如何仅需60行代码(不包括注释和导入)来调整SAM2,并使用新的数据集。

下面可以找到完整的训练脚本

使用60行代码对Segment Anything Model 2 (SAM 2) 进行微调和训练/TRAIN.py 在main分支···提供了用于训练/微调Meta的Segment Anything Model 2 (SAM 2) 的代码……github.com

SAM2 网络结构图来自 SAM2 GIT 页面

Segment Anything 是怎么工作的

SAM 工作的主要方式是通过对一幅图像和图像中的一个点,预测包含该点的区域掩码。这种方法可以实现全自动的完整图像分割,并且对分割的类别或类型没有任何限制,如在这篇文章中所讨论的一样。

使用SAM进行图像分割的步骤:

  1. 选择图像中的一组点
  2. 使用SAM预测出每个点所在的区域
  3. 将得到的区域合并成一个整体

虽然SAM也可以利用其他输入,如掩膜或边界框(bounding box),但这些主要用于涉及人工输入的交互式分割。在本教程里,我们将专注于全自动分割,而只考虑单点输入。

更多关于模型的细节请访问项目网站了解更多详情。

下载SAM2并配置环境

SAM2 可从这里下载。

GitHub - facebookresearch/segment-anything-2:该仓库提供了使用Meta Segment Anything Model 2 (SAM 2) 进行推理的代码和链接…github.com

如果你不想复制训练脚本这部分,你也可以下载我已包含训练脚本的我的分叉版本。

GitHub - sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code: 该项目包含……该项目提供了用于训练/微调Meta的Segment Anything Model 2 (SAM 2) 的代码……

按照GitHub仓库中的安装说明进行操作。

你需要 Python 3.11 及以上版本和 PyTorch,通常来说。

另外,我们将使用OpenCV,可以通过以下方式安装。

pip安装opencv-python

下载预先训练好的模型(下载预训练模型)

您还需要下载预训练的模型:

点击这里下载检查点文件:https://github.com/facebookresearch/segment-anything-2?tab=readme-ov-file#download-checkpoints

你可以从几个模型中选择,它们都适用于本教程。我推荐使用这个较小的模型,它是训练速度最快的。

下载训练用的数据

在本教程中,我们将使用LabPics1数据集,并将材料和液体分开。您可以通过上述链接下载数据集。

点击下载实验图片集文件:https://zenodo.org/records/3697452/files/LabPicsV1.zip?download=1

准备数据读取

我们需要首先写的是数据读取器程序。它会读取并准备数据以供网络使用。

数据读取器需要输出:

  1. 一张图像
  2. 图像中所有区域的蒙版。
  3. 每个蒙版内的一个随机点:训练指针网络以分割物体、部件和材料

我们先启动一下依赖项:

# 导入必要的库
import numpy as np  
import torch  
import cv2  
import os  
from sam2.build_sam import build_sam2  
from sam2.sam2_image_predictor import SAM2ImagePredictor

接下来是该数据集的所有图片。

    data_dir=r"LabPicsV1//" # LabPicsV1 数据集文件夹的路径  
    data=[] # 数据集中的文件列表  
    for ff, name in enumerate(os.listdir(data_dir+"Simple/Train/Image/")):  # 遍历所有标注文件  
        data.append({"图像":data_dir+"Simple/Train/Image/"+name,"注释":data_dir+"Simple/Train/Instance/"+name[:-4]+".png"}) # 注:将图像文件名的最后四位替换为 .png

现在是加载训练批次的主要功能。训练批次包括:一张随机图像,该图像的所有掩模及其相应的掩膜,每个掩模中的一个随机点。

    def read_batch(data): # 从数据集(LabPics)中读取随机图像及其标注

       # 从数据集中随机选择一个条目  

            ent  = data[np.random.randint(len(data))] # 随机选择条目  
            Img = cv2.imread(ent["image"])[...,::-1]  # 读取图像  
            ann_map = cv2.imread(ent["annotation"]) # 读取标注  

       # 调整大小  

            r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) # 缩放因子  
            Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))  
            ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),interpolation=cv2.INTER_NEAREST)  

       # 合并材料和管标注  

            mat_map = ann_map[:,:,0] # 材料图  
            ves_map = ann_map[:,:,2] # 管道图  
            mat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # 合并后的图  

       # 生成二进制掩码和像素点  

            inds = np.unique(mat_map)[1:] # 加载所有索引  
            points= []  
            masks = []   
            for ind in inds:  
                mask=(mat_map == ind).astype(np.uint8) # 生成二进制掩码  
                masks.append(mask)  
                coords = np.argwhere(mask > 0) # 获取掩码坐标  
                yx = np.array(coords[np.random.randint(len(coords))]) # 选择随机像素点/坐标  
                points.append([[yx[1], yx[0]]])  
            return Img,np.array(masks),np.array(points), np.ones([len(masks),1])

函数的第一部分是从中随机选择一张图片,然后加载它。

    ent  = data[np.random.randint(len(data))] # 随机选择一个条目  
    Img = cv2.imread(ent["image"])  # 读入图像  
    ann_map = cv2.imread(ent["annotation"]) # 读入注释

请注意,OpenCV 以 BGR 色彩顺序读取图像,而 SAM 期望 RGB 色彩顺序的图像。通过使用 […, ::-1],图像从 BGR 转换为 RGB。

根据 SAM 的要求,图像大小不应超过 1024 像素,因此我们将调整图像和标注图层的大小到不超过 1024 像素。

    r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) # 缩放比例
    Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))  # 将图像调整为最大1024像素
    ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), interpolation=cv2.INTER_NEAREST)  # 最邻近插值法

这里的关键点是,在调整注释图的大小时,我们使用了最近邻模式(最近邻,即_INTER_NEAREST_模式)。在注释图(注释地图,即_annmap)中,每个像素值代表其所属区域的索引。因此,重要的是要使用不会在图中引入新值的缩放方法,以确保地图的准确性。

接下来的部分专门针对LabPics1数据集的格式。注解映射(_annmap)包含一个通道中的血管分割图,以及另一个通道中的材料注解图。我们要将它们合并成单一映射。

      mat_map = ann_map[:,:,0] # 材料标注图(材料注解图)  
      ves_map = ann_map[:,:,2] # 血管标注图(血管注解图)  
      mat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # 将mat_map中值为0的部分替换为ves_map中对应位置的值乘以mat_map的最大值加1

这给我们提供了一张图像(_matmap),其中每个像素的值是它所属的段的编号(例如:所有值为3的单元格都属于第3段)。我们希望将其转换为一系列二值掩模(0/1),其中每个掩模代表一个不同的段。此外,从每个掩模中,我们希望提取一个单独的点。

inds = np.unique(mat_map)[1:] # 地图中所有索引的列表
points = [] # 所有点的列表(每个掩码一个点)
masks = [] # 所有掩码的列表
for ind in inds:  
    mask = (mat_map == ind).astype(np.uint8) # 将索引 ind 对应的二值掩码创建出来
    masks.append(mask)  
    coords = np.argwhere(mask > 0) # 获取掩码内所有坐标
    yx = np.array(coords[np.random.randint(len(coords))]) # 从这些坐标中随机选取一个点
    points.append([[yx[1], yx[0]]])  
return Img,np.array(masks),np.array(points), np.ones([len(masks),1])

我们得到了图片 (Img),与图片中各个区域对应的二值掩膜(binary masks)列表 (masks),以及每个掩膜内单个点的坐标 (points)。

以下是一组示例训练数据:1)一张图像。2)掩膜列表。3)每个掩膜内标记一个红色点。取自LabPics数据集。

加载SAM模型(首次提及时可考虑加上全称以明确术语)

加载SAM模型

现在让我们加载网络模型。

    sam2_checkpoint = "sam2_hiera_small.pt" # 模型权重的路径  
    model_cfg = "sam2_hiera_s.yaml" # 模型配置  
    sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") # 加载模型权重  
    predictor = SAM2ImagePredictor(sam2_model) # 初始化预测器

首先,我们在 _sam2checkpoint 参数中设置权重文件的路径。我们之前从这里下载了权重文件。 “sam2_hiera_small.pt”指的是小型模型,但是代码可以适用于任何类型的模型。无论选择哪种模型,都需要在 _modelcfg 参数中设置相应的配置文件。配置文件位于主仓库中的子文件夹“sam2_configs/”中。

一个可分割的整体结构

在我们开始训练之前,我们先理解一下模型的结构。

SAM由三个部分构成:
1) 图像编码器模块,2) 提示编码器模块,3) 掩码解码器模块。

图像编码器负责处理图像并过程创建图像嵌入。这是其中最大的一个组件,训练它需要强大的GPU。

提示编码器(Prompt Encoder)处理我们的输入,也就是输入提示(input prompt),在我们的情况下是输入点(input point)。

掩码解码器从图像编码器和提示嵌入编码器获取输出,并生成最终的分割掩模。

训练设置如下:

我们可以通过设置来开启掩码解码器和提示编码器的训练。

    predictor.model.sam_mask_decoder.train(True) 
    predictor.model.sam_prompt_encoder.train(True) 

你可以通过使用‘_predictor.model.imageencoder.train(True)’来开启图像编码器的训练。

这需要一个更强大的GPU,但会为网络提供更多的进步空间。如果你选择训练图像编码器,你需要在SAM2代码中找到并移除所有“_no_grad”命令。(“_no_grad”阻止了梯度收集,这虽然节省了内存,但会阻止训练)。

接下来,我们要定义标准的adamW优化算法:

    optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)

优化器使用了AdamW算法,参数包括预测模型的所有参数,学习率为1e-5,权重衰减为4e-5。

我们也会采用混合精度训练,这仅仅是一种更节省内存的训练方法。

    scaler = torch.cuda.amp.GradScaler() # 混合精度设置
主要的训练循环过程

现在,我们开始搭建主要的训练循环部分。第一部分是读取数据并进行准备,

    for itr in range(100000):  
        with torch.cuda.amp.autocast(): # 混合精度模式  
                image,mask,input_point, input_label = read_batch(data) # 加载一批数据  
                if mask.shape[0]==0: continue # 跳过空批次  
                predictor.set_image(image) # 将图像输入到SAM图像编码器中

首先,我们将数据转换为混合精度格式以便高效训练:

with torch.cuda.amp.autocast();  # 使用torch的自动混合精度上下文管理器

接下来,我们来用我们之前创建的读取功能读取训练资料。

image, mask, 输入点, 输入标签 = 读取数据批次(data)

我们将加载的图像输入到图像编码器中(网络的第一个环节)进行编码。

预测器设置图像(image)

接下来,我们使用网络提示编码器模型处理输入数据点。

      mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)  
      sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(非归一化坐标, 标签), boxes=None, masks=None,)

注意,在这部分虽然可以输入框或遮罩效果,但我们不会使用这些选项。

现在我们已经将提示点(points)和图像都编码了,我们现在终于可以预测分割掩码了。

    batched_mode = unnorm_coords.shape[0] > 1 # 多个掩码的预测模式  
    high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]  
    low_res_masks, prd_scores, 忽略的得分, 忽略的掩码 = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=batched_mode,high_res_features=high_res_features,)  
    prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# 将低分辨率掩码上采样到原始图像分辨率

这段代码的主要部分是 _model.sam_maskdecoder,它运行掩码解码器部分,并生成低分辨率掩模(_low_resmasks)及其得分(_prdscores)。

这些掩膜的分辨率低于原始输入图像,并在_postprocessmasks函数中调整回原始输入图像的大小。

这为我们提供了网络的最终预测结果:对于每个输入点,网络会给出3个分割掩码(_prdmasks)和相应的掩码评分(_prdscores)。_prd_masks_为每个输入点提供了3个预测掩码,但我们只使用每个点的第一个掩码。_prd_scores_给出了网络对每个掩码有多好的评分(或对预测有多确定)。

损失
分段损耗

现在我们有了网络预测,我们可以计算损失。首先,我们计算分割误差,这表示预测掩膜与真实掩膜的匹配程度如何。为此,我们采用了标准交叉熵损失。

首先,我们将预测掩码(_prdmask)使用 sigmoid 函数来将 logits 转换为概率:

    prd_mask = torch.sigmoid(prd_masks[:, 0]) # 将逻辑图转化为概率图

我们将 ground truth mask(真实掩模)转换为 torch tensor(torch 张量)。

    prd_mask = torch.sigmoid(prd_masks[:, 0]) # 将逻辑图转换为概率图

最后,我们手动计算交叉熵损失值(seg_loss),使用 ground truth (gt_mask)和预测的概率图(prd_mask)来:

    seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # 交叉熵损失,这里计算的是预测掩模和真实掩模之间的差异

为了防止对数函数因0而导致结果趋向无穷,我们在数值中加0.0001。

分数扣减(可选的)

除了预测掩膜之外,网络还会为每个掩膜评估一个质量分数。虽然训练这部分不那么重要,但是它仍然非常有用。首先,我们需要确定每个预测掩膜的真实质量评分。换句话说,我们需要找出预测掩膜的实际质量如何。我们通过比较GT掩膜和预测掩膜的IOU来确定这个分数。IOU是两个掩膜重叠部分的面积除以它们总面积的比值。首先,我们要计算这两个掩膜的重叠部分。

    # 计算gt_mask和prd_mask的交集,并求和
    inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)

我们使用这个阈值(即 prd_mask > 0.5)将预测掩码从概率形式转换为二进制掩码。

接下来,我们通过将预测和真实掩码的交集面积除以它们的总面积来计算IOU。

    iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
    # iou:交并比,inter:交集区域,gt_mask:真实掩码,prd_mask:预测掩码

我们将使用 IOU(交并比)作为每个掩膜的真实分数,并得分损失则为预测分数与刚刚计算的 IOU 之间的差值的绝对值。

    score_loss = torch.abs(prd_scores[:, 0] - iou).mean()  # 计算预测得分与IOU的绝对差值的平均值

最后,我们将分词损失和评分损失(给予前者更高的权重)合并在一起:

    loss = seg_loss + 权重score_loss * 0.05  # 损失的混合(将分割损失和评分损失按比例混合)
最后一步:进行反向传播并保存模型

一旦我们得到损失值,所有步骤就完全按照标准流程进行。我们计算反向传播,并利用之前制作的优化器来更新权重。

    predictor.model.zero_grad() # 清零梯度  
    scaler.scale(loss).backward()  # 计算反向传播  
    scaler.step(optimizer) # 执行优化器步骤  
    scaler.update() # 混合精度更新  

我们也想每隔1000步保存一次训练模型:

    if itr%1000==0: torch.save(predictor.model状态, "model.torch") # 保存模型到文件

既然已经计算了IOU,我们可以将其当作移动平均值显示出来,来看看模型预测随着时间推移是否有所改进:

    如果 itr == 0: mean_iou = 0  
    mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())  
    print("步数:", itr, "准确率(IoU)=", mean_iou)

就这样了,我们用不到60行代码(不包括注释和导入)训练并微调了Segment-Anything 2模型。经过大约25,000步后,你就会看到明显的改善。

模型将会被保存为“model.torch”。

你可以在这里找到完整的训练代码:

用60行代码训练/微调Segment Anything 2模型/TRAIN.py 在 main ·…该仓库提供了用于训练/微调Meta的Segment Anything 2模型 (SAM 2) 的代码……

本教程使用每批一张图片的方法,更高效的方法是每批使用多张不同图片。相关代码可以在以下链接找到:

_https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code/blob/main/TRAIN_multi_imagebatch.py

推断:加载后使用训练好的模型:

现在模型已经调整好了,让我们用它来分割一幅图片吧。

咱们打算按照以下步骤来做这件事:

  1. 加载我们刚刚训练的模型。
  2. 给模型一张图片和一堆随机点。对于每个点,网络会预测包含它的分割掩码和一个分数。
  3. 将这些掩码拼接在一起,形成一张分割图。

完整的代码可以在以下位置找到:

用60行代码训练/微调Segment Anything Model 2 (SAM 2)/TEST_Net.py 文件 at main ·…该仓库提供了一套用于训练/微调Meta Segment Anything模型2 (SAM 2) 的代码……github.com

首先,我们加载依赖并将权重调整为float16,这让模型运行更快(仅在进行推理时有效)。

    import numpy as np  
    import torch  
    import cv2  
    from sam2.build_sam import build_sam2  
    from sam2.sam2_image_predictor import SAM2ImagePredictor  

    # 使用bfloat16来进行整个脚本的运算,这样更省内存。  
    torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

接下来,我们加载一个样本图像。下载图像和掩膜的链接如下:图像/掩膜。

    image_path = r"sample_image.jpg" # 图像文件路径  
    mask_path = r"sample_mask.png" # 掩膜路径,定义要分割的图像区域  
    def read_image(image_path, mask_path): # 读取并调整图像和掩膜大小  
            img = cv2.imread(image_path)[...,::-1]  # 以RGB读取图像  
            mask = cv2.imread(mask_path,0) # 要分割区域的掩码  

            # 调整图像到最大尺寸1024  

            r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])  
            img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))  
            mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)),interpolation=cv2.INTER_NEAREST)  
            return img, mask  
    image,mask = read_image(image_path, mask_path)

从我们想要分割的区域内随机抽取30个点:

    num_samples = 30 # 采样点的数量  
    def get_points(mask, num_points): # 在输入掩膜内获取点  
            points=[]  
            for i in range(num_points):  
                coords = np.argwhere(mask > 0)  
                yx = np.array(coords[np.random.randint(len(coords))])  
                points.append([yx[1], yx[0]])  
            return np.array(points)  
    input_points = get_points(mask, num_samples)

加载标准的SAM模型(和训练时所用的一样)

    # 先加载你需要的模型,确保你已经有预训练好的模型
    sam2_checkpoint = "sam2_hiera_small.pt"   # the path to the checkpoint file
    model_cfg = "sam2_hiera_s.yaml"   # the configuration file path for the model
    sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")  # 使用配置文件和检查点文件构建SAM2模型
    predictor = SAM2ImagePredictor(sam2_model)  # 创建一个预测器对象来处理图像

接下来,加载我们刚刚训练的模型的权重(model.torch)。

预测器模型加载来自文件 "model.torch" 的状态字典。

让我们运行微调模型,为之前选择的每个点预测其分割掩码。

    在不计算梯度的情况下  # 禁用网络计算梯度(更高效的推断)
        predictor.set_image(image)  # 图像编码器
        masks, scores, logits = predictor.predict(  # prompt编码和mask解码
            point_coords=input_points,
            point_labels=np.ones((input_points.shape[0], 1))
        )

我们现在有一系列预测的掩膜及其得分。我们希望将它们拼合成一个单一且一致的分割结果图。然而,许多掩膜重叠,可能相互不一致。
接下来,拼接的方法很简单:

首先,我们将按照预测的分数对掩码结果进行排序。

    masks=masks[:,0].astype(bool)  # 将掩码转换为布尔类型
    shorted_masks = masks[np.argsort(scores[:,0])][::-1].astype(bool)  # 对掩码按得分排序并反转

现在让我们创建一个空的分割图和占用图

    seg_map = np.zeros_like(shorted_masks[0],dtype=np.uint8)  # 分割图初始化为与缩短的掩码相同尺寸的零矩阵
    occupancy_mask = np.zeros_like(shorted_masks[0],dtype=bool)  # 占用掩码初始化为与缩短的掩码相同尺寸的布尔矩阵

接下来,我们依次按得分从高到低将掩码添加到分割图中。我们只会添加那些与已添加的掩码不冲突的掩码,换句话说,只有当我们要添加的掩码与已占用区域的重叠部分不超过15%时才会添加。

    for i in range(shorted_masks.shape[0]):  # 这个循环用于处理缩短的掩码(shorted_masks),并更新分割映射(seg_map)和占用掩码(occupancy_mask).
        mask = shorted_masks[i]  
        if (mask*occupancy_mask).sum()/mask.sum()>0.15:  # 如果掩码与占用掩码的交集超过15%,则继续下一次迭代.
            continue   
        mask[occupancy_mask]=0  
        seg_map[mask]=i+1  
        occupancy_mask[mask]=1

就这样了。

_segmask 现在包含了预测的分割结果图,每个区域有不同的值,背景值为0。

我们可以将其转换为颜色映射,方法如下:

# 相关代码片段
    rgb_image = np.zeros((seg_map.shape[0], seg_map.shape[1], 3), dtype=np.uint8)  # 创建一个全零的三维数组,用于存储RGB图像
    for id_class in range(1,seg_map.max()+1):  
        rgb_image[seg_map == id_class] = [np.random.randint(255), np.random.randint(255), np.random.randint(255)]  # 为每个类别随机生成RGB颜色值

显示:

    cv2.imshow("annotation", rgb_image)  # 显示注释图像
    cv2.imshow("mix", (rgb_image / 2 + image / 2).astype(np.uint8))  # 展示混合图像(RGB图像和图像的平均值)
    cv2.imshow("image", image)  # 显示原始图像
    cv2.waitKey()  # 等待按键

这是使用微调过的SAM2进行分割操作的一个示例。图像来自LabPics数据集。

完整的推断代码如下所示:

fine-tune-train_segment_anything_2_in_60_lines_of_code/TEST_Net.py 在main分支中…该仓库提供了用于训练和微调Meta Segment Anything Model 2 (SAM 2)的代码……github.com
最后说一下:

就这样,我们已经训练并测试了SAM2模型。除了更改数据读取器之外,这应该对任何数据集都有效。在许多情况下,这应该足以显著提升性能。

最后,SAM2 还能做到对视频中的对象进行分割和跟踪,但微调这部分功能留待以后再说。

版权: 所有发帖所用的图片均来自 SAM2 GIT 仓库(Apache 许可证)和 LabPics 数据集(MIT 许可证)。本教程的代码和模型在 Apache 许可证下。



这篇关于用 60 行代码训练/调整 SAM 2 完成的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程