1、定制优化器
假设你想添加一个名为MyOptimizer的优化器,它有参数a, b和c。你需要首先在一个文件中实现新的优化器,例如在mmseg/core/optimizer/my_optimizer.py中:
from mmcv.runner import OPTIMIZERS
from torch.optim import Optimizer
@OPTIMIZERS.register_module
class MyOptimizer(Optimizer):
def __init__(self, a, b, c)
然后将此模块添加到mmseg/core/optimizer/_init.py中,这样注册表就会找到新模块并添加它:
from .my_optimizer import MyOptimizer
然后您可以在配置文件的优化器字段中使用MyOptimizer。在配置中,优化器是由字段优化器定义的,如下所示:
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
要使用您自己的优化器,可以将字段更改为
optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)
我们已经支持使用PyTorch实现的所有优化器,唯一的修改就是更改配置文件的优化器字段。例如,如果您想使用ADAM,尽管性能会下降很多,但修改如下。
optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)
2、定制优化器的构造函数
一些模型可能有一些参数特定的设置来进行优化,例如BatchNoarm层的权重衰减。用户可以通过定制优化器构造函数来进行这些细粒度的参数调优。
from mmcv.utils import build_from_cfg
from mmcv.runner import OPTIMIZER_BUILDERS
from .cocktail_optimizer import CocktailOptimizer
@OPTIMIZER_BUILDERS.register_module
class CocktailOptimizerConstructor(object):
def __init__(self, optimizer_cfg, paramwise_cfg=None):
def __call__(self, model):
return my_optimizer
3、开发新的组件
MMSegmentation主要有两类组件:
backbone:通常是卷积网络的堆叠来提取特征图,例如ResNet, HRNet. head:语义分割图解码组件。
1、添加新的骨干
这里我们以MobileNet为例展示如何开发新组件。
1. 创建一个新文件mmseg/models/backbone /mobilenet.py
import torch.nn as nn
from ..registry import BACKBONES
@BACKBONES.register_module
class MobileNet(nn.Module):
def __init__(self, arg1, arg2):
pass
def forward(self, x): # should return a tuple
pass
def init_weights(self, pretrained=None):
pass
2、Import the module in mmseg/models/backbones/init.py.
from .mobilenet import MobileNet
3、Use it in your config file.
model = dict(
...
backbone=dict(
type='MobileNet',
arg1=xxx,
arg2=xxx),
...
2、添加新头
在MMSegmentation中,我们为所有的分割头提供了一个BaseDecodeHead。
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from mmcv.runner import BaseModule, auto_fp16, force_fp32
from mmseg.core import build_pixel_sampler
from mmseg.ops import resize
from ..builder import build_loss
from ..losses import accuracy
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
in_index (int|Sequence[int]): Input feature index. Default: -1
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
Default: None.
loss_decode (dict | Sequence[dict]): Config of decode loss.
The `loss_name` is property of corresponding loss function which
could be shown in training log. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
e.g. dict(type='CrossEntropyLoss'),
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='DiceLoss', loss_name='loss_dice')]
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255.
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
channels,
*,
num_classes,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
in_index=-1,
input_transform=None,
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
ignore_index=255,
sampler=None,
align_corners=False,
init_cfg=dict(
type='Normal', std=0.01, override=dict(name='conv_seg'))):
super(BaseDecodeHead, self).__init__(init_cfg)
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_index = in_index
self.ignore_index = ignore_index
self.align_corners = align_corners
if isinstance(loss_decode, dict):
self.loss_decode = build_loss(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
else:
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}')
if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:
self.dropout = None
self.fp16_enabled = False
def extra_repr(self):
"""Extra repr."""
s = f'input_transform={self.input_transform}, ' \
f'ignore_index={self.ignore_index}, ' \
f'align_corners={self.align_corners}'
return s
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
@auto_fp16()
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function."""
pass
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs)
losses = self.losses(seg_logits, gt_semantic_seg)
return losses
def forward_test(self, inputs, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return self.forward(inputs)
def cls_seg(self, feat):
"""Classify each pixel."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output
@force_fp32(apply_to=('seg_logit', ))
def losses(self, seg_logit, seg_label):
"""Compute segmentation loss."""
loss = dict()
seg_logit = resize(
input=seg_logit,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logit, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
loss['acc_seg'] = accuracy(
seg_logit, seg_label, ignore_index=self.ignore_index)
return loss
所有新实现的解码头都应该从它派生。下面我们将使用PSPNet示例演示如何开发一个新的头部,如下所示。
首先,在mmseg/models/decode_heads/psp_head.py中添加一个新的解码头。PSPNet实现了一个用于分割解码的解码头。为了实现一个解码头,我们基本上需要实现以下三个新模块的功能。
@HEADS.register_module()
class PSPHead(BaseDecodeHead):
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(PSPHead, self).__init__(**kwargs)
def init_weights(self):
def forward(self, inputs):
接下来,用户需要在mmseg/models/decode_heads/init.py中添加模块,这样相应的注册表就可以找到并加载它们。PSPNet的配置文件如下
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='pretrain_model/resnet50_v1c_trick-2cccc1ad.pth',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='PSPHead',
in_channels=2048,
in_index=3,
channels=512,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
3、添加新的损失
假设您想添加一个新的损失作为分割解码的MyLoss。要添加一个新的损失函数,用户需要在mmseg/models/losses/my_loss.py中实现它——decorator weighted_loss允许对每个元素的损失进行加权。
import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weighted_loss
@weighted_loss
def my_loss(pred, target):
assert pred.size() == target.size() and target.numel() > 0
loss = torch.abs(pred - target)
return loss
@LOSSES.register_module
class MyLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(MyLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = self.loss_weight * my_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss
然后用户需要将它添加到mmseg/models/losses/init_.py中。
from .my_loss import MyLoss, my_loss
要使用它,请修改loss_xxx字段。然后需要修改头部中的loss_decode字段。可以使用Loss_weight来平衡多个损失。
文章出处登录后可见!