学习前言
在学会使用一个深度学习库之后,如果想要进一步融入自己的想法,要对深度学习库进行分解,如何分解与学习,是提升效率的关键。
深度学习库的组成
一般来讲,深度学习库的功能包括两个部分,一部分是训练模型,另一部分是利用模型进行预测。
在训练模型时,需要考虑模型本身,训练参数,数据加载与损失函数。
在预测模型时,需要考虑模型本身,数据加载,预测后处理。
综合起来,一个能用的深度学习库需要包含如下5个部分:
模型本身,训练参数,数据加载,损失函数,预测后处理。
一、模型本身
一般来讲,模型本身在仓库中的名字是net(网络)或者model(模型)。
以YoloV5仓库为例,https://github.com/ultralytics/yolov5,网络结构的部分就存储在model文件夹中,基本上网络的组成与构建,都会在这个部分完成。如果想要修改与了解网络结构,一般是在这一部分进行操作。
打开每一个.py文件可知,YoloV5的结构模块保存在common.py文件中,整体模型的构建在yolo.py文件中,下面第一幅图是common.py文件的内容,第二部图是yolo.py文件的内容。
YoloV5的库通过Yaml文件进行模型的构建,看起来不算特别清晰,但也确实挺方便的。
接下来以我提供的YoloV4仓库为例,https://github.com/bubbliiiing/yolov4-pytorch,网络结构的部分就存储在nets文件夹中。基本上网络的组成与构建,都会在这个部分完成。如果想要修改与了解网络结构,一般是在这一部分进行操作。
二、训练参数
训练参数一般伴随着训练文件,因此一般在train.py文件里面,每一个库指定参数的方式不同,有些喜欢通过yaml文件指定,有些喜欢通过cfg文件指定,有些甚至通过py文件指定,都不一样,这个需要参考每一个库的组成去分析。但大多数库都可以在train.py文件中找到蛛丝马迹。
以YoloV5仓库为例,https://github.com/ultralytics/yolov5,我们打开其train.py,寻找对应的训练参数。简单浏览整个train.py文件,找到入口函数,分析可知,yolov5通过argparse指定参数。argparse是python自带的命令行参数解析包,可以用来方便地读取命令行参数。结合yolov5训练的指令可知,yolov5利用argparse通过命令行获取参数。
python train.py --data coco.yaml --cfg yolov5n.yaml --weights '' --batch-size 128
除此之外,yolov5其实还通过hyps的yaml文件指定参数,也是一些训练参数,有关数据增强,学习率,损失函数等。
这种yaml文件其实用起来也蛮方便的,毕竟创建一个新的yaml文件就可以重新设置一套新的参数。有利有弊,对于新手而言,yaml文件与训练文件分离,有些时候会看不过来。对于老手而言,这样用起来就非常方便了。
接下来以我提供的YoloV4仓库为例,https://github.com/bubbliiiing/yolov4-pytorch,我们打开其train.py,寻找对应的训练参数。简单浏览整个train.py文件,找到入口函数,可以很容易发现,该仓库使用指定变量的方式来指定参数,
结合注释比较容易理解,这种方式同样有利有弊,对于新手而言,这种设定方式非常直观,前面写后面用。对于老手而言,如果要配置多套训练参数,非常繁琐。
用到的训练参数我就不一一解析了,无论是哪个库,开发者大多数参数都会给予注释,如果开发者不给予注释的话,和前面所提的一样,建议换一个仓库、
三、数据加载
数据加载分为两部分,一部分是训练的数据加载,另一部分是预测的数据加载。
1、训练部分
训练的数据加载其实是非常重要的,直接关系到模型的训练,监督模型在训练时加载的数据一般分为两部分,一部分是输入变量,通常是图片;另一部分是标签,在目标检测中就是图片对应的框的坐标,在语义分割中就是每个像素点的种类。
一般来讲,数据加载部分,模型本身在仓库中的名字是data、datasets或者dataloader。
以YoloV5仓库为例,https://github.com/ultralytics/yolov5,网络结构的部分就存储在datasets文件中。初次看yolov5的库,可能会以为数据加载部分的内容在data文件夹中,点进去会发现,其实data都是数据集下载相关的内容,这种判断错误是正常的,毕竟这属于相似概念了。
实际的数据集加载文件,datasets文件在utils文件夹中。
简单翻一下datasets,可以知道该文件通过create_dataloader函数构建文件加载器,然后通过LoadImagesAndLabels这个文件加载器的类来获取图片与标签文件。
然后在这个LoadImagesAndLabels中,算法会进行数据增强、数据预处理等操作,最终返回输入图片与标签。下图为__getitem__方法返回的图片与标签。
接下来以我提供的YoloV4仓库为例,https://github.com/bubbliiiing/yolov4-pytorch,我们比较容易的可以寻找到dataloader.py文件,也在utils文件夹中。
在YoloDataset的__getitem__方法中返回图片与标签。
2、预测部分
预测的数据加载和训练的数据加载相比,少了数据增强与标签处理的部分,因此会相对简单一些,主要是对输入图片进行预处理。
既然是预测部分的数据预处理,我们需要从预测文件开始寻找。
以YoloV5仓库为例,https://github.com/ultralytics/yolov5,在detect.py预测文件中可以发现,YoloV5通过文件加载器的方式获得预测的图片文件,在文件加载器中,我们会对图片文件进行预处理,比如resize到一定的大小,进行图片的归一化等。
接下来以我提供的YoloV4仓库为例,https://github.com/bubbliiiing/yolov4-pytorch,从predict.py预测文件开始进行分析,我们调用了YOLO类的detect_image方法。也就很容易发现,预处理的过程在yolo.py的detect_image方法中。
四、损失函数
一般来讲,损失函数在仓库中的名字是Loss(损失),Loss函数是模型优化的目标,在训练过程中Loss理论上是要被越优化越小的。
以YoloV5仓库为例,https://github.com/ultralytics/yolov5,我们在库中寻找一些损失函数,很容易可以发现,loss.py在utils文件夹中
结合train.py调用的函数来看,可以很容易发现,yolov5计算损失时,调用的是ComputeLoss类,进一步定位Loss的计算。
Loss组成的话,每个仓库有每个仓库不同的组成方式,因此解析的难度是非常大的,特别是在目标检测中,正样本的选取方式多样,很难直接对Loss有个整体的认知,想要进一步了解Loss的工作,通常要对损失进行一行、一行的分析。
接下来以我提供的YoloV4仓库为例,https://github.com/bubbliiiing/yolov4-pytorch,该库并没有直接提供名为Loss.py的损失函数文件,结合train.py的执行流程进行分析,可以比较容易的知道,在yolo_training.py文件夹中,定义了一个名为YOLOLoss的类。
进入yolo_training.py文件夹,这个命名为YOLOLoss的类,就是用于进行损失计算的了。
五、预测后处理
预测的后处理主要包括了预测结果的解码与预测图片的可视化。
既然是预测部分的后处理,我们需要从预测文件开始寻找。
以YoloV5仓库为例,https://github.com/ultralytics/yolov5,在detect.py预测文件中可以发现,YoloV5在获得预测结果之后,进行了非极大抑制,然后进行了图片的绘制与可视化。
接下来以我提供的YoloV4仓库为例,https://github.com/bubbliiiing/yolov4-pytorch,从predict.py预测文件开始进行分析,我们调用了YOLO类的detect_image方法。也就很容易发现,后处理的过程在yolo.py的detect_image方法中。在完成图片的预测后,我们对预测结果进行解码与非极大抑制,最后进行绘图。
文章出处登录后可见!