深度学习计算数据集里所有图像像素点的均值方差

代码怕忘记,现在贴上来,以防丢失

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms

def get_mean_and_std(data_path, in_chans=3):
    dataset = ImageFolder(root=data_path, transform=transforms.ToTensor())
    loader = DataLoader(dataset, batch_size=1, shuffle=False, pin_memory=True)
    mean = torch.zeros(in_chans)
    std = torch.zeros(in_chans)
    num_samples = 0

    for X, _ in loader:
        for d in range(in_chans):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()
        num_samples += 1

    mean.div_(num_samples)
    std.div_(num_samples)

    mean = list(mean.numpy())
    std = list(std.numpy())

    print(f"Mean: {mean}")
    print(f"Standard Deviation: {std}")

    return mean, std

data_path = "G:\\04_deep-learning-for-image-processing-master\\pytorch_classification\\swin_transformer\\flower_photos"
mean, std = get_mean_and_std(data_path)

#if __name__ == '__main__':
    #main("G:\\04_deep-learning-for-image-processing-master\\pytorch_classification\\swin_transformer\\flower_photos")

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2023年12月26日
下一篇 2023年12月26日

相关推荐