Dataset
torch.utils.data.Dataset
PyTorch 域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集子类torch.utils.data.Dataset化并实现了特定于特定数据的功能。
Dataset只能返回单个数据
自定义 Dataset 类必须实现三个函数:
__init__、__len__、__getitem__
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
#__init__ 函数在实例化 Dataset 对象时运行一次。初始化包含图像、注释文件和两种转换的目录
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
# 返回数据集中的样本数
def __len__(self):
return len(self.img_labels)
#__getitem__ 函数从给定索引处的数据集中加载并返回一个样本idx。
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
Dataloaders
检索我们数据集的Dataset特征并一次标记一个样本。在训练模型时,我们通常希望以“小批量”的形式传递样本,在每个 epoch 重新洗牌以减少模型过度拟合,并使用 Python multiprocessing加速数据检索。
DataLoader是一个可迭代的,它在一个简单的 API 中为我们抽象了这种复杂性。
Dataloader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,work_init_fn=None)
- dataset是一个torch.unils.data.Dataset类的实例
- batch_size是迷你批次的大小
- shuffle代表数据会不会被打乱
- sampler是自定义的采样器(shuffle=True时会构建默认采样器,如想自定义采样方法,可构造一个torch.unils.data.Sampler的实例进行采样,并设shuffle=False),其中采样器是一个python迭代器,每次迭代时返回一个数据的下标索引
- batch_sampler类似于sampler,不过返回的是一个迷你批次的数据索引
- num_workers是数据载入器使用的进程数目,默认为0
- collate_fn定义如何把一批dataset的实例转换为包含迷你批次数据的张量
- pin_memory参数会把数据转移到和GPU内存相关联的CPU内存中,从而加快GPU载入数据的速度
- drop_last的设置决定了是否要把最后一个迷你批次的数据丢弃掉
- timeout的值如果大于0,就会决定在多进程情况下对数据的等待时间
- worker_init_fn决定了每个数据载入的子进程开始时运行的函数,这个函数运行在随机种子设置以后、数据载入之前
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
- 遍历Dataloaders
我们已将该数据集加载到 中,DataLoader并且可以根据需要遍历数据集。下面的每次迭代都会返回一批train_features和train_labels(分别包含batch_size=64特征和标签)。因为我们指定shuffle=True了 ,所以在我们遍历所有批次之后,数据被打乱
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
文章出处登录后可见!
已经登录?立即刷新