yolov4-keras使用多GPU进行训练

首先附上参考链接:
https://www.freesion.com/article/9785581741/
修改train.py文件,代码如下:

import keras
import keras.backend as K
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from keras.optimizers import Adam
import os
from nets.yolo import get_train_model, yolo_body
from utils.callbacks import (ExponentDecayScheduler, LossHistory,
                             WarmUpCosineDecayScheduler)
from utils.dataloader import YoloDatasets
from utils.utils import get_anchors, get_classes
import datetime
from keras.utils import multi_gpu_model

time_begin = datetime.datetime.now()
print('time_begin:', time_begin)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'


class ParallelModelCheckpoint(ModelCheckpoint):
    def __init__(self,model,filepath, monitor='val_loss', verbose=0,
                 save_best_only=False, save_weights_only=False,
                 mode='auto', period=1):
        self.single_model = model
        super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)

    def set_model(self, model):
        super(ParallelModelCheckpoint,self).set_model(self.single_model)


def totalloss(y_true, y_pred):
    return K.sum(y_pred)/K.cast(K.shape(y_pred)[0],K.dtype(y_pred))


if __name__ == "__main__":

    #--------------------------------------------------------#
    #   训练前一定要修改classes_path,使其对应自己的数据集
    #--------------------------------------------------------#
    classes_path    = 'model_data/my_classes.txt'
    #---------------------------------------------------------------------#
    #   anchors_path代表先验框对应的txt文件,一般不修改。
    #   anchors_mask用于帮助代码找到对应的先验框,一般不修改。
    #---------------------------------------------------------------------#
    anchors_path    = 'model_data/yolo_anchors.txt'
    anchors_mask    = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
    #---------------------------------------------------------------------
    model_path      = 'model_data/yolo4_weight.h5'
    # model_path      = 'logs/ep116-loss11.549-val_loss11.641.h5'
    #------------------------------------------------------#
    #   输入的shape大小,一定要是32的倍数
    #------------------------------------------------------#
    input_shape     = [416, 416]
    #------------------------------------------------------#
    #   Yolov4的tricks应用
    #   mosaic 马赛克数据增强 True or False 
    #   实际测试时mosaic数据增强并不稳定,所以默认为False
    #   Cosine_scheduler 余弦退火学习率 True or False
    #   label_smoothing 标签平滑 0.01以下一般 如0.01、0.005
    #------------------------------------------------------#
    mosaic              = False
    Cosine_scheduler    = False
    label_smoothing     = 0

    #----------------------------------------------------#
    #   训练分为两个阶段,分别是冻结阶段和解冻阶段。
    #   显存不足与数据集大小无关,提示显存不足请调小batch_size。
    #   受到BatchNorm层影响,batch_size最小为2,不能为1。
    #----------------------------------------------------#
    #----------------------------------------------------#
    #   冻结阶段训练参数
    #   此时模型的主干被冻结了,特征提取网络不发生改变
    #   占用的显存较小,仅对网络进行微调
    #----------------------------------------------------#
    Init_Epoch          = 0
    Freeze_Epoch        = 100
    Freeze_batch_size   = 16
    Freeze_lr           = 1e-3
    #----------------------------------------------------#
    #   解冻阶段训练参数
    #   此时模型的主干不被冻结了,特征提取网络会发生改变
    #   占用的显存较大,网络所有的参数都会发生改变
    #----------------------------------------------------#
    UnFreeze_Epoch      = 200
    Unfreeze_batch_size = 8
    Unfreeze_lr         = 1e-4
    #------------------------------------------------------#
    #   是否进行冻结训练,默认先冻结主干训练后解冻训练。
    #------------------------------------------------------#
    Freeze_Train        = True
    #------------------------------------------------------#
    #   用于设置是否使用多线程读取数据,1代表关闭多线程
    #   开启后会加快数据读取速度,但是会占用更多内存
    #   keras里开启多线程有些时候速度反而慢了许多
    #   在IO为瓶颈的时候再开启多线程,即GPU运算速度远大于读取图片的速度。
    #------------------------------------------------------#
    num_workers         = 1
    #----------------------------------------------------#
    #   获得图片路径和标签
    #----------------------------------------------------#
    train_annotation_path   = '2007_train.txt'
    val_annotation_path     = '2007_val.txt'

    #----------------------------------------------------#
    #   获取classes和anchor
    #----------------------------------------------------#
    class_names, num_classes = get_classes(classes_path)
    anchors, num_anchors     = get_anchors(anchors_path)

    K.clear_session()
    #------------------------------------------------------#
    #   创建yolo模型
    #------------------------------------------------------#
    model_body  = yolo_body((None, None, 3), anchors_mask, num_classes)
    if model_path != '':
        #------------------------------------------------------#
        #   载入预训练权重
        #------------------------------------------------------#
        print('Load weights {}.'.format(model_path))
        model_body.load_weights(model_path, by_name=True, skip_mismatch=True)

    model = get_train_model(model_body, input_shape, num_classes, anchors, anchors_mask, label_smoothing)

    model_parallel = multi_gpu_model(model, gpus=2)
    #-------------------------------------------------------------------------------#
    #   训练参数的设置
    #   logging表示tensorboard的保存地址
    #   checkpoint用于设置权值保存的细节,period用于修改多少epoch保存一次
    #   reduce_lr用于设置学习率下降的方式
    #   early_stopping用于设定早停,val_loss多次不下降自动结束训练,表示模型基本收敛
    #-------------------------------------------------------------------------------#
    logging         = TensorBoard(log_dir = 'logs/')
    checkpoint = ParallelModelCheckpoint(model, filepath='logs/ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5', monitor='val_loss',save_weights_only=True, verbose=1,
                                         save_best_only=True)  # 解决多GPU运行下保存模型报错的问题
    # checkpoint      = ModelCheckpoint('logs/ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',
    #                         monitor = 'val_loss', save_weights_only = True, save_best_only = False, period = 1)
    if Cosine_scheduler:
        reduce_lr   = WarmUpCosineDecayScheduler(T_max = 5, eta_min = 1e-5, verbose = 1)
    else:
        reduce_lr   = ExponentDecayScheduler(decay_rate = 0.94, verbose = 1)
    early_stopping  = EarlyStopping(monitor='val_loss', min_delta = 0, patience = 10, verbose = 1)
    loss_history    = LossHistory('logs/')

    #---------------------------#
    #   读取数据集对应的txt
    #---------------------------#
    with open(train_annotation_path) as f:
        train_lines = f.readlines()
    with open(val_annotation_path) as f:
        val_lines   = f.readlines()
    num_train   = len(train_lines)
    num_val     = len(val_lines)

    if Freeze_Train:
        freeze_layers = 249
        for i in range(freeze_layers): model_body.layers[i].trainable = False
        print('Freeze the first {} layers of total {} layers.'.format(freeze_layers, len(model_body.layers)))
        
    #------------------------------------------------------#
    #   主干特征提取网络特征通用,冻结训练可以加快训练速度
    #   也可以在训练初期防止权值被破坏。
    #   Init_Epoch为起始世代
    #   Freeze_Epoch为冻结训练的世代
    #   UnFreeze_Epoch总训练世代
    #   提示OOM或者显存不足请调小Batch_size
    #------------------------------------------------------#
    if True:
        batch_size  = Freeze_batch_size
        lr          = Freeze_lr
        start_epoch = Init_Epoch
        end_epoch   = Freeze_Epoch

        epoch_step          = num_train // batch_size
        epoch_step_val      = num_val // batch_size

        if epoch_step == 0 or epoch_step_val == 0:
            raise ValueError('数据集过小,无法进行训练,请扩充数据集。')
        
        model_parallel.compile(optimizer=Adam(lr = lr), loss=totalloss)

        train_dataloader    = YoloDatasets(train_lines, input_shape, anchors, batch_size, num_classes, anchors_mask, mosaic = mosaic, train = True)
        val_dataloader      = YoloDatasets(val_lines, input_shape, anchors, batch_size, num_classes, anchors_mask, mosaic = False, train = False)

        print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
        print('****************************')
        model_parallel.fit_generator(
            generator           = train_dataloader,
            steps_per_epoch     = epoch_step,
            validation_data     = val_dataloader,
            validation_steps    = epoch_step_val,
            epochs              = end_epoch,
            initial_epoch       = start_epoch,
            use_multiprocessing = True if num_workers > 1 else False,
            workers             = num_workers,
            callbacks           = [logging, checkpoint, reduce_lr, early_stopping, loss_history]
        )

    if Freeze_Train:
        for i in range(freeze_layers): model_body.layers[i].trainable = True

    if True:
        batch_size  = Unfreeze_batch_size
        lr          = Unfreeze_lr
        start_epoch = Freeze_Epoch
        end_epoch   = UnFreeze_Epoch

        epoch_step          = num_train // batch_size
        epoch_step_val      = num_val // batch_size

        if epoch_step == 0 or epoch_step_val == 0:
            raise ValueError('数据集过小,无法进行训练,请扩充数据集。')
        
        model_parallel.compile(optimizer=Adam(lr = lr), loss=totalloss)

        train_dataloader    = YoloDatasets(train_lines, input_shape, anchors, batch_size, num_classes, anchors_mask, mosaic = mosaic, train = True)
        val_dataloader      = YoloDatasets(val_lines, input_shape, anchors, batch_size, num_classes, anchors_mask, mosaic = False, train = False)

        print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
        print('##################################')
        model_parallel.fit_generator(
            generator           = train_dataloader,
            steps_per_epoch     = epoch_step,
            validation_data     = val_dataloader,
            validation_steps    = epoch_step_val,
            epochs              = end_epoch,
            initial_epoch       = start_epoch,
            use_multiprocessing = True if num_workers > 1 else False,
            workers             = num_workers,
            callbacks           = [logging, checkpoint, reduce_lr, early_stopping, loss_history]
        )
    time_end = datetime.datetime.now()
    print('time_consuming:', time_end - time_begin)

主要变化:
1、from keras.utils import multi_gpu_model,在model形成后,加入:

model_parallel = multi_gpu_model(model, gpus=2)

2、使用ParallelModelCheckpoint(ModelCheckpoint)类;

checkpoint = ParallelModelCheckpoint(model, filepath='logs/ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5', monitor='val_loss',save_weights_only=True, verbose=1,
                                         save_best_only=True)  # 解决多GPU运行下保存模型报错的问题

3、定义一个totalloss,因为改变了loss的输出格式;

def totalloss(y_true, y_pred):
    return K.sum(y_pred)/K.cast(K.shape(y_pred)[0],K.dtype(y_pred))
model_parallel.compile(optimizer=Adam(lr = lr), loss=totalloss)

4、为什么有3的变化,因为改变了nets/yolo.py里面的get_train_model,去掉了output_shape = (1, ), :

def get_train_model(model_body, input_shape, num_classes, anchors, anchors_mask, label_smoothing):
    y_true = [Input(shape = (input_shape[0] // {0:32, 1:16, 2:8}[l], input_shape[1] // {0:32, 1:16, 2:8}[l], \
                                len(anchors_mask[l]), num_classes + 5)) for l in range(len(anchors_mask))]
    model_loss  = Lambda(
        yolo_loss,  
        name            = 'yolo_loss', 
        arguments       = {'input_shape' : input_shape, 'anchors' : anchors, 'anchors_mask' : anchors_mask, 
                           'num_classes' : num_classes, 'label_smoothing' : label_smoothing}
    )([*model_body.output, *y_true])
    model       = Model([model_body.input, *y_true], model_loss)
    return model

5、最后一个改变loss的计算那个地方nets/yolo_training.py:

# location_loss   = K.sum(ciou_loss)
location_loss = K.sum(K.sum(ciou_loss,axis=[2,3,4]),1,keepdims=True)
# confidence_loss = K.sum(confidence_loss)
confidence_loss = K.sum(K.sum(confidence_loss,axis=[2,3,4]),1,keepdims=True)
# class_loss      = K.sum(class_loss)
class_loss = K.sum(K.sum(class_loss,axis=[2,3,4]),1,keepdims=True)
        #-----------------------------------------------------------#
        #   计算正样本数量
        #-----------------------------------------------------------#
num_pos += tf.maximum(K.sum(K.cast(object_mask, tf.float32)), 1)
loss    += location_loss + confidence_loss + class_loss

应该没有漏掉的地方,如果报错,希望指出!

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2022年5月26日
下一篇 2022年5月26日

相关推荐