从ResNet到Transformer:手把手图解DETR模型结构(附PyTorch关键代码解析)

张开发
2026/6/9 13:45:27 15 分钟阅读
从ResNet到Transformer:手把手图解DETR模型结构(附PyTorch关键代码解析)
从ResNet到Transformer手把手图解DETR模型结构附PyTorch关键代码解析在计算机视觉领域目标检测一直是一个核心任务。传统方法如Faster R-CNN和YOLO系列虽然效果不错但都依赖于手工设计的组件如锚框anchor boxes和非极大值抑制NMS。DETRDetection Transformer的出现彻底改变了这一局面。本文将带你深入理解DETR的每个模块并通过PyTorch代码实现关键部分。1. DETR整体架构概览DETR的核心思想是将目标检测视为一个集合预测问题。它由以下几个主要组件构成Backbone通常是ResNet用于提取图像特征Transformer Encoder处理空间特征Transformer Decoder使用可学习的object queries来生成预测Prediction Heads输出最终的类别和边界框import torch import torch.nn as nn from torchvision.models import resnet50 class DETR(nn.Module): def __init__(self, num_classes, hidden_dim256, nheads8, num_encoder_layers6, num_decoder_layers6): super().__init__() self.backbone resnet50(pretrainedTrue) self.conv nn.Conv2d(2048, hidden_dim, 1) self.transformer nn.Transformer( hidden_dim, nheads, num_encoder_layers, num_decoder_layers) self.linear_class nn.Linear(hidden_dim, num_classes 1) self.linear_bbox nn.Linear(hidden_dim, 4) self.query_pos nn.Parameter(torch.rand(100, hidden_dim))2. Backbone特征提取DETR使用ResNet作为backbone来提取图像特征。标准的ResNet-50会将输入图像下采样32倍输出2048通道的特征图。特征提取过程的关键参数参数输入尺寸输出尺寸说明输入图像3×H×W-原始RGB图像ResNet输出-2048×(H/32)×(W/32)32倍下采样1×1卷积2048×(H/32)×(W/32)256×(H/32)×(W/32)降维到Transformer所需维度def forward(self, inputs): # 特征提取 features self.backbone.conv1(inputs) features self.backbone.bn1(features) features self.backbone.relu(features) features self.backbone.maxpool(features) features self.backbone.layer1(features) features self.backbone.layer2(features) features self.backbone.layer3(features) features self.backbone.layer4(features) # 降维 features self.conv(features) return features3. Transformer Encoder详解Encoder的作用是增强特征的空间关系理解。它将2D特征图展平为1D序列并添加位置编码。Encoder处理流程展平空间维度从256×(H/32)×(W/32)变为256×HW添加可学习的位置编码通过多层Transformer Encoder处理# 位置编码实现 class PositionEmbeddingLearned(nn.Module): def __init__(self, num_pos_feats256): super().__init__() self.row_embed nn.Parameter(torch.rand(50, num_pos_feats // 2)) self.col_embed nn.Parameter(torch.rand(50, num_pos_feats // 2)) def forward(self, tensor): h, w tensor.shape[-2:] i torch.arange(w, devicetensor.device) j torch.arange(h, devicetensor.device) x_emb self.col_embed[i] y_emb self.row_embed[j] pos torch.cat([ x_emb.unsqueeze(0).repeat(h, 1, 1), y_emb.unsqueeze(1).repeat(1, w, 1), ], dim-1).permute(2, 0, 1) return pos4. Transformer Decoder与Object QueriesDecoder的核心是可学习的object queries它们像问题一样向Encoder特征提问最终得到预测结果。Object Queries的关键特性数量固定通常为100可学习的位置编码通过自注意力和交叉注意力与Encoder交互# Decoder处理流程 def forward_decoder(self, memory): bs memory.shape[1] tgt torch.zeros_like(self.query_pos) hs self.transformer.decoder(tgt.unsqueeze(1).repeat(1, bs, 1), memory.unsqueeze(0)) return hs.transpose(1, 2)5. 预测头与匈牙利损失DETR使用两个独立的FFN分别预测类别和边界框。损失函数采用匈牙利算法进行二分匹配。预测头结构预测头输入维度输出维度激活函数分类头256num_classes1Softmax回归头2564Linear# 匈牙利损失实现 def hungarian_loss(pred_logits, pred_boxes, targets): # 计算类别损失 cost_class -pred_logits[:, targets[labels]] # 计算框损失(L1 GIoU) cost_bbox torch.cdist(pred_boxes, targets[boxes], p1) cost_giou -generalized_box_iou(pred_boxes, targets[boxes]) # 总成本矩阵 C cost_class cost_bbox cost_giou C C.view(bs * num_queries, -1) # 使用匈牙利算法找到最优匹配 indices linear_sum_assignment(C) return indices6. 训练技巧与优化DETR训练有几个关键技巧学习率调整Backbone使用较小的学习率通常为其他部分的1/10梯度裁剪防止梯度爆炸辅助损失在Decoder的每一层都计算损失长训练周期通常需要500个epoch才能收敛提示在实际训练中可以使用预训练的DETR权重进行微调这比从头训练要高效得多。7. 实际应用中的注意事项在部署DETR模型时有几个实际问题需要考虑计算资源Transformer的计算复杂度与图像尺寸的平方成正比小物体检测DETR对小物体的检测效果相对较弱训练时间相比传统检测器需要更长的训练时间# 推理时的后处理 def postprocess(outputs, target_sizes): out_logits, out_bbox outputs[pred_logits], outputs[pred_boxes] prob F.softmax(out_logits, -1) scores, labels prob[..., :-1].max(-1) # 将框坐标转换为图像尺寸 img_h, img_w target_sizes.unbind(1) scale_fct torch.stack([img_w, img_h, img_w, img_h], dim1) boxes box_cxcywh_to_xyxy(out_bbox) * scale_fct[:, None, :] results [{scores: s, labels: l, boxes: b} for s, l, b in zip(scores, labels, boxes)] return results8. DETR的变体与改进自DETR提出以来研究者们提出了多种改进版本变体主要改进效果提升Deformable DETR可变形注意力机制训练更快小物体检测更好Conditional DETR条件空间查询加速收敛DAB-DETR动态锚框查询更稳定的训练DN-DETR去噪训练解决二分匹配不稳定性在项目中尝试这些改进版本时我发现Deformable DETR在保持精度的同时显著减少了训练时间特别是在处理高分辨率图像时优势更为明显。

更多文章