Training Single Shot MultiBox Detector with Custom Data

试一试SSD。


#

李沐老师在课程中详细讲述了如何基于\(\text{PyTorch}\)从〇实现单发多框检测算法\(\text{ (SSD)}\),关于实现该算法的细节在视频中被一一呈现,然而却忽略了基于成熟框架的简便实现方法。实际上,库中提供了名为\(\text{ssd300_vgg16}\)的\(\text{SSD}\)实例化方法,该方法基于最初实现\(\text{SSD}\)的文章\(\text{ (Liu et al., 2015)}\),可以快速复现文章提及的网络架构。通过调用该方法处理目标检测任务,能极大提高用户工作效率并简化工作方式。

然而,可能由于\(\text{SSD}\)多年来缺乏维护,不仅社区中鲜有讨论其基于\(\text{torch}\)的简化实现,连官方文档对其介绍也是寥寥数语。因此,为弥补上述知识缺口,本文将以训练检测香蕉的\(\text{SSD}\)模型为例,简明扼要说明基于\(\text{PyTorch}\)使用自定义数据集训练\(\text{SSD}\)的代码实例及其注意事项,供初学者参考借鉴。

# 数据来源与预处理

数据来源

本示例使用的复杂背景下多模态大香蕉可以通过该直链下载。解压后得到文件目录结构如下所示的目标检测数据集。

├── bananas_train
│   ├── images
│   │   ├── 0.png
│   │   ├── 1.png
│   │   ├── 2.png
│   │   ├── ...
│   │   ├── 999.png
│   └── label.csv
├── bananas_val
│   ├── images
│   │   ├── 0.png
│   │   ├── 1.png
│   │   ├── ...
│   │   ├── 99.png
│   └── label.csv
└── best_model.ckpt

其中:每张图片由\(256\times 256\)个像素组成,每个像素含有\(\text{RGB}\)三通道信息;标签文件详细信息如\(\text{Table 1}\)所示,

列名类型示例描述
img_namestr0.png每张图片的文件名
labelint0每张图片所含目标的标签
xminint104边界框的左上角X坐标
yminint20边界框的左上角Y坐标
xmaxint143边界框的右下角X坐标
ymaxint58边界框的右下角Y坐标
Table 1. label.csv详细信息

构造训练图像数据迭代器

以便每次迭代时仅读入所需小批量数据,进而节省内存资源消耗。该函数仅返回每批量读取的图像数据和目标字典。其中目标字典包含边缘框与标签信息。

class custom_dateset(data.Dataset):
    def __init__(self, df, image_dir, transform=None):
        self.data_frame = df
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.to_list()
        img_name = os.path.join(self.image_dir, str(self.data_frame.iloc[idx, 0]))
        image = Image.open(img_name)
        # unsqueeze(0)的原因是在最前面插入一个维度,以适配模型的输入要求。同时需注意,该方法返回的是每次获取的样本,因此还不涉及批量大小。
        boxes = torch.tensor(self.data_frame.iloc[idx, 2:6].astype(numpy.float16).to_numpy()).unsqueeze(0)  # 等价于boxes = torch.tensor(self.data_frame.iloc[idx, 2:6].astype(numpy.float16).to_numpy()).reshape(1, -1)
        labels = self.data_frame.iloc[idx, 1].reshape(-1)
        target = {'boxes': boxes, 'labels': labels}
        if self.transform:
            image = self.transform(image)
        return image, target

注意,\(\text{ssd300_vgg16}\)要求输入 boxes 的形状为\(\text{(N,4)}\),\(\text{N}\)指边缘框的个数;\(\text{4}\)指从\(\text{DataFrame}\)中得到的四个坐标。代码中的\(\text{.unsqueeze(0)}\)旨在给这四个坐标组成的形状为\(\text{(4)}\)的张量增添一个维度,进而使其形状转为\(\text{(1,4)}\),这样才能对应一个\(\text{label}\)。

读取并转换\(\text{label.csv}\)

\(\text{PyTorch}\)提供的目标检测模型中,\(\text{label}=0\)默认分配给背景类,而在该数据集中\(0\)被分配给了大香蕉。因此需要适当转换数据集信息,以适配现有模型的要求。

# 定义将栅格图片转为张量的方法
trans = transforms.Compose([transforms.ToTensor()])
# 读取label.csv
train_df = pandas.read_csv('/workspace/data/banana-detection/bananas_train/label.csv', header=0)
# 将大香蕉的label变换为1,以避免和pytorch目标检测背景类冲突
train_df['label'] = 1
# 实例化数据迭代器
train_dataset = custom_dateset(df=train_df, image_dir='/workspace/data/banana-detection/bananas_train/images', transform=trans)
train_iter = data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)

同上方式处理验证数据集

valid_df = pandas.read_csv('/workspace/data/banana-detection/bananas_val/label.csv', header=0)
valid_dataset = custom_dateset(df=valid_df, image_dir='/workspace/data/banana-detection/bananas_train/images', transform=trans)
valid_iter = data.DataLoader(valid_dataset, batch_size=16, shuffle=True, num_workers=8)

# 训练准备

实例化模型

实例化\(\text{PyTorch}\)提供的\(\text{ssd300_vgg16}\),并采用预训练好的\(\textbf{backbone}\)。需要注意的是这里仅有大香蕉需要检测,因为\(\text{label.csv}\)里边缘框标签仅含大香蕉不含背景类,参数 num_classes 应设为\(\text{实际类别数}+1\),即想要识别的目标类别和背景。

# 定义模型
pyt_ssd = models.detection.ssd300_vgg16(weights_backbone=VGG16_Weights.IMAGENET1K_V1, num_classes=2)
# 丢到显卡上
pyt_ssd.to(device='cuda:0')

注:如果训练自己的数据集,则只在\(\text{backbone}\)上使用预训练权重。如果想使用基于其他数据集训练好的整个\(\text{SSD}\),则需要设置参数\(\text{weights=SSD300_VGG16_Weights.COCO_V1}\)。

定义优化器

接下来使用\(\text{Adam}\)优化器基于大香蕉数据训练\(\text{ssd300_vgg16}\)。

## 定义优化器
optimizer = optim.Adam(pyt_ssd.parameters(), lr=0.0001, weight_decay=5e-4)

# 训练

注:训练前没有定义损失函数。因为当给 PyTorch 的\(\text{SSD}\)分别传入\(\textbf{栅格图像}\)和\(\textbf{边缘框及其标签类别字典}\)时,模型将返回含有两个部分损失的字典。

定义参数

num_epoch = 2
best_model_loss = None  # 初始化最佳损失
valid_loss_list = []  # 记录所有epoch的验证集损失
valid_loss = torch.Tensor().to(device='cuda:0')  # 记录一个epoch内每次迭代的验证集batch损失。用于求平均,得到每个epoch的验证集损失

开始训练

每轮\(\text{epoch}\)都以学习到的现有参数进行一次评估,若得到更好的结果则保存模型。

## 开训
for epoch in range(num_epoch):
    epoch_loss = 0
    pyt_ssd.train()  # 切换为训练形态
    for images, targets in train_iter:  # 来一笔数据
        optimizer.zero_grad()  # 清空梯度
        images = images.to('cuda:0')  # 把数据丢显卡上
        # 把目标转换为[{'boxes': 一个或多个边缘框, 'labels': 与边缘框对应的标签}, {'boxes': 一个或多个边缘框, 'labels': 与边缘框对应的标签}, ..., {'boxes': 一个或多个边缘框, 'labels': 与边缘框对应的标签}]
        targets = [{f'{k}': v[idx].to('cuda:0') for k, v in targets.items()} for idx in range(len(targets.get('boxes')))]
        loss_dict = pyt_ssd(images, targets)  # 计算损失
        l = sum(loss for loss in loss_dict.values())  # 把分类损失和回归损失求和
        l.backward()  # 再梯度下降
        optimizer.step()  # 更新参数
        epoch_loss += l.cpu().detach().numpy()
        print(l.cpu().detach().numpy())
    print(f"Epoch [{epoch + 1}/{num_epoch}], Loss: {epoch_loss:.4f}")
    # pyt_ssd.eval()  # 不转变为评估形态,因为评估形态不返回损失字典
    valid_loss = torch.Tensor().to(device='cuda:0')
    for images, targets in valid_iter:
        images = images.to('cuda:0')
        targets = [{f'{k}': v[idx].to('cuda:0') for k, v in targets.items()} for idx in range(len(targets.get('boxes')))]
        with torch.no_grad():
            loss_valid = sum(loss for loss in pyt_ssd(images, targets).values())
            valid_loss = torch.cat([valid_loss, loss_valid.reshape(-1)])
    valid_loss_list.append(valid_loss.to(device="cpu").numpy().flatten().mean())
    if best_model_loss is None:
        best_model_loss = valid_loss_list[-1]
        torch.save(pyt_ssd.state_dict(), '/workspace/data/banana-detection/best_model.ckpt')
        print('到达世界最高城——理塘!')
    if valid_loss_list[-1] < best_model_loss:
        best_model_loss = valid_loss_list[-1]
        torch.save(pyt_ssd.state_dict(), '/workspace/data/banana-detection/best_model.ckpt')
        print('到达世界最高城——理塘!')
    print(f'epoch_{epoch}的验证精度为{valid_loss.to(device="cpu").numpy().flatten().mean():.3f}')

# 出图

为渲染两行香蕉图及预测得到的锚框,首先使用\(\text{Matplotlib}\)生成\(2\)行\(8\)列子坐标系,再得到\(16\)张香蕉图并进行遍历,基于训练好的\(\text{ssd300_vgg16}\)进行预测并渲染原始图片以及预测框。

重新构建数据迭代器

做推理时仅需输入图片的数据,因此不再需要数据迭代器返回目标的信息。

class test_dateset(data.Dataset):
    def __init__(self, df, image_dir, transform=None):
        self.data_frame = df
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.to_list()
        img_name = os.path.join(self.image_dir, str(self.data_frame.iloc[idx, 0]))
        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        return image

实例化迭代器并得到\(16\)张香蕉图

test_df = pandas.read_csv('/workspace/data/banana-detection/bananas_val/label.csv', header=0)
test_dataset = test_dateset(df=test_df, image_dir='/workspace/data/banana-detection/bananas_train/images', transform=trans)
test_iter = data.DataLoader(test_dataset, batch_size=16, shuffle=True)
images = next(iter(test_iter))  # 来笔数据
images = [img.to('cuda:0') for img in images]  # 都丢到显卡上,并转为List

初始化\(\text{Figure}\)和\(\text{Axis}\)

fig, ax = pyplot.subplots(2, 8, figsize=(16, 4))

实例化一个\(\text{ssd300_vgg16}\)并加载训练得到的权重

需将丢到显卡上并转为评估形态,这样才能更高效得做推理。

# 实例化ssd300_vgg16
pyt_ssd = models.detection.ssd300_vgg16(weights_backbone=VGG16_Weights.IMAGENET1K_V1, num_classes=2)
# 加载学习得到的参数
pyt_ssd.load_state_dict(torch.load('/workspace/data/banana-detection/best_model.ckpt'))
pyt_ssd.to('cuda:0')  # 丢到显卡上
pyt_ssd.eval()  # 转换为评估形态

推理及绘图

pyt_ssd.eval()  # 转换为评估形态,这样才能做推理
imgs_pred_list = pyt_ssd(images)
for idx, img in enumerate(images):
    img = img.permute(1, 2, 0).to('cpu')
    if idx >= 8:
        ax[1, idx - 8].imshow(img)
    else:
        ax[0, idx].imshow(img)
    img_pred = imgs_pred_list[idx]
    bbox, scores, labels = img_pred['boxes'], img_pred['scores'], img_pred['labels']
    nums = torch.argwhere(scores > 0.95).shape[0]
    bbox = bbox.to('cpu')
    if idx >= 8:
        for i in range(nums):
            x1, y1, x2, y2 = bbox[i].detach().numpy().astype('int')
            ax[1, idx - 8].add_patch(patches.Rectangle(xy=(x1, y1), width=x2 - x1, height=y2 - y1, edgecolor='red', facecolor='None'))
    else:
        for i in range(nums):
            x1, y1, x2, y2 = bbox[i].detach().numpy().astype('int')
            ax[0, idx].add_patch(patches.Rectangle(xy=(x1, y1), width=x2 - x1, height=y2 - y1, edgecolor='red', facecolor='None'))
pyplot.show()

其中:\(\text{imgs_pred_list}\)指推理返回的由一张或多张图片预测结果组成的列表,每张图片的预测结果为字典,键依次为\(\text{boxes}\)、\(\text{scores}\)以及\(\text{labels}\),形状分别是\(\text{(FloatTensor[N, 4])}\)、\(\text{(Int64Tensor[N])}\)以及\(\text{(Tensor[N])}\),\(\text{N}\)指检测到了\(\text{N}\)个目标物体。

Subscribe
Notify of
guest
0 Comments
Inline Feedbacks
View all comments