Training Faster Region-based Convolutional Neural Network with Custom Dataset

需要使用半精度。


#

代码架构与逻辑几乎与之前\(\text{SSD}\)那篇文章一致。然而由于\(\text{Faster R-CNN}\)虚拟内存使用量较大,不仅在训练时需减小\(batchsize\),而且在推理时需要使用半精度。

# 数据来源与预处理

此处。其中包括数据迭代器构建、实例化迭代器、对数据集的\(\text{transform}\)等。

# 训练准备

实例化模型

实例化\(\text{PyTorch}\)提供的\(\text{fasterrcnn_resnet50_fpn_v2}\),并采用预训练好的\(\textbf{backbone}\)。需要注意的是这里仅有大香蕉需要检测,因为\(\text{label.csv}\)里边缘框标签仅含大香蕉不含背景类,参数 num_classes 应设为\(\text{实际类别数}+1\),即想要识别的目标类别和背景。

# 定义模型
pyt_frcnn = models.detection.fasterrcnn_resnet50_fpn_v2(weights_backbone=ResNet50_Weights.IMAGENET1K_V2, num_classes=2)
# 丢到显卡上
pyt_frcnn.to(device='cuda:0')

注:如果训练自己的数据集,则只在\(\text{backbone}\)上使用预训练权重。如果想使用基于其他数据集训练好的整个\(\text{Faster R-CNN}\),则需要设置参数\(\text{weights=FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1}\)。

定义优化器

接下来使用\(\text{Adam}\)优化器基于大香蕉数据训练\(\text{FasterRCNN_ResNet50_FPN_V2}\)。

optimizer = optim.Adam(pyt_frcnn.parameters(), lr=0.0001, weight_decay=5e-4)

# 训练

注:训练前没有定义损失函数。因为当给 PyTorch 的\(\text{Faster R-CNN}\)分别传入\(\textbf{栅格图像}\)和\(\textbf{边缘框及其标签类别字典}\)时,模型将返回含有两个部分损失的字典。

定义参数

num_epoch = 10
best_model_loss = None  # 初始化最佳损失
valid_loss_list = []  # 记录所有epoch的验证集损失
valid_loss = torch.Tensor().to(device='cuda:0')  # 记录一个epoch内每次迭代的验证集batch损失。用于求平均,得到每个epoch的验证集损失

开始训练

每轮\(\text{epoch}\)都以学习到的现有参数进行一次评估,若得到更好的结果则保存模型。

# 开始训练
num_epoch = 10
best_model_loss = None  # 初始化最佳损失
valid_loss_list = []  # 记录所有epoch的验证集损失
valid_loss = torch.Tensor().to(device='cuda:0')  # 记录一个epoch内每次迭代的验证集batch损失。用于求平均,得到每个epoch的验证集损失
for epoch in range(num_epoch):
    epoch_loss = 0
    pyt_frcnn.train()
    for images, targets in train_iter:
        optimizer.zero_grad()
        images = images.to('cuda:0')
        # 把目标转换为[{'boxes': 一个或多个边缘框, 'labels': 与边缘框对应的标签}, {'boxes': 一个或多个边缘框, 'labels': 与边缘框对应的标签}, ..., {'boxes': 一个或多个边缘框, 'labels': 与边缘框对应的标签}]
        targets = [{f'{k}': v[idx].to('cuda:0') for k, v in targets.items()} for idx in range(len(targets.get('boxes')))]
        loss_dict = pyt_frcnn(images, targets)
        l = sum(loss for loss in loss_dict.values())
        l.backward()
        optimizer.step()
        epoch_loss += l.cpu().detach().numpy()
        print(l.cpu().detach().numpy())
    print(f"Epoch [{epoch + 1}/{num_epoch}], Loss: {epoch_loss:.4f}")

    valid_loss = torch.Tensor().to(device='cuda:0')
    for images, targets in valid_iter:
        images = images.to('cuda:0')
        targets = [{f'{k}': v[idx].to('cuda:0') for k, v in targets.items()} for idx in range(len(targets.get('boxes')))]
        with torch.no_grad():
            loss_valid = sum(loss for loss in pyt_frcnn(images, targets).values())
            valid_loss = torch.cat([valid_loss, loss_valid.reshape(-1)])
    valid_loss_list.append(valid_loss.to(device="cpu").numpy().flatten().mean())
    if best_model_loss is None:
        best_model_loss = valid_loss_list[-1]
        torch.save(pyt_frcnn.state_dict(), '/banana-detection/rcnn_best_model.ckpt')
        print('到达世界最高城——理塘!')
    if valid_loss_list[-1] < best_model_loss:
        best_model_loss = valid_loss_list[-1]
        torch.save(pyt_frcnn.state_dict(), '/banana-detection/rcnn_best_model.ckpt')
        print('到达世界最高城——理塘!')
    print(f'epoch_{epoch}的验证精度为{valid_loss.to(device="cpu").numpy().flatten().mean():.3f}')

# 推理

设置模型

推理时,先实例化模型并加载训练好的权重。然后将模型转为评估与半精度模式。

# 初始化模型
pyt_frcnn = models.detection.fasterrcnn_resnet50_fpn_v2(weights_backbone=ResNet50_Weights.IMAGENET1K_V2, num_classes=2).to('cuda:0')
pyt_frcnn.load_state_dict(torch.load('/banana-detection/rcnn_best_model.ckpt'))
# 转半精度
pyt_frcnn.eval().half()

实例化迭代器并得到\(16\)张香蕉图

这\(16\)张香蕉图也需要转为半精度,并且在出图时转为\(\text{FP32}\),以满足 Matplotlib 的要求。

test_df = pandas.read_csv('/banana-detection/bananas_val/label.csv', header=0)
test_dataset = test_dateset(df=test_df, image_dir='/banana-detection/bananas_train/images', transform=trans)
test_iter = data.DataLoader(test_dataset, batch_size=16, shuffle=True)
images = next(iter(test_iter))  # 来笔数据
images = [img.to('cuda:0').half() for img in images]  # 转半精度

进行推理

imgs_pred_list = pyt_frcnn(images)

# 绘图

初始化\(\text{Figure}\)和\(\text{Axis}\)

fig, ax = pyplot.subplots(2, 8, figsize=(16, 4))

根据每一张图片来绘制坐标系

for idx, img in enumerate(images):
    # 转为FP32
    img = img.float().permute(1, 2, 0).to('cpu')
    if idx >= 8:
        ax[1, idx - 8].imshow(img)
    else:
        ax[0, idx].imshow(img)
    img_pred = imgs_pred_list[idx]
    bbox, scores, labels = img_pred['boxes'], img_pred['scores'], img_pred['labels']
    nums = torch.argwhere(scores > 0.5).shape[0]
    bbox = bbox.to('cpu')
    if idx >= 8:
        for i in range(nums):
            x1, y1, x2, y2 = bbox[i].detach().numpy().astype('int')
            ax[1, idx - 8].add_patch(patches.Rectangle(xy=(x1, y1), width=x2 - x1, height=y2 - y1, edgecolor='red', facecolor='None'))
    else:
        for i in range(nums):
            x1, y1, x2, y2 = bbox[i].detach().numpy().astype('int')
            ax[0, idx].add_patch(patches.Rectangle(xy=(x1, y1), width=x2 - x1, height=y2 - y1, edgecolor='red', facecolor='None'))
pyplot.show()
Subscribe
Notify of
guest
0 Comments
Inline Feedbacks
View all comments