PyTorch实现计算图像数据集的均值和标准差

一、实施过程

使用Pytorch进行预处理时,通常使用torchvision.transforms.Normalize(mean, std)方法进行数据标准化,其中参数mean和std分别表示图像集每个通道的均值和标准差序列。
首先,给出mean和std的定义,数学表示如下:
假设有一组数据集X_i%2C%5C%2C%5C%2Ci%5Cin%5C%7B1%2C2%2C%5Ccdots%2Cn%5C%7D,这组数据集的均值是:mean%3D%5Cfrac%7B%5Cdisplaystyle%5Csum_%7Bi%3D1%7D%5EnX_i%7D%7Bn%7D%5Ctag%7B1%7D通常用%5Coverline%20X来表示数据的均值。
该组数据集的标准差为:std%3D%5Csqrt%7B%5Cfrac%7B%5Cdisplaystyle%5Csum_%7Bi%3D1%7D%5En%5Cleft%28X_i-%5Coverline%20X%5Cright%29%5E2%7D%7Bn%7D%7D%5C%5C%5B2ex%5D%3D%5Csqrt%7B%5Cfrac%7B%5Cdisplaystyle%5Csum_%7Bi%3D1%7D%5En%28X_i%5E2-2X_i%5Coverline%20X%2B%5Coverline%20X%5E2%29%7D%7Bn%7D%7D%5C%5C%5B2ex%5D%3D%5Csqrt%7B%5Cfrac%7B%5Cleft%28%5Cdisplaystyle%5Csum_%7Bi%3D1%7D%5EnX_i%5E2%5Cright%29-n%5Coverline%20X%5E2%7D%7Bn%7D%7D%5C%5C%5B2ex%5D%3D%5Csqrt%7B%5Cfrac%7B%5Cdisplaystyle%5Csum_%7Bi%3D1%7D%5EnX_i%5E2%7D%7Bn%7D-%5Coverline%20X%5E2%7D%5Ctag%7B2%7D计算图像数据集各通道均值和标准差的函数代码如下:

def get_mean_std_value(loader):
    '''
    求数据集的均值和标准差
    :param loader:
    :return:
    '''
    data_sum,data_squared_sum,num_batches = 0,0,0

    for data,_ in loader:
        # data: [batch_size,channels,height,width]
        # 计算dim=0,2,3维度的均值和,dim=1为通道数量,不用参与计算
        data_sum += torch.mean(data,dim=[0,2,3])    # [batch_size,channels,height,width]
        # 计算dim=0,2,3维度的平方均值和,dim=1为通道数量,不用参与计算
        data_squared_sum += torch.mean(data**2,dim=[0,2,3])  # [batch_size,channels,height,width]
        # 统计batch的数量
        num_batches += 1
    # 计算均值
    mean = data_sum/num_batches
    # 计算标准差
    std = (data_squared_sum/num_batches - mean**2)**0.5
    return mean,std

CIFAR10数据集的均值和标准差为:

mean = tensor([0.4914, 0.4821, 0.4465]),std = tensor([0.2470, 0.2435, 0.2616])

MNIST数据集的均值和标准差为:

mean = tensor([0.1307]),std = tensor([0.3081])

2. 参考文献

[1]https://zhuanlan.zhihu.com/p/378810257

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2022年3月14日 下午6:49
下一篇 2022年3月14日 下午7:02

相关推荐