Pytorch之Dataset和Dataloaders

Dataset

torch.utils.data.Dataset

PyTorch 域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集子类torch.utils.data.Dataset化并实现了特定于特定数据的功能。

Dataset只能返回单个数据

自定义 Dataset 类必须实现三个函数:

__init__、__len__、__getitem__
import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    #__init__ 函数在实例化 Dataset 对象时运行一次。初始化包含图像、注释文件和两种转换的目录
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
    # 返回数据集中的样本数
    def __len__(self):
        return len(self.img_labels)
    #__getitem__ 函数从给定索引处的数据集中加载并返回一个样本idx。
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

Dataloaders

检索我们数据集的Dataset特征并一次标记一个样本。在训练模型时,我们通常希望以“小批量”的形式传递样本,在每个 epoch 重新洗牌以减少模型过度拟合,并使用 Python multiprocessing加速数据检索。

DataLoader是一个可迭代的,它在一个简单的 API 中为我们抽象了这种复杂性。

Dataloader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,work_init_fn=None)
  • dataset是一个torch.unils.data.Dataset类的实例
  • batch_size是迷你批次的大小
  • shuffle代表数据会不会被打乱
  • sampler是自定义的采样器(shuffle=True时会构建默认采样器,如想自定义采样方法,可构造一个torch.unils.data.Sampler的实例进行采样,并设shuffle=False),其中采样器是一个python迭代器,每次迭代时返回一个数据的下标索引
  • batch_sampler类似于sampler,不过返回的是一个迷你批次的数据索引
  • num_workers是数据载入器使用的进程数目,默认为0
  • collate_fn定义如何把一批dataset的实例转换为包含迷你批次数据的张量
  • pin_memory参数会把数据转移到和GPU内存相关联的CPU内存中,从而加快GPU载入数据的速度
  • drop_last的设置决定了是否要把最后一个迷你批次的数据丢弃掉
  • timeout的值如果大于0,就会决定在多进程情况下对数据的等待时间
  • worker_init_fn决定了每个数据载入的子进程开始时运行的函数,这个函数运行在随机种子设置以后、数据载入之前
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
  • 遍历Dataloaders

我们已将该数据集加载到 中,DataLoader并且可以根据需要遍历数据集。下面的每次迭代都会返回一批train_features和train_labels(分别包含batch_size=64特征和标签)。因为我们指定shuffle=True了 ,所以在我们遍历所有批次之后,数据被打乱

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2022年6月13日 下午12:24
下一篇 2022年6月13日 下午12:26

相关推荐