需要使用半精度。
# 序
代码架构与逻辑几乎与之前\(\text{SSD}\)那篇文章一致。然而由于\(\text{Faster R-CNN}\)虚拟内存使用量较大,不仅在训练时需减小\(batchsize\),而且在推理时需要使用半精度。
# 数据来源与预处理
见此处。其中包括数据迭代器构建、实例化迭代器、对数据集的\(\text{transform}\)等。
# 训练准备
实例化模型
实例化\(\text{PyTorch}\)提供的\(\text{fasterrcnn_resnet50_fpn_v2}\),并采用预训练好的\(\textbf{backbone}\)。需要注意的是这里仅有大香蕉需要检测,因为\(\text{label.csv}\)里边缘框标签仅含大香蕉不含背景类,参数 num_classes 应设为\(\text{实际类别数}+1\),即想要识别的目标类别和背景。
# 定义模型
pyt_frcnn = models.detection.fasterrcnn_resnet50_fpn_v2(weights_backbone=ResNet50_Weights.IMAGENET1K_V2, num_classes=2)
# 丢到显卡上
pyt_frcnn.to(device='cuda:0')
注:如果训练自己的数据集,则只在\(\text{backbone}\)上使用预训练权重。如果想使用基于其他数据集训练好的整个\(\text{Faster R-CNN}\),则需要设置参数\(\text{weights=FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1}\)。
定义优化器
接下来使用\(\text{Adam}\)优化器基于大香蕉数据训练\(\text{FasterRCNN_ResNet50_FPN_V2}\)。
optimizer = optim.Adam(pyt_frcnn.parameters(), lr=0.0001, weight_decay=5e-4)
# 训练
注:训练前没有定义损失函数。因为当给 PyTorch 的\(\text{Faster R-CNN}\)分别传入\(\textbf{栅格图像}\)和\(\textbf{边缘框及其标签类别字典}\)时,模型将返回含有两个部分损失的字典。
定义参数
num_epoch = 10
best_model_loss = None # 初始化最佳损失
valid_loss_list = [] # 记录所有epoch的验证集损失
valid_loss = torch.Tensor().to(device='cuda:0') # 记录一个epoch内每次迭代的验证集batch损失。用于求平均,得到每个epoch的验证集损失
开始训练
每轮\(\text{epoch}\)都以学习到的现有参数进行一次评估,若得到更好的结果则保存模型。
# 开始训练
num_epoch = 10
best_model_loss = None # 初始化最佳损失
valid_loss_list = [] # 记录所有epoch的验证集损失
valid_loss = torch.Tensor().to(device='cuda:0') # 记录一个epoch内每次迭代的验证集batch损失。用于求平均,得到每个epoch的验证集损失
for epoch in range(num_epoch):
epoch_loss = 0
pyt_frcnn.train()
for images, targets in train_iter:
optimizer.zero_grad()
images = images.to('cuda:0')
# 把目标转换为[{'boxes': 一个或多个边缘框, 'labels': 与边缘框对应的标签}, {'boxes': 一个或多个边缘框, 'labels': 与边缘框对应的标签}, ..., {'boxes': 一个或多个边缘框, 'labels': 与边缘框对应的标签}]
targets = [{f'{k}': v[idx].to('cuda:0') for k, v in targets.items()} for idx in range(len(targets.get('boxes')))]
loss_dict = pyt_frcnn(images, targets)
l = sum(loss for loss in loss_dict.values())
l.backward()
optimizer.step()
epoch_loss += l.cpu().detach().numpy()
print(l.cpu().detach().numpy())
print(f"Epoch [{epoch + 1}/{num_epoch}], Loss: {epoch_loss:.4f}")
valid_loss = torch.Tensor().to(device='cuda:0')
for images, targets in valid_iter:
images = images.to('cuda:0')
targets = [{f'{k}': v[idx].to('cuda:0') for k, v in targets.items()} for idx in range(len(targets.get('boxes')))]
with torch.no_grad():
loss_valid = sum(loss for loss in pyt_frcnn(images, targets).values())
valid_loss = torch.cat([valid_loss, loss_valid.reshape(-1)])
valid_loss_list.append(valid_loss.to(device="cpu").numpy().flatten().mean())
if best_model_loss is None:
best_model_loss = valid_loss_list[-1]
torch.save(pyt_frcnn.state_dict(), '/banana-detection/rcnn_best_model.ckpt')
print('到达世界最高城——理塘!')
if valid_loss_list[-1] < best_model_loss:
best_model_loss = valid_loss_list[-1]
torch.save(pyt_frcnn.state_dict(), '/banana-detection/rcnn_best_model.ckpt')
print('到达世界最高城——理塘!')
print(f'epoch_{epoch}的验证精度为{valid_loss.to(device="cpu").numpy().flatten().mean():.3f}')
# 推理
设置模型
推理时,先实例化模型并加载训练好的权重。然后将模型转为评估与半精度模式。
# 初始化模型
pyt_frcnn = models.detection.fasterrcnn_resnet50_fpn_v2(weights_backbone=ResNet50_Weights.IMAGENET1K_V2, num_classes=2).to('cuda:0')
pyt_frcnn.load_state_dict(torch.load('/banana-detection/rcnn_best_model.ckpt'))
# 转半精度
pyt_frcnn.eval().half()
实例化迭代器并得到\(16\)张香蕉图
这\(16\)张香蕉图也需要转为半精度,并且在出图时转为\(\text{FP32}\),以满足 Matplotlib 的要求。
test_df = pandas.read_csv('/banana-detection/bananas_val/label.csv', header=0)
test_dataset = test_dateset(df=test_df, image_dir='/banana-detection/bananas_train/images', transform=trans)
test_iter = data.DataLoader(test_dataset, batch_size=16, shuffle=True)
images = next(iter(test_iter)) # 来笔数据
images = [img.to('cuda:0').half() for img in images] # 转半精度
进行推理
imgs_pred_list = pyt_frcnn(images)
# 绘图
初始化\(\text{Figure}\)和\(\text{Axis}\)
fig, ax = pyplot.subplots(2, 8, figsize=(16, 4))
根据每一张图片来绘制坐标系
for idx, img in enumerate(images):
# 转为FP32
img = img.float().permute(1, 2, 0).to('cpu')
if idx >= 8:
ax[1, idx - 8].imshow(img)
else:
ax[0, idx].imshow(img)
img_pred = imgs_pred_list[idx]
bbox, scores, labels = img_pred['boxes'], img_pred['scores'], img_pred['labels']
nums = torch.argwhere(scores > 0.5).shape[0]
bbox = bbox.to('cpu')
if idx >= 8:
for i in range(nums):
x1, y1, x2, y2 = bbox[i].detach().numpy().astype('int')
ax[1, idx - 8].add_patch(patches.Rectangle(xy=(x1, y1), width=x2 - x1, height=y2 - y1, edgecolor='red', facecolor='None'))
else:
for i in range(nums):
x1, y1, x2, y2 = bbox[i].detach().numpy().astype('int')
ax[0, idx].add_patch(patches.Rectangle(xy=(x1, y1), width=x2 - x1, height=y2 - y1, edgecolor='red', facecolor='None'))
pyplot.show()