代码怕忘记,现在贴上来,以防丢失
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms
def get_mean_and_std(data_path, in_chans=3):
dataset = ImageFolder(root=data_path, transform=transforms.ToTensor())
loader = DataLoader(dataset, batch_size=1, shuffle=False, pin_memory=True)
mean = torch.zeros(in_chans)
std = torch.zeros(in_chans)
num_samples = 0
for X, _ in loader:
for d in range(in_chans):
mean[d] += X[:, d, :, :].mean()
std[d] += X[:, d, :, :].std()
num_samples += 1
mean.div_(num_samples)
std.div_(num_samples)
mean = list(mean.numpy())
std = list(std.numpy())
print(f"Mean: {mean}")
print(f"Standard Deviation: {std}")
return mean, std
data_path = "G:\\04_deep-learning-for-image-processing-master\\pytorch_classification\\swin_transformer\\flower_photos"
mean, std = get_mean_and_std(data_path)
#if __name__ == '__main__':
#main("G:\\04_deep-learning-for-image-processing-master\\pytorch_classification\\swin_transformer\\flower_photos")
文章出处登录后可见!
已经登录?立即刷新