MMdetection之train.py源码详解

目录


一、tools/train.py

可选参数:

# =========== optional arguments ===========
# --work-dir        存储日志和模型的目录
# --resume-from     加载 checkpoint 的目录
# --no-validate     是否在训练的时候进行验证
# 互斥组:
#   --gpus          使用的 GPU 数量
#   --gpu_ids       使用指定 GPU 的 id
# --seed            随机数种子
# --deterministic   是否设置 cudnn 为确定性行为
# --options         其他参数
# --launcher        分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi']
#                   none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。
# --local_rank      本地进程编号,此参数 torch.distributed.launch 会自动传入。

对于 tools/train.py 其主要的流程如下

        对于 train.py 来说,首先从命令行和配置文件读取配置,然后分别用 build_detector、build_dataset 构建模型和数据集,最后将模型和数据集传入 train_detector 进行训练。

(一)从命令行和配置文件获取参数配置

cfg = Config.fromfile(args.config)  

(二)构建模型

# 构建模型: 需要传入 cfg.model,cfg.train_cfg,cfg.test_cfg
model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

(三)构建数据集

# 构建数据集: 需要传入 cfg.data.train,表明是训练集
datasets = [build_dataset(cfg.data.train)]

(四)训练模型

# 训练检测器:需要传入模型、数据集、配置参数等
train_detector(
    model,
    datasets,
    cfg,
    distributed=distributed,
    validate=(not args.no_validate),
    timestamp=timestamp,
    meta=meta)

二、源码详解

import argparse
import copy
import os
import os.path as osp
import time

import mmcv
import torch
# Config 用于读取配置文件, DictAction 将命令行字典类型参数转化为 key-value 形式
from mmcv import Config, DictAction
from mmcv.runner import init_dist

from mmdet import __version__
from mmdet.apis import set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger


# python tools/train.py ${CONFIG_FILE} [optional arguments]

# =========== optional arguments ===========
# --work-dir        存储日志和模型的目录
# --resume-from     加载 checkpoint 的目录
# --no-validate     是否在训练的时候进行验证
# 互斥组:
#   --gpus          使用的 GPU 数量
#   --gpu_ids       使用指定 GPU 的 id
# --seed            随机数种子
# --deterministic   是否设置 cudnn 为确定性行为
# --options         其他参数
# --launcher        分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi']
#                   none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。
# --local_rank      本地进程编号,此参数 torch.distributed.launch 会自动传入。

def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--work-dir', help='the dir to save logs and models')
    parser.add_argument(
        '--resume-from', help='the checkpoint file to resume from')

    # action: store (默认, 表示保存参数)
    # action: store_true, store_false (如果指定参数, 则为 True, False)
    parser.add_argument(
        '--no-validate',
        action='store_true',
        help='whether not to evaluate the checkpoint during training')

    # --------- 创建一个互斥组. argparse 将会确保互斥组中的参数只能出现一个 ---------
    group_gpus = parser.add_mutually_exclusive_group()
    group_gpus.add_argument(
        '--gpus',
        type=int,
        help='number of gpus to use '
        '(only applicable to non-distributed training)')

    # 可以使用 python train.py --gpu-ids 0 1 2 3 指定使用的 GPU id
    # 参数结果:[0, 1, 2, 3]
    # nargs = '*':参数个数可以设置0个或n个
    # nargs = '+':参数个数可以设置1个或n个
    # nargs = '?':参数个数可以设置0个或1个
    group_gpus.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='ids of gpus to use '
        '(only applicable to non-distributed training)')
    # ------------------------------------------------------------------------

    parser.add_argument('--seed', type=int, default=None, help='random seed')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')

    # 其他参数: 可以使用 --options a=1,2,3 指定其他参数
    # 参数结果: {'a': [1, 2, 3]}
    parser.add_argument(
        '--options', nargs='+', action=DictAction, help='arguments in dict')

    # 如果使用 dist_utils.sh 进行分布式训练, launcher 默认为 pytorch
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')

    # 本地进程编号,此参数 torch.distributed.launch 会自动传入。
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    # 如果环境中没有 LOCAL_RANK,就设置它为当前的 local_rank
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    return args


def main():
    args = parse_args()

    # 从文件读取配置
    cfg = Config.fromfile(args.config)
    # 从命令行读取额外的配置
    if args.options is not None:
        cfg.merge_from_dict(args.options)

    # 设置 cudnn_benchmark = True 可以加速输入大小固定的模型. 如:SSD300
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir 的优先程度为: 命令行 > 配置文件
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    # 当 work_dir 为 None 的时候, 使用 ./work_dir/配置文件名 作为默认工作目录
    elif cfg.get('work_dir', None) is None:
        # os.path.basename(path)    返回文件名
        # os.path.splitext(path)    分割路径, 返回路径名和文件扩展名的元组
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    # 是否继续上次的训练
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    # gpu id
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)

    # 如果 launcher 为 none,不启用分布式训练。不使用 dist_train.sh 默认参数为 none.
    if args.launcher == 'none':
        distributed = False
    # launcher 不为 none,启用分布式训练。使用 dist_train.sh,会传 ‘pytorch’
    else:
        distributed = True
        # 初始化 dist 里面会调用 init_process_group
        init_dist(args.launcher, **cfg.dist_params)

    # 创建 work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # 保存 config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    # init the logger before other steps
    # eg: 20200726_105413
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    # 获取 root logger。
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info

    # log some basic info
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config:\n{cfg.pretty_text}')

    # 设置随机化种子
    if args.seed is not None:
        logger.info(f'Set random seed to {args.seed}, '
                    f'deterministic: {args.deterministic}')
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed

    # 构建模型: 需要传入 cfg.model, cfg.train_cfg, cfg.test_cfg
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

    # 构建数据集: 需要传入 cfg.data.train
    datasets = [build_dataset(cfg.data.train)]
    # workflow 代表流程:
    # [('train', 2), ('val', 1)] 就代表,训练两个 epoch 验证一个 epoch
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmdet_version=__version__,
            config=cfg.pretty_text,
            CLASSES=datasets[0].CLASSES)
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES

    # 训练检测器, 传入:模型, 数据集, config 等
    train_detector(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=(not args.no_validate),
        timestamp=timestamp,
        meta=meta)


if __name__ == '__main__':
    main()

三、核心函数详解

在 train.py 中主要调用:构建模型(build_detector),构建数据集(build_dataset),训练模型(train_detector)的函数。

(一)build_detector(mmdet/models/builder.py)

build_detector 函数将配置文件中的:model、train_cfg 和 test_cfg 传入参数。

下面以 faster_rcnn_r50_fpn_1x_coco.py 配置文件来举例:

model

model = dict(
    type='FasterRCNN',
    pretrained='torchvision://resnet50',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=True),
        norm_eval=True,
        style='pytorch'),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=80,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0))))

train_cfg

train_cfg = dict(
    rpn=dict(
        assigner=dict(
            type='MaxIoUAssigner',
            pos_iou_thr=0.7,
            neg_iou_thr=0.3,
            min_pos_iou=0.3,
            match_low_quality=True,
            ignore_iof_thr=-1),
        sampler=dict(
            type='RandomSampler',
            num=256,
            pos_fraction=0.5,
            neg_pos_ub=-1,
            add_gt_as_proposals=False),
        allowed_border=-1,
        pos_weight=-1,
        debug=False),
    rpn_proposal=dict(
        nms_across_levels=False,
        nms_pre=2000,
        nms_post=1000,
        max_num=1000,
        nms_thr=0.7,
        min_bbox_size=0),
    rcnn=dict(
        assigner=dict(
            type='MaxIoUAssigner',
            pos_iou_thr=0.5,
            neg_iou_thr=0.5,
            min_pos_iou=0.5,
            match_low_quality=False,
            ignore_iof_thr=-1),
        sampler=dict(
            type='RandomSampler',
            num=512,
            pos_fraction=0.25,
            neg_pos_ub=-1,
            add_gt_as_proposals=True),
        pos_weight=-1,
        debug=False))

 test_cfg

test_cfg = dict(
    rpn=dict(
        nms_across_levels=False,
        nms_pre=1000,
        nms_post=1000,
        max_num=1000,
        nms_thr=0.7,
        min_bbox_size=0),
    rcnn=dict(
        score_thr=0.05,
        nms=dict(type='nms', iou_threshold=0.5),
        max_per_img=100)
    # soft-nms is also supported for rcnn testing
    # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
)

运行时会将上面的三个值作为参数传入 build_detector 函数,build_detector 函数会调用 build 函数,build 函数调用 build_from_cfg 函数构建检测器对象。其中 train_cfg 和 test_cfg 作为默认参数用于构建 detector 对象。

def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        # 调用 build_from_cfg 用来根据 config 字典构建 registry 里面的对象
        return build_from_cfg(cfg, registry, default_args)


def build_detector(cfg, train_cfg=None, test_cfg=None):
    # 调用 build 函数,传入 cfg, registry 对象,
    # 把 train_cfg 和 test_cfg 作为默认字典传入
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

 build_from_cfg 在 mmcv/utils/registery.py 中。其中参数 cfg 字典中的 type 键所对应的值表示需要创建的对象的类型。build_from_cfg 会自动在 Registry 注册的类中找到需要创建的类,并传入默认参数实例化。

def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        object: The constructed object.
    """
    if not isinstance(cfg, dict):
        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
    if 'type' not in cfg:
        raise KeyError(
            f'the cfg dict must contain the key "type", but got {cfg}')
    if not isinstance(registry, Registry):
        raise TypeError('registry must be an mmcv.Registry object, '
                        f'but got {type(registry)}')
    if not (isinstance(default_args, dict) or default_args is None):
        raise TypeError('default_args must be a dict or None, '
                        f'but got {type(default_args)}')

    args = cfg.copy()
    # 获取 type 对应的值
    obj_type = args.pop('type')
    if is_str(obj_type):
        # 获取需要创建的对象
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f'{obj_type} is not in the {registry.name} registry')
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')

    # 如果 default_args 不是 None,传入默认值再实例化。
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_cls(**args)

 那么什么是 registry?registry 就是注册类,将一个字符串和类关联起来。如果索引字符串就会获得类。Registry 是注册所需要的类,可以用它来注册类。我们可以使用如下的方式来注册类。

下面是 Registry 类的代码,它的内部维护了一个已经注册的类的字典 ——_module_dict。每当注册一个类就在字典里添加一个字符串(默认为类名)与类的映射。register_module 方法,利用装饰器将类名和类添加到 _module_dict 中。对于注册的模块可以通过 build_from_cfg 来构建。

import inspect
import warnings
from functools import partial

from .misc import is_str


class Registry:
    """A registry to map strings to classes.

    Args:
        name (str): Registry name.
    """

    def __init__(self, name):
        self._name = name
        # 已经注册的类的字典
        self._module_dict = dict()

    def __len__(self):
        return len(self._module_dict)

    def __contains__(self, key):
        return self.get(key) is not None

    def __repr__(self):
        format_str = self.__class__.__name__ + \
                     f'(name={self._name}, ' \
                     f'items={self._module_dict})'
        return format_str

    @property
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def get(self, key):
        """Get the registry record.

        Args:
            key (str): The class name in string format.

        Returns:
            class: The corresponding class.
        """
        return self._module_dict.get(key, None)

    def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if not force and module_name in self._module_dict:
            raise KeyError(f'{module_name} is already registered '
                           f'in {self.name}')
        self._module_dict[module_name] = module_class

    def deprecated_register_module(self, cls=None, force=False):
        warnings.warn(
            'The old API of register_module(module, force=False) '
            'is deprecated and will be removed, please use the new API '
            'register_module(name=None, force=False, module=None) instead.')
        if cls is None:
            return partial(self.deprecated_register_module, force=force)
        self._register_module(cls, force=force)
        return cls

    def register_module(self, name=None, force=False, module=None):
        """Register a module.

        A record will be added to `self._module_dict`, whose key is the class
        name or the specified name, and value is the class itself.
        It can be used as a decorator or a normal function.

        Example:
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module()
            >>> class ResNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> @backbones.register_module(name='mnet')
            >>> class MobileNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> class ResNet:
            >>>     pass
            >>> backbones.register_module(ResNet)

        Args:
            name (str | None): The module name to be registered. If not
                specified, the class name will be used.
            force (bool, optional): Whether to override an existing class with
                the same name. Default: False.
            module (type): Module class to be registered.
        """
        if not isinstance(force, bool):
            raise TypeError(f'force must be a boolean, but got {type(force)}')
        # NOTE: This is a walkaround to be compatible with the old api,
        # while it may introduce unexpected bugs.
        if isinstance(name, type):
            return self.deprecated_register_module(name, force=force)

        # use it as a normal method: x.register_module(module=SomeClass)
        if module is not None:
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # raise the error ahead of time
        if not (name is None or isinstance(name, str)):
            raise TypeError(f'name must be a str, but got {type(name)}')

        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register

(二) build_dataset(mmdet/datasets/builder)

build_dataset 也类似,通过调用 build_from_cfg 创建。

def build_dataset(cfg, default_args=None):
    from .dataset_wrappers import (ConcatDataset, RepeatDataset,
                                   ClassBalancedDataset)
    if isinstance(cfg, (list, tuple)):
        dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
    elif cfg['type'] == 'RepeatDataset':
        dataset = RepeatDataset(
            build_dataset(cfg['dataset'], default_args), cfg['times'])
    elif cfg['type'] == 'ClassBalancedDataset':
        dataset = ClassBalancedDataset(
            build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
    elif isinstance(cfg.get('ann_file'), (list, tuple)):
        dataset = _concat_dataset(cfg, default_args)
    else:
        dataset = build_from_cfg(cfg, DATASETS, default_args)

    return dataset

(三) train_detector(mmdet/apis/train.py)

train_detector 的主要流程为:

(1.)构建 data loaders:

data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            cfg.data.workers_per_gpu,
            # cfg.gpus will be ignored if distributed
            len(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed) for ds in dataset
    ]

(2.) 构建分布式处理对象:

model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)

(3.) 构建优化器:

optimizer = build_optimizer(model, cfg.optimizer)

(4.) 创建 EpochBasedRunner 并进行训练:

runner = EpochBasedRunner(
        model,
        optimizer=optimizer,
        work_dir=cfg.work_dir,
        logger=logger,
        meta=meta)

源码如下

 

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):
    # 获取 logger
    logger = get_root_logger(cfg.log_level)

    # ==================== 构建 data loaders ====================
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]

    # 获得 samples_per_gpu
    if 'imgs_per_gpu' in cfg.data:
        logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
                       'Please use "samples_per_gpu" instead')
        if 'samples_per_gpu' in cfg.data:
            logger.warning(
                f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
                f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
                f'={cfg.data.imgs_per_gpu} is used in this experiments')
        else:
            logger.warning(
                'Automatically set "samples_per_gpu"="imgs_per_gpu"='
                f'{cfg.data.imgs_per_gpu} in this experiments')
        cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu

    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            cfg.data.workers_per_gpu,
            # cfg.gpus will be ignored if distributed
            len(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed) for ds in dataset
    ]


    # ==================== 构建分布式处理对象 =====================
    # 如果是多卡会进入此 if
    if distributed:
        find_unused_parameters = cfg.get('find_unused_parameters', False)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)
    # 单卡进入
    else:
        model = MMDataParallel(
            model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)


    # ====================== 构建优化器 ==========================
    optimizer = build_optimizer(model, cfg.optimizer)

    # ============= 创建 EpochBasedRunner 并进行训练 ==============
    runner = EpochBasedRunner(
        model,
        optimizer=optimizer,
        work_dir=cfg.work_dir,
        logger=logger,
        meta=meta)
    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(
            **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
    elif distributed and 'type' not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config

    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))
    if distributed:
        runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
        val_dataloader = build_dataloader(
            val_dataset,
            samples_per_gpu=1,
            workers_per_gpu=cfg.data.workers_per_gpu,
            dist=distributed,
            shuffle=False)
        eval_cfg = cfg.get('evaluation', {})
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

(四)set_random_seed:

        此函数会对 python、numpy、torch 都设置随机数种子。

        保持随机数种子相同时,卷积的结果在CPU上相同,在GPU上仍然不相同。这是因为,cudnn卷积行为的不确定性。使用 torch.backends.cudnn.deterministic = True 可以解决。

        cuDNN 使用非确定性算法,并且可以使用 torch.backends.cudnn.enabled = False 来进行禁用。如果设置为 torch.backends.cudnn.enabled = True,说明设置为使用非确定性算法(即会自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题)。

一般来讲,应该遵循以下准则:

  1. 如果网络的输入数据维度或类型上变化不大,设置 torch.backends.cudnn.benchmark = true 可以增加运行效率
  2. 如果网络的输入数据在每次 iteration 都变化的话,会导致 cnDNN 每次都会去寻找一遍最优配置,这样反而会降低运行效率。设置 torch.backends.cudnn.benchmark = False 避免重复搜索。
def set_random_seed(seed, deterministic=False):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # manual_seed_all 是为所有 GPU 都设置随机数种子。
    torch.cuda.manual_seed_all(seed)

    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

(五)get_root_logger:

get_root_logger 调用 get_logger 函数获取 logger 对象。

import logging

from mmcv.utils import get_logger


def get_root_logger(log_file=None, log_level=logging.INFO):
    logger = get_logger(name='mmdet', log_file=log_file, log_level=log_level)

    return logger

这里实现的 get_logger 函数非常灵活,如果传入相同的 log 的 name,会返回配置相同的 log。传入以点分割的日志名称的子模块,也会返回相同的 log。如:a 和 a.b 会返回相同的 log。如果传入 log_file 会保存 log 的输出到 log_file 指定的路径,如果不传入 log_file,不保存日志的输出。只在控制台输出。

import logging

import torch.distributed as dist

# 记录是否创建过 name 对应的 log,如果创建过设置为 True
logger_initialized = {}


def get_logger(name, log_file=None, log_level=logging.INFO):
    # 获取 log 对象。
    logger = logging.getLogger(name)
    # 如果已经创建过,直接返回
    if name in logger_initialized:
        return logger
    # 如果是创建过的以 ‘.’ 分割的子模块,也直接返回
    for logger_name in logger_initialized:
        if name.startswith(logger_name):
            return logger

    stream_handler = logging.StreamHandler()
    handlers = [stream_handler]

    # 获取当前的 rank(总进程编号)
    if dist.is_available() and dist.is_initialized():
        rank = dist.get_rank()
    else:
        rank = 0

    # 只有 rank 0(master 节点的 local_rank 为 0 的进程)的主机才保存日志
    if rank == 0 and log_file is not None:
        file_handler = logging.FileHandler(log_file, 'w')
        handlers.append(file_handler)

    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    for handler in handlers:
        handler.setFormatter(formatter)
        handler.setLevel(log_level)
        logger.addHandler(handler)
    # 对于非 rank 为 0 的进程,只有 error 以上的信息才会显示
    if rank == 0:
        logger.setLevel(log_level)
    else:
        logger.setLevel(logging.ERROR)
    # 将 log name 对应的值设为 True,表示创建过。
    logger_initialized[name] = True

    return logger

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
社会演员多的头像社会演员多普通用户
上一篇 2023年3月28日
下一篇 2023年3月28日

相关推荐

此站出售,如需请站内私信或者邮箱!