Torchvision框架学习之FCOS模型及其训练

科技   2024-11-20 22:35   江苏  

点击上方蓝字关注我们

微信公众号:OpenCV学堂

关注获取更多计算机视觉与深度学习知识

FCOS介绍

FCOS(Fully Convolutional One Stage)是于2019年由Zhi Tian等人提出的一种全卷积单阶段无锚框检测模型。该模型无需在训练时计算锚框的IoU值,极大的简化了模型的训练,使得训练方便,并且在推理时只有NMS(None Maximum Suppression,非及大值抑制),使得推理速度极快。在torchvision模型库中包含了以Resnet50作为骨干网络的FCOS,这样只需要创建模型进行训练即可。

模型介绍

FCOS模型的结构如图10.8所示,整个模型从左向右,可以分为三个部分:Backbone(骨干网络),Feature Pyramid(特征金字塔),以及由Classification、Center-ness和Regression构成的多尺度目标检测头作为输出。


Backbone主要用于提取提取不同尺度的特征,一般使用预训练的分类网络的特征提取部分作为BackBone,FCOS的Backbone,作为模型的入口,接受输入的图像,从Backbone中输出C3,C4和C5特征图。各特征图的高和宽标记在各特征图的左侧,分别为输入图像的1/8,1/16和1/32,对于输入为800×1024的图像,C3的尺寸为100×128,C4的尺寸为50×64,C5的尺寸为25×32。

Feature Pyramid接收来自骨干网络的C3,C4和C5特征图,一方面以C5特征图为基础构造尺寸为13×16的P6特征图和尺寸为7×8的P7特征图,另一方面将C5上采样和C4合并构造与C4同尺寸的P4,并将P4上采样和C3构造出与C4同尺寸的P3。这样通过Feature Pyramid总共构造出了P3,P4,P5,P6和P7共5个尺寸的特征图,目标检测就在这5个不同尺寸的特征图上完成。


多尺度目标检测头负责对从Feature Pyramid传来P3-7共5个特征图进行检测。在每个特征图上进行目标检测的头都具有相同的结构。每个头包含两个独立的部分,Classification结构和Center-ness结构共享一个部分,其中Classification结构输出大小为H×W×C形状的张量表示在H×W大小的特征图上检测总共C个类别的结果,Center-ness结构输出大小为H×W×1形状的张量表示在H×W的特征图上,是否为目标的中心;Regression独自为一个部分,输出大小为H×W×4形状的张量,表示以该元素为中心的目标到该元素的4个距离,如图10.9所示。图10.9显示了FCOS模型在Regression部分学习的4个距离,分别是中心到左边界的距离,中心到右边界的距离,中心到上边界的距离和中心到下边界的距离。这样FCOS模型就彻底抛弃了锚框IoU的计算,直接以这四个距离进行训练。

图10.8  FCOS目标检测模型

图10.9  FCOS目标外接矩形的表示


由于在torchvision中包含了以ResNet50为骨干网络的FCOS模型,并且还可以加载COCO上预训练的模型,这样创建一个FCOS模型就非常方便,创建方法与上一节创建预训练模型方法相同:
from torchvision.models import detectionmodel = detection.fcos_resnet50_fpn(progress=True,num_classes=3,          pretrained_backbone=True,          trainable_backbone_layers=4)
以上就可以创建一个具有检测3个类别,以带有预训练参数的ResNet50为骨干网络的FCOS模型,对于创建模型时其他参数及其含义可以参考API文档。

数据集制作

由于目标检测模型多样,因此,在训练前对于数据集的构建方法会有所差异。对于torchvision包中提供的所有目标检测模型已经对训练数据的格式进行了统一,因此,只需要把数据按照统一的方式进行构建后,torchvision包内的其它目标检测模型也可以使用。

由于通用目标检测数据集通常较大,不便于进行原理的演示。在这里使用一个样本量较小,类别数较小的目标检测数据集——螺丝螺母检测数据集。螺丝螺母检测数据集是一个开源目标检测数据集,下载地址为:
https://aistudio.baidu.com/aistudio/datasetdetail/6045
同时,该数据集也附于本书的电子资料中。

图10.10  螺丝螺母检测数据集中的样本

螺丝螺母数据集包括413张训练集和10张测试集两部分。图10.10显示了一个带有标注的训练集中的样本,在样本中螺丝和螺母放置于一个白色托盘中,托盘放置在一灰色平台上,螺丝螺母使用矩形框进行标注。以下就以该数据集为例构造用于训练torchvision中模型的数据集。

使用torchvision中的目标检测模型训练自定义数据同样需要对数据集进行封装,并在得到样本的__getitem__()方法中返回一个表示样本元组的数据和标签,以(x, y)表示,其中x是一个范围为0-1的3×H×W的图像张量,y表示图像x的标签,是一个包含‘label’和‘boxes’两个键的字典,‘label’键里以整数张量的形式存储了图像中K个目标的标签值,‘boxes’键里存储了图像中K个对应目标外边矩形框的左上和右下共4个坐标值组成的一个K×4的数字张量,格式如下所示:
#样本标签y的格式:
{'labels': tensor([1, 1, 1, 2, 2, 2, 2, 2]),  'boxes': tensor([[ 711,  233,  844,  506],          [1036,  194, 1206,  459],          [ 958,  406, 1239,  573],          [1142,  194, 1275,  320],          [ 780,  478,  908,  614],          [ 766,  612,  914,  742],          [ 972,  542, 1120,  678],          [ 986,  684, 1120,  820]])}
#以上表明样本中包含8个目标,3个目标的类别为1,5个的目标的类别为2。按照上述要求,通过继承torch.utils.data.Dataset类创建一个自定义的数据集,实现螺丝螺母数据集的构造:
from pathlib import Pathfrom torchvision.io import read_image,ImageReadModeimport jsonimport torchclass BNDataset(torch.utils.data.Dataset):    def __init__(self, istrain=True,datapath='D:/data/lslm'):   #注意修改数据集路径        self.datadir=Path(datapath)/('train' if istrain else 'test')        self.idxfile=self.datadir/('train.txt' if istrain else 'test.txt')        self.labelnames=['background','bolt','nut']        self.data=self.parseidxfile()    def parseidxfile(self):        lines=open(self.idxfile).readlines()        return [line for line in lines if len(line)>5]

def __getitem__(self, idx): data=self.data[idx].split('\t') x = read_image(str(self.datadir/data[0]),ImageReadMode.RGB)/255.0 labels=[] boxes=[] for i in data[2:]: if len(i)<5: continue r=json.loads(i) labels.append(self.labelnames.index( r['value'])) cords=r['coordinate'] xyxy=cords[0][0],cords[0][1],cords[1][0],cords[1][1] boxes.append(xyxy) y = { 'labels': torch.LongTensor(labels), 'boxes': torch.tensor(boxes).long() } return x, y

def __len__(self): return len(self.data)
以上代码对螺丝螺母数据集以BNDataset为类名进行封装,主要涉及的难点就是标签文件的解析,具体解析过程要结合上述代码和标签文件进行理解。在模型进行训练时,还需要使用把数据集进一步使用DataLoader封装:
def collate_fn(data):    x = [i[0] for i in data]    y = [i[1] for i in data]    return x, y

train_loader = torch.utils.data.DataLoader(dataset=BNDataset(istrain=True), batch_size=4, shuffle=True, drop_last=True, collate_fn=collate_fn)test_loader = torch.utils.data.DataLoader(dataset=BNDataset(istrain=False), batch_size=1, shuffle=True, drop_last=True, collate_fn=collate_fn)
在封装完成后,就可以使用上一节提到的可视化方法,进行样本和标签的可视化,以检查数据集构造的正确性,得到如图10.10所示的结果:
for i, (x, y) in enumerate(train_loader):    labels=[loader.dataset.labelnames[i] for i in y[0]['labels']]colors=[ ('red' if i=='nut' else 'blue') for i in labels]image=draw_bounding_boxes(x[0], y[0]['boxes'],labels=labels,colors=colors,width=5,font_size=50,outtype='CHW')vis.image(image)
    #图10.10所示
以上完成了螺丝螺母数据集的构造,能够用于FCOS模型的训练。下面介绍FCOS模型在该数据集上的训练以及模型的评估。

训练与预测

由于torchvision对FCOS模型进行了很好的封装,在准备好数据集后,训练方法与分类和分割网络的训练模式并无太大差异:创建优化器,构造损失函数,对数据集进行多次循环并根据反向传播的梯度进行参数的修正。将训练过程封装为train()函数,调用train()函数进行模型的训练,代码如下:
def train():    model.train()    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001,momentum=0.98)    for epoch in range(5):        for i, (x, y) in enumerate(train_loader):            outs = model(x, y)            loss = outs['classification']+ outs['bbox_ctrness']+outs['bbox_regression']            loss.backward()            optimizer.step()            optimizer.zero_grad()            if i % 10 == 0:                print(epoch, i, loss.item())        torch.save(model, f'./models/tvs{epoch}.model')
train()

#输出结果:

0 0 2.1866644620895386......4 100 0.9854792356491089
以上就是FCOS模型的训练代码,其中train()函数实现了模型的训练,在该函数中,将模型切换到训练模式,创建SGD优化器,总共训练5轮(可根据情况训练更多轮数),每10批打印损失值,经过4轮训练后,损失值从2.18降为了0.98,并且在每轮训练完成后都保存模型。


在模型训练完成后,就可以在测试集上查看和评估模型的检测效果。评估方法实质上与之前介绍的模型的使用方法是相同的,可以参考上一节的内容以便于理解。对模型在测试集上进行运行,并可视化结果,测试代码如下:
def test():    model_load = torch.load('./models/tvs4.model')    model_load.eval()    loader_test = torch.utils.data.DataLoader(dataset=BNDataset(istrain=False), batch_size=1, shuffle=False, drop_last=True, collate_fn=collate_fn)

for i, (x, y) in enumerate(loader_test): with torch.no_grad(): outs = model_load(x) res=outs[0] boxes=res['boxes'] scores=res['scores'] labels=res['labels'] #阈值过滤 threhold=0.5 #保留类别概率大于0.5的检测结果 mask=scores>threhold scores=scores[mask] labels=labels[mask] boxes=boxes[mask] labelnames=[loader_test.dataset.labelnames[i]+f'{scores[idx]:.2f}' for idx,i in enumerate(labels)] colors=[ ('red' if i==1 else 'blue') for i in labels] img=draw_bounding_boxes(x[0],boxes,labels=labelnames,colors=colors,width=3,font_size=50,outtype=’CHW’)vis.image(img)

#使用visdom可视化,图10.11


图10.11显示了经过5轮的训练,FCOS模型在螺丝螺母数据集上的测试结果,其中(a),(b)中的螺丝和螺母均被正确的检出,特别是(a)中右侧两个螺母十分靠近也正确的得到检出,而(c),(d)中均有部分螺丝被错误检测。从结果可以看出,虽然仅仅进行了5轮的训练,但模型就已经能够较好地检出螺丝和螺母,可以预测经过更多轮数的训练能够取得更好的检测效果。

以上就是使用FCOS在螺丝螺母数据集上的训练和评估,对于torchvision中的其他模型的训练只需要修改模型,并在train()函数中根据模型的输出,修改计算损失各部分的构成和各部分损失的权重即可。



如何学习计算机视觉之Pytorch数字图像处理,实现FCOS模型的自定义数据从标注到训练部署,更多内容请参考《计算机视觉之Pytorch数字图像处理》一书。


福利时间


在本文下方留言,至下周一(11月25日晚22:00)

点赞最高的第1到4名,可获赠书一本


免费图书等你来取!

快快留言+点赞+转发, 你就有会免费获取本书!


OpenCV学堂
三本书《Java数字图像处理-编程技巧与应用实践》、《OpenCV Android开发实战》、《OpenCV4应用开发-入门、进阶与工程化实践》作者。OpenCV实验大师平台 软件作者,OpenCV开发专家、OpenCV研习社创始人。
 最新文章