一、实施过程
使用Pytorch进行预处理时,通常使用torchvision.transforms.Normalize(mean, std)方法进行数据标准化,其中参数mean和std分别表示图像集每个通道的均值和标准差序列。
首先,给出mean和std的定义,数学表示如下:
假设有一组数据集,这组数据集的均值是:
通常用
来表示数据的均值。
该组数据集的标准差为:计算图像数据集各通道均值和标准差的函数代码如下:
def get_mean_std_value(loader):
'''
求数据集的均值和标准差
:param loader:
:return:
'''
data_sum,data_squared_sum,num_batches = 0,0,0
for data,_ in loader:
# data: [batch_size,channels,height,width]
# 计算dim=0,2,3维度的均值和,dim=1为通道数量,不用参与计算
data_sum += torch.mean(data,dim=[0,2,3]) # [batch_size,channels,height,width]
# 计算dim=0,2,3维度的平方均值和,dim=1为通道数量,不用参与计算
data_squared_sum += torch.mean(data**2,dim=[0,2,3]) # [batch_size,channels,height,width]
# 统计batch的数量
num_batches += 1
# 计算均值
mean = data_sum/num_batches
# 计算标准差
std = (data_squared_sum/num_batches - mean**2)**0.5
return mean,std
CIFAR10数据集的均值和标准差为:
mean = tensor([0.4914, 0.4821, 0.4465]),std = tensor([0.2470, 0.2435, 0.2616])
MNIST数据集的均值和标准差为:
mean = tensor([0.1307]),std = tensor([0.3081])
2. 参考文献
[1]https://zhuanlan.zhihu.com/p/378810257
文章出处登录后可见!
已经登录?立即刷新