pytorch源码解析系列-yolov4最核心技巧代码详解(3)- 训练过程
2021/9/24 17:10:40
本文主要是介绍pytorch源码解析系列-yolov4最核心技巧代码详解(3)- 训练过程,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
补一下源码地址
我们先从简单的开始说起,怎么判断loss?IOU(交并比)
IOU
yolov4用了CIOU_loss 和DIOU_LOSS
简单说一下,有个具体了解,都是从左到右发展来的
IOU | GIOU | DIOU | CIOU | |
---|---|---|---|---|
作用 | 主要考虑检测框和目标框重叠面积 | 在IOU的基础上,解决边界框不重合时的问题 | 在IOU和GIOU的基础上,考虑边界框中心点距离的信息 | 在DIOU的基础上,考虑边界框宽高比的尺度信息 |
具体实现 | 交并比 | 加了一个尺度相交(两个矩形外接最大矩形) | GIOU+欧式距离/中心点距离 | DIOU+长宽比 |
看代码就更直观了解他们的运作方式了
if GIoU or DIoU or CIoU: if GIoU: #area_c 就是外接矩形 area_c = torch.prod(con_br - con_tl, 2) # br tl对应button right和 top left坐标,这个公式就是算最小外接矩形面积 return iou - (area_c - area_u) / area_c # GIoU的公式, if DIoU or CIoU: #c2就是欧式距离 加一个小偏置防止除数为0 c2 = torch.pow(con_br - con_tl, 2).sum(dim=2) + 1e-16 if DIoU: #rho2 就是中心点距离 rho2 = ((bboxes_a[:, None, :2] - bboxes_b[:, :2]) ** 2 / 4).sum(dim=-1) return iou - rho2 / c2 # DIoU 的计算公式 加了个中心点距离/欧氏距离 elif CIoU: #这个V是长宽比 v = (4 / math.pi ** 2) * torch.pow(torch.atan(w1 / h1).unsqueeze(1) - torch.atan(w2 / h2), 2) with torch.no_grad(): alpha = v / (1 - iou + v) return iou - (rho2 / c2 + v * alpha) # CIoU 可以看到比Diou多了个长宽比因素 return iou
如果对上述参数不了解,可以参考一下源代码,这里贴太多反而容易混淆
Loss function
CIOU懂了 那么CIOU loss呢
其实就是CIOU loss = (1-CIOU)
GIOU,CIOU等同理
那么yolo怎么计算loss的呢
偷一下cuijiahua大佬的图
很复杂 看不懂?
没关系 实际上就是 三个loss组成的
如果有物体 就要加上: 坐标框损失,置信度损失,分类类别损失
大概知道什么意思 然后去看代码就可以了:
代码很长 可以只看我注释的地方 方便了解大体作用
class Yolo_loss(nn.Module): def __init__(self, n_classes=80, n_anchors=3, device=None, batch=2): super(Yolo_loss, self).__init__() # 这些老参数了 看我上一章内容都有 self.device = device self.strides = [8, 16, 32] image_size = 608 self.n_classes = n_classes self.n_anchors = n_anchors self.anchors = [[12, 16], [19, 36], [40, 28], [36, 75], [76, 55], [72, 146], [142, 110], [192, 243], [459, 401]] self.anch_masks = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] self.ignore_thre = 0.5 self.masked_anchors, self.ref_anchors, self.grid_x, self.grid_y, self.anchor_w, self.anchor_h = [], [], [], [], [], [] #遍历三个anchor框 这下面代码在之前都出现过 具体就是初始化那些anchor for i in range(3): all_anchors_grid = [(w / self.strides[i], h / self.strides[i]) for w, h in self.anchors] masked_anchors = np.array([all_anchors_grid[j] for j in self.anch_masks[i]], dtype=np.float32) ref_anchors = np.zeros((len(all_anchors_grid), 4), dtype=np.float32) ref_anchors[:, 2:] = np.array(all_anchors_grid, dtype=np.float32) ref_anchors = torch.from_numpy(ref_anchors) # calculate pred - xywh obj cls fsize = image_size // self.strides[i] grid_x = torch.arange(fsize, dtype=torch.float).repeat(batch, 3, fsize, 1).to(device) grid_y = torch.arange(fsize, dtype=torch.float).repeat(batch, 3, fsize, 1).permute(0, 1, 3, 2).to(device) anchor_w = torch.from_numpy(masked_anchors[:, 0]).repeat(batch, fsize, fsize, 1).permute(0, 3, 1, 2).to( device) anchor_h = torch.from_numpy(masked_anchors[:, 1]).repeat(batch, fsize, fsize, 1).permute(0, 3, 1, 2).to( device) self.masked_anchors.append(masked_anchors) self.ref_anchors.append(ref_anchors) self.grid_x.append(grid_x) self.grid_y.append(grid_y) self.anchor_w.append(anchor_w) self.anchor_h.append(anchor_h) def build_target(self, pred, labels, batchsize, fsize, n_ch, output_id): # 目标注册 tgt最后一维是4 对应除p外的标签 # (B,3,f,f,4) tgt_mask = torch.zeros(batchsize, self.n_anchors, fsize, fsize, 4 + self.n_classes).to(device=self.device) # (B,3,f,f) obj_mask = torch.ones(batchsize, self.n_anchors, fsize, fsize).to(device=self.device) tgt_scale = torch.zeros(batchsize, self.n_anchors, fsize, fsize, 2).to(self.device) target = torch.zeros(batchsize, self.n_anchors, fsize, fsize, n_ch).to(self.device) # labels = labels.cpu().data nlabel = (labels.sum(dim=2) > 0).sum(dim=1) #label数量统计 # label对应的是x,y,w,h 所以X=x+w,Y=y+h 下面宽高还要除以步长 truth_x_all = (labels[:, :, 2] + labels[:, :, 0]) / (self.strides[output_id] * 2) truth_y_all = (labels[:, :, 3] + labels[:, :, 1]) / (self.strides[output_id] * 2) truth_w_all = (labels[:, :, 2] - labels[:, :, 0]) / self.strides[output_id] truth_h_all = (labels[:, :, 3] - labels[:, :, 1]) / self.strides[output_id] truth_i_all = truth_x_all.to(torch.int16).cpu().numpy() truth_j_all = truth_y_all.to(torch.int16).cpu().numpy() for b in range(batchsize): n = int(nlabel[b]) if n == 0: continue truth_box = torch.zeros(n, 4).to(self.device) truth_box[:n, 2] = truth_w_all[b, :n] truth_box[:n, 3] = truth_h_all[b, :n] truth_i = truth_i_all[b, :n] truth_j = truth_j_all[b, :n] # calculate iou between truth and reference anchors anchor_ious_all = bboxes_iou(truth_box.cpu(), self.ref_anchors[output_id], CIoU=True) # temp = bbox_iou(truth_box.cpu(), self.ref_anchors[output_id]) best_n_all = anchor_ious_all.argmax(dim=1) best_n = best_n_all % 3 best_n_mask = ((best_n_all == self.anch_masks[output_id][0]) | (best_n_all == self.anch_masks[output_id][1]) | (best_n_all == self.anch_masks[output_id][2])) if sum(best_n_mask) == 0: continue truth_box[:n, 0] = truth_x_all[b, :n] truth_box[:n, 1] = truth_y_all[b, :n] pred_ious = bboxes_iou(pred[b].view(-1, 4), truth_box, xyxy=False) pred_best_iou, _ = pred_ious.max(dim=1) pred_best_iou = (pred_best_iou > self.ignore_thre) pred_best_iou = pred_best_iou.view(pred[b].shape[:3]) # set mask to zero (ignore) if pred matches truth obj_mask[b] = ~ pred_best_iou for ti in range(best_n.shape[0]): if best_n_mask[ti] == 1: i, j = truth_i[ti], truth_j[ti] a = best_n[ti] obj_mask[b, a, j, i] = 1 tgt_mask[b, a, j, i, :] = 1 target[b, a, j, i, 0] = truth_x_all[b, ti] - truth_x_all[b, ti].to(torch.int16).to(torch.float) target[b, a, j, i, 1] = truth_y_all[b, ti] - truth_y_all[b, ti].to(torch.int16).to(torch.float) target[b, a, j, i, 2] = torch.log( truth_w_all[b, ti] / torch.Tensor(self.masked_anchors[output_id])[best_n[ti], 0] + 1e-16) target[b, a, j, i, 3] = torch.log( truth_h_all[b, ti] / torch.Tensor(self.masked_anchors[output_id])[best_n[ti], 1] + 1e-16) target[b, a, j, i, 4] = 1 target[b, a, j, i, 5 + labels[b, ti, 4].to(torch.int16).cpu().numpy()] = 1 tgt_scale[b, a, j, i, :] = torch.sqrt(2 - truth_w_all[b, ti] * truth_h_all[b, ti] / fsize / fsize) return obj_mask, tgt_mask, tgt_scale, target def forward(self, xin, labels=None): loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2 = 0, 0, 0, 0, 0, 0 for output_id, output in enumerate(xin): batchsize = output.shape[0] fsize = output.shape[2] n_ch = 5 + self.n_classes output = output.view(batchsize, self.n_anchors, n_ch, fsize, fsize) output = output.permute(0, 1, 3, 4, 2) # .contiguous() # logistic activation for xy, obj, cls output[..., np.r_[:2, 4:n_ch]] = torch.sigmoid(output[..., np.r_[:2, 4:n_ch]]) pred = output[..., :4].clone() pred[..., 0] += self.grid_x[output_id] pred[..., 1] += self.grid_y[output_id] pred[..., 2] = torch.exp(pred[..., 2]) * self.anchor_w[output_id] pred[..., 3] = torch.exp(pred[..., 3]) * self.anchor_h[output_id] obj_mask, tgt_mask, tgt_scale, target = self.build_target(pred, labels, batchsize, fsize, n_ch, output_id) # loss calculation output[..., 4] *= obj_mask output[..., np.r_[0:4, 5:n_ch]] *= tgt_mask output[..., 2:4] *= tgt_scale target[..., 4] *= obj_mask target[..., np.r_[0:4, 5:n_ch]] *= tgt_mask target[..., 2:4] *= tgt_scale loss_xy += F.binary_cross_entropy(input=output[..., :2], target=target[..., :2], weight=tgt_scale * tgt_scale, reduction='sum') loss_wh += F.mse_loss(input=output[..., 2:4], target=target[..., 2:4], reduction='sum') / 2 loss_obj += F.binary_cross_entropy(input=output[..., 4], target=target[..., 4], reduction='sum') loss_cls += F.binary_cross_entropy(input=output[..., 5:], target=target[..., 5:], reduction='sum') loss_l2 += F.mse_loss(input=output, target=target, reduction='sum') loss = loss_xy + loss_wh + loss_obj + loss_cls return loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2
今天累了 代码写到这 后续补完
这篇关于pytorch源码解析系列-yolov4最核心技巧代码详解(3)- 训练过程的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2025-01-08CCPM如何缩短项目周期并降低风险?
- 2025-01-08Omnivore 替代品 Readeck 安装与使用教程
- 2025-01-07Cursor 收费太贵?3分钟教你接入超低价 DeepSeek-V3,代码质量逼近 Claude 3.5
- 2025-01-06PingCAP 连续两年入选 Gartner 云数据库管理系统魔力象限“荣誉提及”
- 2025-01-05Easysearch 可搜索快照功能,看这篇就够了
- 2025-01-04BOT+EPC模式在基础设施项目中的应用与优势
- 2025-01-03用LangChain构建会检索和搜索的智能聊天机器人指南
- 2025-01-03图像文字理解,OCR、大模型还是多模态模型?PalliGema2在QLoRA技术上的微调与应用
- 2025-01-03混合搜索:用LanceDB实现语义和关键词结合的搜索技术(应用于实际项目)
- 2025-01-03停止思考数据管道,开始构建数据平台:介绍Analytics Engineering Framework