改进 DeepLabV3+

网络整体结构图

CFF结构图 

import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.xception import xception
from nets.mobilenetv2 import mobilenetv2

class MobileNetV2(nn.Module):
	def __init__(self, downsample_factor=8, pretrained=True):
		super(MobileNetV2, self).__init__()
		from functools import partial

		model           = mobilenetv2(pretrained)
		self.features   = model.features[:-1]

		self.total_idx  = len(self.features)
		self.down_idx   = [2, 4, 7, 14]

		if downsample_factor == 8:
			for i in range(self.down_idx[-2], self.down_idx[-1]):
				self.features[i].apply(
					partial(self._nostride_dilate, dilate=2)
				)
			for i in range(self.down_idx[-1], self.total_idx):
				self.features[i].apply(
					partial(self._nostride_dilate, dilate=4)
				)
		elif downsample_factor == 16:
			for i in range(self.down_idx[-1], self.total_idx):
				self.features[i].apply(
					partial(self._nostride_dilate, dilate=2)
				)

	def _nostride_dilate(self, m, dilate):
		classname = m.__class__.__name__
		if classname.find('Conv') != -1:
			if m.stride == (2, 2):
				m.stride = (1, 1)
				if m.kernel_size == (3, 3):
					m.dilation = (dilate//2, dilate//2)
					m.padding = (dilate//2, dilate//2)
			else:
				if m.kernel_size == (3, 3):
					m.dilation = (dilate, dilate)
					m.padding = (dilate, dilate)

	def forward(self, x):
		#输入shape为576*576*3
		low_level_features = self.features[:4](x) #144*144*24
		the_three_features = self.features[:7](x) #72*72*32
		the_four_features = self.features[:11](x) #36*36*64
		x = self.features[4:](low_level_features) #36*36*320
		return low_level_features, the_three_features, the_four_features, x


#-----------------------------------------#
#   ASPP特征提取模块
#   利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
class ASPP(nn.Module):
	def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
		super(ASPP, self).__init__()
		self.branch1 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),
		)
		self.branch2 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),	
		)
		self.branch3 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=12*rate, dilation=12*rate, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),	
		)
		self.branch4 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=18*rate, dilation=18*rate, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),	
		)
		self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=True)
		self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
		self.branch5_relu = nn.ReLU(inplace=True)

		self.conv_cat = nn.Sequential(
				nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=True), #dim_out=256
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),		
		)

	def forward(self, x):
		[b, c, row, col] = x.size()
		#-----------------------------------------#
		#   一共五个分支
		#-----------------------------------------#
		conv1x1 = self.branch1(x)
		conv3x3_1 = self.branch2(x)
		conv3x3_2 = self.branch3(x)
		conv3x3_3 = self.branch4(x)
		#-----------------------------------------#
		#   第五个分支,全局平均池化+卷积
		#-----------------------------------------#
		global_feature = torch.mean(x,2,True)
		global_feature = torch.mean(global_feature,3,True)
		global_feature = self.branch5_conv(global_feature)
		global_feature = self.branch5_bn(global_feature)
		global_feature = self.branch5_relu(global_feature)
		global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
		
		#-----------------------------------------#
		#   将五个分支的内容堆叠起来
		#   然后1x1卷积整合特征。
		#-----------------------------------------#
		feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
		result = self.conv_cat(feature_cat) #256通道
		return result

class DeepLab(nn.Module):
	def __init__(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16):
		super(DeepLab, self).__init__()
		if backbone=="xception":
			#----------------------------------#
			#   获得两个特征层
			#   浅层特征    [128,128,256]
			#   主干部分    [30,30,2048]
			#----------------------------------#
			self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained)
			in_channels = 2048
			low_level_channels = 256

		elif backbone=="mobilenet":
			#----------------------------------#
			#   获得两个特征层
			#   浅层特征    [128,128,24]
			#   主干部分    [30,30,320]
			#----------------------------------#
			self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained)
			in_channels = 320
			low_level_channels = 24
			# the_three_channels = 32
			# the_four_channels = 64

		else:
			raise ValueError('Unsupported backbone - `{}`, Use mobilenet, xception.'.format(backbone))

		#CA注意力机制
		self.CA = CoordAtt(320, 320)

		#-----------------------------------------#
		#   ASPP特征提取模块
		#   利用不同膨胀率的膨胀卷积进行特征提取
		#-----------------------------------------#
		self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)

		#----------------------------------#
		#   浅层特征边
		#----------------------------------#
		self.shortcut_conv = nn.Sequential(
			nn.Conv2d(low_level_channels, 48, 1),
			nn.BatchNorm2d(48),
			nn.ReLU(inplace=True)
		)

		self.cat_conv = nn.Sequential(
			nn.Conv2d(48+256, 256, 3, stride=1, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(inplace=True),
			nn.Dropout(0.5),

			nn.Conv2d(256, 256, 3, stride=1, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(inplace=True),

			nn.Dropout(0.1),
		)
		self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1)

		#CFF
		self.F1 = nn.Sequential(
			nn.Conv2d(32, 192, 1, stride=1, padding=0),
			nn.BatchNorm2d(192)
		)
		self.F2_1 = nn.Sequential(
			nn.Conv2d(64, 64, 3, 1, padding=2, dilation=2, bias=True),  # dilation=2的膨胀卷积
			nn.BatchNorm2d(64, momentum=0.1),
		)

	def forward(self, x):
		H, W = x.size(2), x.size(3)
		#-----------------------------------------#
		#   获得两个特征层
		#   low_level_features: 浅层特征-进行卷积处理
		#   x : 主干部分-利用ASPP结构进行加强特征提取
		#-----------------------------------------#
		low_level_features, the_three_features, the_four_features, x = self.backbone(x)
		x = self.CA(x)
		x = self.aspp(x) #输出256通道
		low_level_features = self.shortcut_conv(low_level_features) #144*144*48

		#---------------
		F1 = self.F1(the_three_features)  # 72*72*32-72*72*192
		# 36*36*64-72*72*64
		F2_0 = F.interpolate(the_four_features, size=(the_three_features.size(2), the_three_features.size(3)), mode='bilinear',
							 align_corners=True)
		F2_1 = self.F2_1(F2_0)  # 72*72*64-72*72*64
		FN = F.relu_(torch.cat((F1, F2_1), dim=1))  # 72*72*256
		#----------------------------------------#

		x = F.interpolate(x, size=(the_three_features.size(2), the_three_features.size(3)), mode='bilinear',
						    align_corners=True) # 72*72*256
		FN2 = FN + x # 72*72*256
		F2_1 = F.interpolate(FN2, size=(low_level_features.size(2), low_level_features.size(3)),
							 mode='bilinear', align_corners=True) #144*144*256


		#-----------------------------------------#
		#   将加强特征边上采样
		#   与浅层特征堆叠后利用卷积进行特征提取
		#-----------------------------------------#
		# x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
		# x = self.cat_conv(torch.cat((x, low_level_features), dim=1))
		x = self.cat_conv(torch.cat((low_level_features, F2_1), dim=1)) #144*144*304-144*144*256
		x = self.cls_conv(x)
		x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
		return x




#-----------------------------------------#
#   CA
#-----------------------------------------#
import torch
import torch.nn as nn
import torch.nn.functional as F

class h_sigmoid(nn.Module):
	def __init__(self, inplace=True):
		super(h_sigmoid, self).__init__()
		self.relu = nn.ReLU6(inplace=inplace)

	def forward(self, x):
		return self.relu(x + 3) / 6

class h_swish(nn.Module):
	def __init__(self, inplace=True):
		super(h_swish, self).__init__()
		self.sigmoid = h_sigmoid(inplace=inplace)

	def forward(self, x):
		return x * self.sigmoid(x)

class CoordAtt(nn.Module):
	def __init__(self, inp, oup, reduction=32):
		super(CoordAtt, self).__init__()
		self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
		self.pool_w = nn.AdaptiveAvgPool2d((1, None))

		mip = max(8, inp // reduction)

		self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
		self.bn1 = nn.BatchNorm2d(mip)
		self.act = h_swish()

		self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
		self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)

	def forward(self, x):
		identity = x

		n, c, h, w = x.size()
		x_h = self.pool_h(x)
		x_w = self.pool_w(x).permute(0, 1, 3, 2)

		y = torch.cat([x_h, x_w], dim=2)
		y = self.conv1(y)
		y = self.bn1(y)
		y = self.act(y)

		x_h, x_w = torch.split(y, [h, w], dim=2)
		x_w = x_w.permute(0, 1, 3, 2)

		a_h = self.conv_h(x_h).sigmoid()
		a_w = self.conv_w(x_w).sigmoid()

		out = identity * a_w * a_h

		return out

网络整体结构图

import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.xception import xception
from nets.mobilenetv2 import mobilenetv2

class MobileNetV2(nn.Module):
	def __init__(self, downsample_factor=8, pretrained=True):
		super(MobileNetV2, self).__init__()
		from functools import partial

		model           = mobilenetv2(pretrained)
		self.features   = model.features[:-1]

		self.total_idx  = len(self.features)
		self.down_idx   = [2, 4, 7, 14]

		if downsample_factor == 8:
			for i in range(self.down_idx[-2], self.down_idx[-1]):
				self.features[i].apply(
					partial(self._nostride_dilate, dilate=2)
				)
			for i in range(self.down_idx[-1], self.total_idx):
				self.features[i].apply(
					partial(self._nostride_dilate, dilate=4)
				)
		elif downsample_factor == 16:
			for i in range(self.down_idx[-1], self.total_idx):
				self.features[i].apply(
					partial(self._nostride_dilate, dilate=2)
				)

	def _nostride_dilate(self, m, dilate):
		classname = m.__class__.__name__
		if classname.find('Conv') != -1:
			if m.stride == (2, 2):
				m.stride = (1, 1)
				if m.kernel_size == (3, 3):
					m.dilation = (dilate//2, dilate//2)
					m.padding = (dilate//2, dilate//2)
			else:
				if m.kernel_size == (3, 3):
					m.dilation = (dilate, dilate)
					m.padding = (dilate, dilate)

	def forward(self, x):
		# 输入shape为576*576*3
		low_level_features = self.features[:4](x)  # 144*144*24
		the_three_features = self.features[:7](x)  # 72*72*32
		the_four_features = self.features[:11](x)  # 36*36*64
		x = self.features[4:](low_level_features)  # 36*36*320
		return low_level_features, the_three_features, the_four_features, x

#-----------------------------------------#
#   ASPP特征提取模块
#   利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
class ASPP(nn.Module):
	def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
		super(ASPP, self).__init__()
		self.branch1 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),
		)
		self.branch2 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),	
		)
		self.branch3 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=12*rate, dilation=12*rate, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),	
		)
		self.branch4 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=18*rate, dilation=18*rate, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),	
		)
		self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=True)
		self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
		self.branch5_relu = nn.ReLU(inplace=True)

		self.conv_cat = nn.Sequential(
				nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),		
		)

	def forward(self, x):
		[b, c, row, col] = x.size()
		#-----------------------------------------#
		#   一共五个分支
		#-----------------------------------------#
		conv1x1 = self.branch1(x)
		conv3x3_1 = self.branch2(x)
		conv3x3_2 = self.branch3(x)
		conv3x3_3 = self.branch4(x)
		#-----------------------------------------#
		#   第五个分支,全局平均池化+卷积
		#-----------------------------------------#
		global_feature = torch.mean(x,2,True)
		global_feature = torch.mean(global_feature,3,True)
		global_feature = self.branch5_conv(global_feature)
		global_feature = self.branch5_bn(global_feature)
		global_feature = self.branch5_relu(global_feature)
		global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
		
		#-----------------------------------------#
		#   将五个分支的内容堆叠起来
		#   然后1x1卷积整合特征。
		#-----------------------------------------#
		feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
		result = self.conv_cat(feature_cat)
		return result

class DeepLab(nn.Module):
	def __init__(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16):
		super(DeepLab, self).__init__()
		if backbone=="xception":
			#----------------------------------#
			#   获得两个特征层
			#   浅层特征    [128,128,256]
			#   主干部分    [30,30,2048]
			#----------------------------------#
			self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained)
			in_channels = 2048
			low_level_channels = 256
		elif backbone=="mobilenet":
			#----------------------------------#
			#   获得两个特征层
			#   浅层特征    [128,128,24]
			#   主干部分    [30,30,320]
			#----------------------------------#
			self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained)
			in_channels = 320
			low_level_channels = 24
			the_three_channels = 32
			the_four_channels = 64
		else:
			raise ValueError('Unsupported backbone - `{}`, Use mobilenet, xception.'.format(backbone))

		#-----------------------------------------#
		#   ASPP特征提取模块
		#   利用不同膨胀率的膨胀卷积进行特征提取
		#-----------------------------------------#
		self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)

		#----------------------------------#
		#   浅层特征边
		#----------------------------------#
		self.shortcut_conv = nn.Sequential(
			nn.Conv2d(120, 48, 1),
			nn.BatchNorm2d(48),
			nn.ReLU(inplace=True)
		)

		self.cat_conv = nn.Sequential(
			nn.Conv2d(256+48, 256, 3, stride=1, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(inplace=True),
			nn.Dropout(0.5),

			nn.Conv2d(256, 256, 3, stride=1, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(inplace=True),

			nn.Dropout(0.1),
		)
		self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1)

	def forward(self, x):
		H, W = x.size(2), x.size(3)
		#-----------------------------------------#
		#   获得两个特征层
		#   low_level_features: 浅层特征-进行卷积处理
		#   x : 主干部分-利用ASPP结构进行加强特征提取
		#-----------------------------------------#
		low_level_features, the_three_features, the_four_features, x = self.backbone(x)
		x = self.aspp(x) #输出通道256
		# low_level_features = self.shortcut_conv(low_level_features) #144*144*24-144*144*48

		#72*72*32-144*144*32
		the_three_features_up = F.interpolate(the_three_features, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
		#36*36*64-144*144*64
		the_four_features_up = F.interpolate(the_four_features, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
		#144*144*(24+32+64)-144*144*48
		low_level_features = self.shortcut_conv(torch.cat((low_level_features, the_three_features_up, the_four_features_up), dim=1))

		#-----------------------------------------#
		#   将加强特征边上采样
		#   与浅层特征堆叠后利用卷积进行特征提取
		#-----------------------------------------#
		#x: 144*144*256
		x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
		x = self.cat_conv(torch.cat((x, low_level_features), dim=1))#144*144*(256+48)-144*144*256
		x = self.cls_conv(x)
		x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
		return x

ASPP模块中加入SP条形池化分支

#-----------------------------------------#
#   ASPP特征提取模块,增加了SP条形池化分支
#   利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
class ASPP(nn.Module):
	def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
		super(ASPP, self).__init__()
		self.branch1 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),
		)
		self.branch2 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),	
		)
		self.branch3 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=12*rate, dilation=12*rate, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),	
		)
		self.branch4 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=18*rate, dilation=18*rate, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),	
		)
		self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=True)
		self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
		self.branch5_relu = nn.ReLU(inplace=True)

		self.conv_cat = nn.Sequential(
				nn.Conv2d(dim_out*5+320, dim_out, 1, 1, padding=0,bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),		
		)

		#ASPP模块中增加SP条形池化分支
		self.SP = StripPooling(320, up_kwargs={'mode': 'bilinear', 'align_corners': True})

	def forward(self, x):
		[b, c, row, col] = x.size()
		#-----------------------------------------#
		#   一共五个分支
		#-----------------------------------------#
		conv1x1 = self.branch1(x)
		conv3x3_1 = self.branch2(x)
		conv3x3_2 = self.branch3(x)
		conv3x3_3 = self.branch4(x)

		#增加SP分支
		sp = self.SP(x) #输出通道数=320

		#-----------------------------------------#
		#   第五个分支,全局平均池化+卷积
		#-----------------------------------------#
		global_feature = torch.mean(x,2,True)
		global_feature = torch.mean(global_feature,3,True)
		global_feature = self.branch5_conv(global_feature)
		global_feature = self.branch5_bn(global_feature)
		global_feature = self.branch5_relu(global_feature)
		global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
		
		#-----------------------------------------#
		#   将五个分支的内容堆叠起来
		#   然后1x1卷积整合特征。
		#-----------------------------------------#
		feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, sp, global_feature], dim=1)
		result = self.conv_cat(feature_cat)
		return result



# -----------------------------------------#
#   SP条形池化模块,输入通道=输出通道=320
# -----------------------------------------#
class StripPooling(nn.Module):
	def __init__(self, in_channels, up_kwargs={'mode': 'bilinear', 'align_corners': True}):
		super(StripPooling, self).__init__()
		self.pool1 = nn.AdaptiveAvgPool2d((1, None))#1*W
		self.pool2 = nn.AdaptiveAvgPool2d((None, 1))#H*1
		inter_channels = int(in_channels / 4)
		self.conv1 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 1, bias=False),
									 nn.BatchNorm2d(inter_channels),
									 nn.ReLU(True))
		self.conv2 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (1, 3), 1, (0, 1), bias=False),
									 nn.BatchNorm2d(inter_channels))
		self.conv3 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 1), 1, (1, 0), bias=False),
									 nn.BatchNorm2d(inter_channels))
		self.conv4 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False),
									 nn.BatchNorm2d(inter_channels),
									 nn.ReLU(True))
		self.conv5 = nn.Sequential(nn.Conv2d(inter_channels, in_channels, 1, bias=False),
								   nn.BatchNorm2d(in_channels))
		self._up_kwargs = up_kwargs

	def forward(self, x):
		_, _, h, w = x.size()
		x1 = self.conv1(x)
		x2 = F.interpolate(self.conv2(self.pool1(x1)), (h, w), **self._up_kwargs)#结构图的1*W的部分
		x3 = F.interpolate(self.conv3(self.pool2(x1)), (h, w), **self._up_kwargs)#结构图的H*1的部分
		x4 = self.conv4(F.relu_(x2 + x3))#结合1*W和H*1的特征
		out = self.conv5(x4)
		return F.relu_(x + out)#将输出的特征与原始输入特征结合

DenseASPP替换ASPP,并在DenseASPP中引入SP分支

import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.xception import xception
from nets.mobilenetv2 import mobilenetv2

class MobileNetV2(nn.Module):
    def __init__(self, downsample_factor=8, pretrained=True):
        super(MobileNetV2, self).__init__()
        from functools import partial
        
        model           = mobilenetv2(pretrained)
        self.features   = model.features[:-1]

        self.total_idx  = len(self.features)
        self.down_idx   = [2, 4, 7, 14]

        if downsample_factor == 8:
            for i in range(self.down_idx[-2], self.down_idx[-1]):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=4)
                )
        elif downsample_factor == 16:
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )
        
    def _nostride_dilate(self, m, dilate):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            if m.stride == (2, 2):
                m.stride = (1, 1)
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate//2, dilate//2)
                    m.padding = (dilate//2, dilate//2)
            else:
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)

    def forward(self, x):
        low_level_features = self.features[:4](x)
        x = self.features[4:](low_level_features)
        return low_level_features, x 

'''
#-----------------------------------------#
#   ASPP特征提取模块
#   利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
class ASPP(nn.Module):
    def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
        super(ASPP, self).__init__()
        self.branch1 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        self.branch2 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),	
        )
        self.branch3 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 3, 1, padding=12*rate, dilation=12*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),	
        )
        self.branch4 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 3, 1, padding=18*rate, dilation=18*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),	
        )
        self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=True)
        self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
        self.branch5_relu = nn.ReLU(inplace=True)

        self.conv_cat = nn.Sequential(
                nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),		
        )

    def forward(self, x):
        [b, c, row, col] = x.size()
        #-----------------------------------------#
        #   一共五个分支
        #-----------------------------------------#
        conv1x1 = self.branch1(x)
        conv3x3_1 = self.branch2(x)
        conv3x3_2 = self.branch3(x)
        conv3x3_3 = self.branch4(x)
        #-----------------------------------------#
        #   第五个分支,全局平均池化+卷积
        #-----------------------------------------#
        global_feature = torch.mean(x,2,True)
        global_feature = torch.mean(global_feature,3,True)
        global_feature = self.branch5_conv(global_feature)
        global_feature = self.branch5_bn(global_feature)
        global_feature = self.branch5_relu(global_feature)
        global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
        
        #-----------------------------------------#
        #   将五个分支的内容堆叠起来
        #   然后1x1卷积整合特征。
        #-----------------------------------------#
        feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
        result = self.conv_cat(feature_cat)
        return result
        '''

class DeepLab(nn.Module):
    def __init__(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16):
        super(DeepLab, self).__init__()
        if backbone=="xception":
            #----------------------------------#
            #   获得两个特征层
            #   浅层特征    [128,128,256]
            #   主干部分    [30,30,2048]
            #----------------------------------#
            self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained)
            in_channels = 2048
            low_level_channels = 256
        elif backbone=="mobilenet":
            #----------------------------------#
            #   获得两个特征层
            #   浅层特征    [128,128,24]
            #   主干部分    [30,30,320]
            #----------------------------------#
            self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained)
            in_channels = 320
            low_level_channels = 24
        else:
            raise ValueError('Unsupported backbone - `{}`, Use mobilenet, xception.'.format(backbone))

        #-----------------------------------------#
        #   ASPP特征提取模块
        #   利用不同膨胀率的膨胀卷积进行特征提取
        #-----------------------------------------#
        # self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)
        self.denseaspp = _DenseASPPBlock(in_channels, 512, 256, norm_layer=nn.BatchNorm2d, norm_kwargs=None)
        
        #----------------------------------#
        #   浅层特征边
        #----------------------------------#
        self.shortcut_conv = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )		

        self.cat_conv = nn.Sequential(
            nn.Conv2d(48+1920, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Dropout(0.1),
        )
        self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1)

    def forward(self, x):
        H, W = x.size(2), x.size(3)
        #-----------------------------------------#
        #   获得两个特征层
        #   low_level_features: 浅层特征-进行卷积处理
        #   x : 主干部分-利用ASPP结构进行加强特征提取
        #-----------------------------------------#
        low_level_features, x = self.backbone(x)
        # x = self.aspp(x)
        x = self.denseaspp(x) #输入通道是320,输出通道是1600+320

        low_level_features = self.shortcut_conv(low_level_features) #144*144*24-144*144*48
        
        #-----------------------------------------#
        #   将加强特征边上采样
        #   与浅层特征堆叠后利用卷积进行特征提取
        #-----------------------------------------#
        # 144*144*1920
        x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
        x = self.cat_conv(torch.cat((x, low_level_features), dim=1))# 144*144*1968-144*144*256
        x = self.cls_conv(x)
        x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
        return x




# -----------------------------------------#
#  	DenseASPP,含有SP分支,输入通道是320,输出通道是1600+320
# -----------------------------------------#
class _DenseASPPConv(nn.Sequential):
    def __init__(self, in_channels, inter_channels, out_channels, atrous_rate,
                 drop_rate=0.1, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
        super(_DenseASPPConv, self).__init__()
        self.add_module('conv1', nn.Conv2d(in_channels, inter_channels, 1)),
        self.add_module('bn1', norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs))),
        self.add_module('relu1', nn.ReLU(True)),
        self.add_module('conv2', nn.Conv2d(inter_channels, out_channels, 3, dilation=atrous_rate, padding=atrous_rate)),
        self.add_module('bn2', norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs))),
        self.add_module('relu2', nn.ReLU(True)),
        self.drop_rate = drop_rate

    def forward(self, x):
        features = super(_DenseASPPConv, self).forward(x)
        if self.drop_rate > 0:
            features = F.dropout(features, p=self.drop_rate, training=self.training)
        return features


class _DenseASPPBlock(nn.Module):
    def __init__(self, in_channels, inter_channels1, inter_channels2,
                 norm_layer=nn.BatchNorm2d, norm_kwargs=None):
        super(_DenseASPPBlock, self).__init__()
        self.aspp_3 = _DenseASPPConv(in_channels, inter_channels1, inter_channels2, 3, 0.1,
                                     norm_layer, norm_kwargs)
        self.aspp_6 = _DenseASPPConv(in_channels + inter_channels2 * 1, inter_channels1, inter_channels2, 6, 0.1,
                                     norm_layer, norm_kwargs)
        self.aspp_12 = _DenseASPPConv(in_channels + inter_channels2 * 2, inter_channels1, inter_channels2, 12, 0.1,
                                      norm_layer, norm_kwargs)
        self.aspp_18 = _DenseASPPConv(in_channels + inter_channels2 * 3, inter_channels1, inter_channels2, 18, 0.1,
                                      norm_layer, norm_kwargs)
        self.aspp_24 = _DenseASPPConv(in_channels + inter_channels2 * 4, inter_channels1, inter_channels2, 24, 0.1,
                                      norm_layer, norm_kwargs)
        self.SP = StripPooling(320, up_kwargs={'mode': 'bilinear', 'align_corners': True})

    def forward(self, x):
        x1 = self.SP(x)
        aspp3 = self.aspp_3(x)

        x = torch.cat([aspp3, x], dim=1)

        aspp6 = self.aspp_6(x)
        x = torch.cat([aspp6, x], dim=1)

        aspp12 = self.aspp_12(x)
        x = torch.cat([aspp12, x], dim=1)

        aspp18 = self.aspp_18(x)
        x = torch.cat([aspp18, x], dim=1)

        aspp24 = self.aspp_24(x)
        x = torch.cat([aspp24, x], dim=1)
        x = torch.cat([x, x1], dim=1)

        return x


# -----------------------------------------#
#   SP条形池化模块,输入通道=输出通道=320
# -----------------------------------------#
class StripPooling(nn.Module):
    def __init__(self, in_channels, up_kwargs={'mode': 'bilinear', 'align_corners': True}):
        super(StripPooling, self).__init__()
        self.pool1 = nn.AdaptiveAvgPool2d((1, None))#1*W
        self.pool2 = nn.AdaptiveAvgPool2d((None, 1))#H*1
        inter_channels = int(in_channels / 4)
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 1, bias=False),
                                     nn.BatchNorm2d(inter_channels),
                                     nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (1, 3), 1, (0, 1), bias=False),
                                     nn.BatchNorm2d(inter_channels))
        self.conv3 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 1), 1, (1, 0), bias=False),
                                     nn.BatchNorm2d(inter_channels))
        self.conv4 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False),
                                     nn.BatchNorm2d(inter_channels),
                                     nn.ReLU(True))
        self.conv5 = nn.Sequential(nn.Conv2d(inter_channels, in_channels, 1, bias=False),
                                   nn.BatchNorm2d(in_channels))
        self._up_kwargs = up_kwargs

    def forward(self, x):
        _, _, h, w = x.size()
        x1 = self.conv1(x)
        x2 = F.interpolate(self.conv2(self.pool1(x1)), (h, w), **self._up_kwargs)#结构图的1*W的部分
        x3 = F.interpolate(self.conv3(self.pool2(x1)), (h, w), **self._up_kwargs)#结构图的H*1的部分
        x4 = self.conv4(F.relu_(x2 + x3))#结合1*W和H*1的特征
        out = self.conv5(x4)
        return F.relu_(x + out)#将输出的特征与原始输入特征结合

DenseASPP替换ASPP,并采用上面两种级联方式

import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.xception import xception
from nets.mobilenetv2 import mobilenetv2

class MobileNetV2(nn.Module):
    def __init__(self, downsample_factor=8, pretrained=True):
        super(MobileNetV2, self).__init__()
        from functools import partial
        
        model           = mobilenetv2(pretrained)
        self.features   = model.features[:-1]

        self.total_idx  = len(self.features)
        self.down_idx   = [2, 4, 7, 14]

        if downsample_factor == 8:
            for i in range(self.down_idx[-2], self.down_idx[-1]):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=4)
                )
        elif downsample_factor == 16:
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )
        
    def _nostride_dilate(self, m, dilate):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            if m.stride == (2, 2):
                m.stride = (1, 1)
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate//2, dilate//2)
                    m.padding = (dilate//2, dilate//2)
            else:
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)

    def forward(self, x):
        # 输入shape为576*576*3
        low_level_features = self.features[:4](x)  # 144*144*24
        the_three_features = self.features[:7](x)  # 72*72*32
        the_four_features = self.features[:11](x)  # 36*36*64
        x = self.features[4:](low_level_features)  # 36*36*320
        return low_level_features, the_three_features, the_four_features, x

'''
#-----------------------------------------#
#   ASPP特征提取模块
#   利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
class ASPP(nn.Module):
    def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
        super(ASPP, self).__init__()
        self.branch1 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        self.branch2 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),	
        )
        self.branch3 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 3, 1, padding=12*rate, dilation=12*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),	
        )
        self.branch4 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 3, 1, padding=18*rate, dilation=18*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),	
        )
        self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=True)
        self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
        self.branch5_relu = nn.ReLU(inplace=True)

        self.conv_cat = nn.Sequential(
                nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),		
        )

    def forward(self, x):
        [b, c, row, col] = x.size()
        #-----------------------------------------#
        #   一共五个分支
        #-----------------------------------------#
        conv1x1 = self.branch1(x)
        conv3x3_1 = self.branch2(x)
        conv3x3_2 = self.branch3(x)
        conv3x3_3 = self.branch4(x)
        #-----------------------------------------#
        #   第五个分支,全局平均池化+卷积
        #-----------------------------------------#
        global_feature = torch.mean(x,2,True)
        global_feature = torch.mean(global_feature,3,True)
        global_feature = self.branch5_conv(global_feature)
        global_feature = self.branch5_bn(global_feature)
        global_feature = self.branch5_relu(global_feature)
        global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
        
        #-----------------------------------------#
        #   将五个分支的内容堆叠起来
        #   然后1x1卷积整合特征。
        #-----------------------------------------#
        feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
        result = self.conv_cat(feature_cat)
        return result
        '''

class DeepLab(nn.Module):
    def __init__(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16):
        super(DeepLab, self).__init__()
        if backbone=="xception":
            #----------------------------------#
            #   获得两个特征层
            #   浅层特征    [128,128,256]
            #   主干部分    [30,30,2048]
            #----------------------------------#
            self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained)
            in_channels = 2048
            low_level_channels = 256
        elif backbone=="mobilenet":
            #----------------------------------#
            #   获得两个特征层
            #   浅层特征    [128,128,24]
            #   主干部分    [30,30,320]
            #----------------------------------#
            self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained)
            in_channels = 320
            low_level_channels = 24
            the_three_channels = 32
            the_four_channels = 64
        else:
            raise ValueError('Unsupported backbone - `{}`, Use mobilenet, xception.'.format(backbone))

        #-----------------------------------------#
        #   ASPP特征提取模块
        #   利用不同膨胀率的膨胀卷积进行特征提取
        #-----------------------------------------#
        # self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)
        self.denseaspp = _DenseASPPBlock(in_channels, 512, 256, norm_layer=nn.BatchNorm2d, norm_kwargs=None)

        
        #----------------------------------#
        #   浅层特征边
        #----------------------------------#
        self.shortcut_conv = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )		

        self.cat_conv = nn.Sequential(
            nn.Conv2d(304, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Dropout(0.1),
        )
        self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1)

        # CFF
        self.F1 = nn.Sequential(
            nn.Conv2d(32, 192, 1, stride=1, padding=0),
            nn.BatchNorm2d(192)
        )
        self.F2_1 = nn.Sequential(
            nn.Conv2d(64, 64, 3, 1, padding=2, dilation=2, bias=True),  # dilation=2的膨胀卷积
            nn.BatchNorm2d(64, momentum=0.1),
        )
        #降低通道数
        self.down_conv = nn.Sequential(
            nn.Conv2d(1920, 256, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        
    def forward(self, x):
        H, W = x.size(2), x.size(3)
        #-----------------------------------------#
        #   获得两个特征层
        #   low_level_features: 浅层特征-进行卷积处理
        #   x : 主干部分-利用ASPP结构进行加强特征提取
        #-----------------------------------------#
        low_level_features, the_three_features, the_four_features, x = self.backbone(x)
        # x = self.aspp(x)
        x = self.denseaspp(x) #输入36*36*320,输出36*36*1920
        
        x = self.down_conv(x)#36*36*1920-36*36*256

        low_level_features = self.shortcut_conv(low_level_features) #144*144*24-144*144*48

        # ---------------CFF模块-----------------#
        F1 = self.F1(the_three_features)  # 72*72*32-72*72*192
        # 36*36*64-72*72*64
        F2_0 = F.interpolate(the_four_features, size=(the_three_features.size(2), the_three_features.size(3)),
                             mode='bilinear', align_corners=True)
        F2_1 = self.F2_1(F2_0)  # 72*72*64-72*72*64
        FN = F.relu_(torch.cat((F1, F2_1), dim=1))  # 72*72*256
        # ----------------------------------------#

        x = F.interpolate(x, size=(the_three_features.size(2), the_three_features.size(3)), mode='bilinear',
                          align_corners=True)  # 72*72*256
        FN2 = FN + x  # 72*72*256,此处维度必须一致,即二者的通道数必须一样
        F2_1 = F.interpolate(FN2, size=(low_level_features.size(2), low_level_features.size(3)),
                             mode='bilinear', align_corners=True)  # 144*144*256
        

        #-----------------------------------------#
        #   将加强特征边上采样
        #   与浅层特征堆叠后利用卷积进行特征提取
        #-----------------------------------------#
        # 144*144*1920
        # x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
        # x = self.cat_conv(torch.cat((x, low_level_features), dim=1))
        x = self.cat_conv(torch.cat((low_level_features, F2_1), dim=1))  # 144*144*304-144*144*256
        x = self.cls_conv(x)
        x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
        return x




# -----------------------------------------#
#  	DenseASPP,含有SP分支,输入通道是320,输出通道是1600+320
# -----------------------------------------#
class _DenseASPPConv(nn.Sequential):
    def __init__(self, in_channels, inter_channels, out_channels, atrous_rate,
                 drop_rate=0.1, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
        super(_DenseASPPConv, self).__init__()
        self.add_module('conv1', nn.Conv2d(in_channels, inter_channels, 1)),
        self.add_module('bn1', norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs))),
        self.add_module('relu1', nn.ReLU(True)),
        self.add_module('conv2', nn.Conv2d(inter_channels, out_channels, 3, dilation=atrous_rate, padding=atrous_rate)),
        self.add_module('bn2', norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs))),
        self.add_module('relu2', nn.ReLU(True)),
        self.drop_rate = drop_rate

    def forward(self, x):
        features = super(_DenseASPPConv, self).forward(x)
        if self.drop_rate > 0:
            features = F.dropout(features, p=self.drop_rate, training=self.training)
        return features


class _DenseASPPBlock(nn.Module):
    def __init__(self, in_channels, inter_channels1, inter_channels2,
                 norm_layer=nn.BatchNorm2d, norm_kwargs=None):
        super(_DenseASPPBlock, self).__init__()
        self.aspp_3 = _DenseASPPConv(in_channels, inter_channels1, inter_channels2, 3, 0.1,
                                     norm_layer, norm_kwargs)
        self.aspp_6 = _DenseASPPConv(in_channels + inter_channels2 * 1, inter_channels1, inter_channels2, 6, 0.1,
                                     norm_layer, norm_kwargs)
        self.aspp_12 = _DenseASPPConv(in_channels + inter_channels2 * 2, inter_channels1, inter_channels2, 12, 0.1,
                                      norm_layer, norm_kwargs)
        self.aspp_18 = _DenseASPPConv(in_channels + inter_channels2 * 3, inter_channels1, inter_channels2, 18, 0.1,
                                      norm_layer, norm_kwargs)
        self.aspp_24 = _DenseASPPConv(in_channels + inter_channels2 * 4, inter_channels1, inter_channels2, 24, 0.1,
                                      norm_layer, norm_kwargs)
        self.SP = StripPooling(320, up_kwargs={'mode': 'bilinear', 'align_corners': True})

    def forward(self, x):
        x1 = self.SP(x)
        aspp3 = self.aspp_3(x)

        x = torch.cat([aspp3, x], dim=1)

        aspp6 = self.aspp_6(x)
        x = torch.cat([aspp6, x], dim=1)

        aspp12 = self.aspp_12(x)
        x = torch.cat([aspp12, x], dim=1)

        aspp18 = self.aspp_18(x)
        x = torch.cat([aspp18, x], dim=1)

        aspp24 = self.aspp_24(x)
        x = torch.cat([aspp24, x], dim=1)
        x = torch.cat([x, x1], dim=1)

        return x


# -----------------------------------------#
#   SP条形池化模块,输入通道=输出通道=320
# -----------------------------------------#
class StripPooling(nn.Module):
    def __init__(self, in_channels, up_kwargs={'mode': 'bilinear', 'align_corners': True}):
        super(StripPooling, self).__init__()
        self.pool1 = nn.AdaptiveAvgPool2d((1, None))#1*W
        self.pool2 = nn.AdaptiveAvgPool2d((None, 1))#H*1
        inter_channels = int(in_channels / 4)
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 1, bias=False),
                                     nn.BatchNorm2d(inter_channels),
                                     nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (1, 3), 1, (0, 1), bias=False),
                                     nn.BatchNorm2d(inter_channels))
        self.conv3 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 1), 1, (1, 0), bias=False),
                                     nn.BatchNorm2d(inter_channels))
        self.conv4 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False),
                                     nn.BatchNorm2d(inter_channels),
                                     nn.ReLU(True))
        self.conv5 = nn.Sequential(nn.Conv2d(inter_channels, in_channels, 1, bias=False),
                                   nn.BatchNorm2d(in_channels))
        self._up_kwargs = up_kwargs

    def forward(self, x):
        _, _, h, w = x.size()
        x1 = self.conv1(x)
        x2 = F.interpolate(self.conv2(self.pool1(x1)), (h, w), **self._up_kwargs)#结构图的1*W的部分
        x3 = F.interpolate(self.conv3(self.pool2(x1)), (h, w), **self._up_kwargs)#结构图的H*1的部分
        x4 = self.conv4(F.relu_(x2 + x3))#结合1*W和H*1的特征
        out = self.conv5(x4)
        return F.relu_(x + out)#将输出的特征与原始输入特征结合

import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.mobilenetv2 import mobilenetv2
from nets.xception import xception

class MobileNetV2(nn.Module):
	def __init__(self, downsample_factor=8, pretrained=True):
		super(MobileNetV2, self).__init__()
		from functools import partial

		model           = mobilenetv2(pretrained)
		self.features   = model.features[:-1]

		self.total_idx  = len(self.features)
		self.down_idx   = [2, 4, 7, 14]

		if downsample_factor == 8:
			for i in range(self.down_idx[-2], self.down_idx[-1]):
				self.features[i].apply(
					partial(self._nostride_dilate, dilate=2)
				)
			for i in range(self.down_idx[-1], self.total_idx):
				self.features[i].apply(
					partial(self._nostride_dilate, dilate=4)
				)
		elif downsample_factor == 16:
			for i in range(self.down_idx[-1], self.total_idx):
				self.features[i].apply(
					partial(self._nostride_dilate, dilate=2)
				)

	def _nostride_dilate(self, m, dilate):
		classname = m.__class__.__name__
		if classname.find('Conv') != -1:
			if m.stride == (2, 2):
				m.stride = (1, 1)
				if m.kernel_size == (3, 3):
					m.dilation = (dilate//2, dilate//2)
					m.padding = (dilate//2, dilate//2)
			else:
				if m.kernel_size == (3, 3):
					m.dilation = (dilate, dilate)
					m.padding = (dilate, dilate)

	def forward(self, x):
		#输出两个有效特征层
		low_level_features = self.features[:4](x)
		the_three_features = self.features[:7](x)
		the_four_features = self.features[:11](x)
		x = self.features[4:](low_level_features)
		return low_level_features, the_three_features, the_four_features, x

'''
#-----------------------------------------#
#   ASPP特征提取模块
#   利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
class ASPP(nn.Module):
	def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
		super(ASPP, self).__init__()
		self.branch1 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True), #dilation=1即没使用膨胀卷积
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True), #30,30,256
		)
		self.branch2 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True), #dilation=6的膨胀卷积
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True), #30,30,256
		)
		self.branch3 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=12*rate, dilation=12*rate, bias=True), #dilation12的膨胀卷积
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True), #30,30,256
		)
		self.branch4 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=18*rate, dilation=18*rate, bias=True), #dilation=18的膨胀卷积
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True), #30,30,256
		)
		self.branch5 = nn.Sequential(
				nn.AdaptiveAvgPool2d((1, 1)),
				nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True)
		)

		# self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True)
		# self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
		# self.branch5_relu = nn.ReLU(inplace=True)

		self.conv_cat = nn.Sequential(
				nn.Conv2d(dim_out*5+320, dim_out, 1, 1, padding=0, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True), #30,30,256
		)

		self.head = StripPooling(320, up_kwargs={'mode': 'bilinear', 'align_corners': True})

	def forward(self, x):
		#获取输入特征图的高宽
		[b, c, row, col] = x.size()
		#-----------------------------------------#
		#   一共五个分支
		#-----------------------------------------#
		conv1x1 = self.branch1(x) #30,30,256
		# print("X1.shape", conv1x1.size())
		conv3x3_1 = self.branch2(x) #30,30,256
		# print("X2.shape", conv3x3_1.size())
		conv3x3_2 = self.branch3(x) #30,30,256
		# print("X3.shape", conv3x3_2.size())
		conv3x3_3 = self.branch4(x) #30,30,256
		# print("X4.shape", conv3x3_3.size())
		spm = self.head(x)
		#-----------------------------------------#
		#   第五个分支,全局平均池化+卷积
		#-----------------------------------------#
		# global_feature = torch.mean(x,2,True)
		# global_feature = torch.mean(global_feature,3,True)
		# global_feature = self.branch5_conv(global_feature)
		# global_feature = self.branch5_bn(global_feature)
		# global_feature = self.branch5_relu(global_feature)
		global_feature = self.branch5(x)
		# print("X5.shape", global_feature.size())
		global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True) #30,30,256
		
		#-----------------------------------------#
		#   将五个分支的内容堆叠起来
		#   然后1x1卷积整合特征。
		#-----------------------------------------#
		feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, spm, global_feature], dim=1) #30,30,256*5
		result = self.conv_cat(feature_cat) #堆叠完后利用1*1卷积对通道数进行调整,30,30,256
		return result
		'''

class DeepLab(nn.Module):
	def __init__(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16):
		super(DeepLab, self).__init__()
		if backbone=="xception":
			#----------------------------------#
			#   获得两个特征层
			#   浅层特征    [128,128,256]
			#   主干部分    [30,30,2048]
			#----------------------------------#
			self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained)
			in_channels = 2048
			low_level_channels = 256
		elif backbone=="mobilenet":
			#----------------------------------#
			#   获得两个特征层
			#   浅层特征    [128,128,24]
			#   主干部分    [30,30,320]
			#----------------------------------#
			self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained)
			in_channels = 320
			low_level_channels = 24
			the_three_channels = 32
			the_four_channels = 64
		else:
			raise ValueError('Unsupported backbone - `{}`, Use mobilenet, xception.'.format(backbone))

		#-----------------------------------------#
		#   ASPP特征提取模块
		#   利用不同膨胀率的膨胀卷积进行特征提取
		#-----------------------------------------#
		# self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)
		self.denseaspp = _DenseASPPBlock(in_channels, 512, 256, norm_layer=nn.BatchNorm2d, norm_kwargs=None)
		# self.SE1 = SELayer(1600+320)
		#----------------------------------#
		#   浅层特征边
		#----------------------------------#
		self.shortcut_conv = nn.Sequential(
			nn.Conv2d(low_level_channels+the_three_channels+the_four_channels, 48, 1),
			nn.BatchNorm2d(48),
			nn.ReLU(inplace=True)
		)
		# self.SE2 = SELayer(48)

		self.cat_conv = nn.Sequential(
			nn.Conv2d(1920+48, 256, 3, stride=1, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(inplace=True),
			nn.Dropout(0.5),

			nn.Conv2d(256, 256, 3, stride=1, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(inplace=True),

			nn.Dropout(0.1),
		)
		self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1)


	def forward(self, x): #此处传入的x为原图b,3,512,512
		H, W = x.size(2), x.size(3)
		#-----------------------------------------#
		#   获得两个特征层
		#   low_level_features: 浅层特征-进行卷积处理 128,128,24
		#   x : 主干部分-利用ASPP结构进行加强特征提取 30,30,256
		#-----------------------------------------#
		low_level_features, the_three_features, the_four_features, x = self.backbone(x)
		# x = self.aspp(x) #aspp后的输出
		x = self.denseaspp(x)
		# x = self.SE1(x)
		#浅层特征网络经过一个1*1卷积,128,128,24->128,128,48
		the_three_features_up = F.interpolate(the_three_features, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
		the_four_features_up = F.interpolate(the_four_features, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
		low_level_features = self.shortcut_conv(torch.cat((low_level_features, the_three_features_up, the_four_features_up), dim=1))
		# low_level_features = self.SE2(low_level_features)
		#-----------------------------------------#
		#   将加强特征边上采样
		#   与浅层特征堆叠后利用卷积进行特征提取
		#-----------------------------------------#
		x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True) #x:128,128,256
		x = self.cat_conv(torch.cat((x, low_level_features), dim=1)) #128,128,256+48->128,128,256
		x = self.cls_conv(x) #128,128,256->128,128,num_classes
		x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) #512,512,num_classes
		return x


# -----------------------------------------#
#   SP条形池化模块
# -----------------------------------------#
class StripPooling(nn.Module):
	def __init__(self, in_channels, up_kwargs={'mode': 'bilinear', 'align_corners': True}):
		super(StripPooling, self).__init__()
		self.pool1 = nn.AdaptiveAvgPool2d((1, None))#1*W
		self.pool2 = nn.AdaptiveAvgPool2d((None, 1))#H*1
		inter_channels = int(in_channels / 4)
		self.conv1 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 1, bias=False),
									 nn.BatchNorm2d(inter_channels),
									 nn.ReLU(True))
		self.conv2 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (1, 3), 1, (0, 1), bias=False),
									 nn.BatchNorm2d(inter_channels))
		self.conv3 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 1), 1, (1, 0), bias=False),
									 nn.BatchNorm2d(inter_channels))
		self.conv4 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False),
									 nn.BatchNorm2d(inter_channels),
									 nn.ReLU(True))
		self.conv5 = nn.Sequential(nn.Conv2d(inter_channels, in_channels, 1, bias=False),
								   nn.BatchNorm2d(in_channels))
		self._up_kwargs = up_kwargs

	def forward(self, x):
		_, _, h, w = x.size()
		x1 = self.conv1(x)
		x2 = F.interpolate(self.conv2(self.pool1(x1)), (h, w), **self._up_kwargs)#结构图的1*W的部分
		x3 = F.interpolate(self.conv3(self.pool2(x1)), (h, w), **self._up_kwargs)#结构图的H*1的部分
		x4 = self.conv4(F.relu_(x2 + x3))#结合1*W和H*1的特征
		out = self.conv5(x4)
		return F.relu_(x + out)#将输出的特征与原始输入特征结合


# -----------------------------------------#
#  	DenseASPP
# -----------------------------------------#
class _DenseASPPConv(nn.Sequential):
	def __init__(self, in_channels, inter_channels, out_channels, atrous_rate,
				 drop_rate=0.1, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
		super(_DenseASPPConv, self).__init__()
		self.add_module('conv1', nn.Conv2d(in_channels, inter_channels, 1)),
		self.add_module('bn1', norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs))),
		self.add_module('relu1', nn.ReLU(True)),
		self.add_module('conv2', nn.Conv2d(inter_channels, out_channels, 3, dilation=atrous_rate, padding=atrous_rate)),
		self.add_module('bn2', norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs))),
		self.add_module('relu2', nn.ReLU(True)),
		self.drop_rate = drop_rate

	def forward(self, x):
		features = super(_DenseASPPConv, self).forward(x)
		if self.drop_rate > 0:
			features = F.dropout(features, p=self.drop_rate, training=self.training)
		return features

class _DenseASPPBlock(nn.Module):
	def __init__(self, in_channels, inter_channels1, inter_channels2,
				 norm_layer=nn.BatchNorm2d, norm_kwargs=None):
		super(_DenseASPPBlock, self).__init__()
		self.aspp_3 = _DenseASPPConv(in_channels, inter_channels1, inter_channels2, 3, 0.1,
									 norm_layer, norm_kwargs)
		self.aspp_6 = _DenseASPPConv(in_channels + inter_channels2 * 1, inter_channels1, inter_channels2, 6, 0.1,
									 norm_layer, norm_kwargs)
		self.aspp_12 = _DenseASPPConv(in_channels + inter_channels2 * 2, inter_channels1, inter_channels2, 12, 0.1,
									  norm_layer, norm_kwargs)
		self.aspp_18 = _DenseASPPConv(in_channels + inter_channels2 * 3, inter_channels1, inter_channels2, 18, 0.1,
									  norm_layer, norm_kwargs)
		self.aspp_24 = _DenseASPPConv(in_channels + inter_channels2 * 4, inter_channels1, inter_channels2, 24, 0.1,
									  norm_layer, norm_kwargs)
		self.SP = StripPooling(320, up_kwargs={'mode': 'bilinear', 'align_corners': True})

	def forward(self, x):
		x1 = self.SP(x)
		aspp3 = self.aspp_3(x)

		x = torch.cat([aspp3, x], dim=1)

		aspp6 = self.aspp_6(x)
		x = torch.cat([aspp6, x], dim=1)

		aspp12 = self.aspp_12(x)
		x = torch.cat([aspp12, x], dim=1)

		aspp18 = self.aspp_18(x)
		x = torch.cat([aspp18, x], dim=1)

		aspp24 = self.aspp_24(x)
		x = torch.cat([aspp24, x], dim=1)
		x = torch.cat([x, x1], dim=1)

		return x

 

10.28更新(解码复习)

import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.xception import xception
from nets.mobilenetv2 import mobilenetv2

class MobileNetV2(nn.Module):
    def __init__(self, downsample_factor=8, pretrained=True):
        super(MobileNetV2, self).__init__()
        from functools import partial
        
        model           = mobilenetv2(pretrained)
        self.features   = model.features[:-1]

        self.total_idx  = len(self.features)
        self.down_idx   = [2, 4, 7, 14]

        if downsample_factor == 8:
            for i in range(self.down_idx[-2], self.down_idx[-1]):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=4)
                )
        elif downsample_factor == 16:
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )
        
    def _nostride_dilate(self, m, dilate):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            if m.stride == (2, 2):
                m.stride = (1, 1)
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate//2, dilate//2)
                    m.padding = (dilate//2, dilate//2)
            else:
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)

    def forward(self, x):
        # 输入shape为576*576*3
        low_level_features = self.features[:4](x)  # 144*144*24
        the_three_features = self.features[:7](x)  # 72*72*32
        # the_four_features = self.features[:11](x)  # 36*36*64
        x = self.features[4:](low_level_features)  # 36*36*320
        return low_level_features, the_three_features, x


#-----------------------------------------#
#   ASPP特征提取模块
#   利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
class ASPP(nn.Module):
    def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
        super(ASPP, self).__init__()
        self.branch1 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        self.branch2 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        self.branch3 = nn.Sequential(
                nn.Conv2d(dim_in+dim_out, dim_out, 3, 1, padding=12*rate, dilation=12*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        self.branch4 = nn.Sequential(
                nn.Conv2d(dim_in+dim_out, dim_out, 3, 1, padding=18*rate, dilation=18*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=True)
        self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
        self.branch5_relu = nn.ReLU(inplace=True)

        self.conv_cat = nn.Sequential(
                nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )

    def forward(self, x):
        [b, c, row, col] = x.size()
        #-----------------------------------------#
        #   一共五个分支
        #-----------------------------------------#
        conv1x1 = self.branch1(x) #32*32*320-32*32*256

        conv3x3_1 = self.branch2(x)  # 32*32*320-32*32*256
        x1 = torch.cat((x, conv3x3_1), dim=1) #32*32*576

        conv3x3_2 = self.branch3(x1) #32*32*576-32*32*256
        x2 = torch.cat((x, conv3x3_2), dim=1)  # 32*32*576

        conv3x3_3 = self.branch4(x2)
        #-----------------------------------------#
        #   第五个分支,全局平均池化+卷积
        #-----------------------------------------#
        global_feature = torch.mean(x,2,True)
        global_feature = torch.mean(global_feature,3,True)
        global_feature = self.branch5_conv(global_feature)
        global_feature = self.branch5_bn(global_feature)
        global_feature = self.branch5_relu(global_feature)
        global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)

        #-----------------------------------------#
        #   将五个分支的内容堆叠起来
        #   然后1x1卷积整合特征。
        #-----------------------------------------#
        feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
        result = self.conv_cat(feature_cat)
        return result

class DeepLab(nn.Module):
    def __init__(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16):
        super(DeepLab, self).__init__()
        if backbone=="xception":
            #----------------------------------#
            #   获得两个特征层
            #   浅层特征    [128,128,256]
            #   主干部分    [30,30,2048]
            #----------------------------------#
            self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained)
            in_channels = 2048
            low_level_channels = 256
        elif backbone=="mobilenet":
            #----------------------------------#
            #   获得两个特征层
            #   浅层特征    [128,128,24]
            #   主干部分    [30,30,320]
            #----------------------------------#
            self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained)
            in_channels = 320
            low_level_channels = 24
        else:
            raise ValueError('Unsupported backbone - `{}`, Use mobilenet, xception.'.format(backbone))

        #-----------------------------------------#
        #   ASPP特征提取模块
        #   利用不同膨胀率的膨胀卷积进行特征提取
        #-----------------------------------------#
        self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)
        
        #----------------------------------#
        #   浅层特征边
        #----------------------------------#
        self.shortcut_conv = nn.Sequential(
            nn.Conv2d(32, 64, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )		

        self.cat_conv = nn.Sequential(
            nn.Conv2d(48+256, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Dropout(0.1),
        )
        self.cls_conv = nn.Conv2d(688, num_classes, 3, stride=1, padding=1)

        self.three_conv = nn.Sequential(
            nn.Conv2d(32, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True))

        self.low_conv = nn.Sequential(
            nn.Conv2d(24, 48, 1, stride=1, padding=0),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True))

        self.low_conv_0 = nn.Sequential(
            nn.Conv2d(48, 368, 1, stride=1, padding=0),
            nn.BatchNorm2d(368),
            nn.ReLU(inplace=True))

        self.cSE = cSE_Module(320)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        H, W = x.size(2), x.size(3)
        #-----------------------------------------#
        #   获得两个特征层
        #   low_level_features: 浅层特征-进行卷积处理
        #   x : 主干部分-利用ASPP结构进行加强特征提取
        #-----------------------------------------#
        # low_level_features, x = self.backbone(x)
        low_level_features, the_three_features, x = self.backbone(x)

        x = self.aspp(x) #32*32*256
        x = F.interpolate(x, size=(the_three_features.size(2), the_three_features.size(3)), mode='bilinear',
                          align_corners=True) #64*64*256
        the_three_features = self.shortcut_conv(the_three_features) #64*64*32-64*64*64
        x1 = torch.cat((x, the_three_features), dim=1) #64*64*320
        x2_0 = F.interpolate(x1, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear',
                          align_corners=True)  #128*128*320

        x2 = self.cSE(x2_0) #128*128*320-128*128*320

        low_level_features = self.low_conv(low_level_features) #128*128*24-128*128*48
        low_level_features_0 = self.low_conv_0(low_level_features) #128*128*48-128*128*368
        x3 = torch.cat((x2, low_level_features), dim=1)  #128*128*368
        x3 = self.sigmoid(x3)  #128*128*368
        x4 = x3 * low_level_features_0 #128*128*368
        x5 = torch.cat((x4, x2_0), dim=1) #128*128*688
        x5 = self.cls_conv(x5)
        x6 = F.interpolate(x5, size=(H, W), mode='bilinear', align_corners=True)
        return x6

        #-----------------------------------------#
        #   将加强特征边上采样
        #   与浅层特征堆叠后利用卷积进行特征提取
        #-----------------------------------------#
        # x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
        # x = self.cat_conv(torch.cat((x, low_level_features), dim=1))
        # x = self.cls_conv(x)
        # x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
        # return x

class cSE_Module(nn.Module):   #通道注意力机制
    def __init__(self, channel, ratio = 16):
        super(cSE_Module, self).__init__()
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
                nn.Conv2d(channel, channel // ratio, 1, bias=False),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // ratio, channel, 1, bias=False),
                nn.Sigmoid()

                # nn.Linear(in_features=channel, out_features=channel // ratio),
                # nn.ReLU(inplace=True),
                # nn.Linear(in_features=channel // ratio, out_features=channel),
                # nn.Sigmoid()
            )
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.squeeze(x)
        z = self.excitation(y)
        return x * z.expand_as(x)

        # b, c, _, _ = x.size()
        # y = self.squeeze(x).view(b, c)
        # z = self.excitation(y).view(b, c, 1, 1)
        # return x * z.expand_as(x)

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
乘风的头像乘风管理团队
上一篇 2023年11月8日
下一篇 2023年11月8日

相关推荐