PyTorch深度学习实战——数据读取~

前言

工欲善其事,必先利其器。在学习神经网络时,熟练使用一种神经网络框架是十分有必要的,可以让你省去许多繁琐的步骤。PyTorch是目前最主流的深度学习框架之一,被广泛应用于各个领域。

正确读取数据是训练一个神经网络模型的第一步,本文主要介绍pytorch中是如何定义和读取这些数据集的。在pytorch中已经包含了部分常用数据集(例如MNIST等),可以直接使用,但在实际工程应用中仅仅使用pytorch自带的数据集远远不够,有时还需要自定义数据集来满足需求。下面内容中,将从pytorch自带数据集和自定义数据集两部分介绍数据集制作和读取方法。

🔔若是想了解更多pytorch信息以及环境搭建步骤,可以参考以下文章:
PyTorch简介及环境搭建

一、PyTorch自带数据集及读取方法简介

1. 简介

pytorch中所有的数据集均继承自torch.utils.data.Dataset,它们都需要实现 __getitem__和 __ len __ 两个接口,因此,实现一个数据集的核心也就是实现这两个接口。

Pytorch的torchvision中已经包含了很多常用数据集以供我们使用,如Imagenet,MNIST,CIFAR10、VOC等,利用torchvision可以很方便地读取。对于pytorch自带的图像数据集,它们都已经实现好了上述的两个核心接口。因此这里先忽略这部分细节,先介绍用法,关于 __ getitem __ 和 __ len __ 两个方法,我们将在后面的自定义数据集读取方法中详细介绍。

Pytorch支持哪些常用数据加载呢?可以参见:torchvision.datasets

以读取pytorch自带的CIFAR10数据集为例进行介绍,CIFAR10数据集的定义方法如下:

dataset_dir = '../../../dataset/'
torchvision.datasets.CIFAR10(dataset_dir, train=True, transform=None, target_transform=None, download=False) 

范围:

  • dataset_dir:存放数据集的路径。
  • train(bool,可选):如果为True,则构建训练集,否则构建测试集。
  • transform:定义数据预处理,数据增强方案都是在这里指定。
  • target_transform:标注的预处理,分类任务不常用。
  • download:是否下载,若为True则从互联网下载,如果已经在dataset_dir下存在,就不会再次下载。

2. 读取示例

为了可视化数据读取方法,给出以下两个示例:
读取示例1(从网上自动下载):

import torchvision 
        
# 读取训练集
train_data = torchvision.datasets.CIFAR10('../../../dataset', 
                                                      train=True, 
                                                      transform=None,  
                                                      target_transform=None, 
                                                      download=True)          
# 读取测试集
test_data = torchvision.datasets.CIFAR10('../../../dataset', 
                                                      train=False, 
                                                      transform=None, 
                                                      target_transform=None, 
                                                      download=True)      

以上是一个简单的阅读示例,接下来我们尝试添加一些内容。
读取示例2(示例1基础上附带数据增强):

在使用API读取数据时,API中的transform参数指定了导入数据集时需要对图像进行何种变换操作。对于图像进行各种变换来增加数据的丰富性称为数据增强,是一种常用操作。

一般的,我们使用torchvision.transforms中的函数来实现数据增强,并用transforms.Compose将所要进行的变换操作都组合在一起,其变换操作的顺序按照在transforms.Compose中出现的先后顺序排列。在transforms中有很多实现好的数据增强方法,在这里我们尝试使用缩放,随机颜色变换、随机旋转、图像像素归一化等组合变换。

import torchvision 
import torchvision.transforms as transforms        

# 读取训练集
custom_transform=transforms.transforms.Compose([
              transforms.Resize((64, 64)),    # 缩放到指定大小 64*64
              transforms.ColorJitter(0.2, 0.2, 0.2),    # 随机颜色变换
              transforms.RandomRotation(5),    # 随机旋转
              transforms.Normalize([0.485,0.456,0.406],    # 对图像像素进行归一化
                                   [0.229,0.224,0.225])])
train_data=torchvision.datasets.CIFAR10('../../../dataset', 
                                        train=True,                                       
                                        transform=custom_transforms,
                                        target_transform=None, 
                                        download=False)          

数据集定义完成后,我们还需要进行数据加载。Pytorch提供DataLoader来完成对于数据集的加载,并且支持多进程并行读取。
DataLoader使用示例:

import torch
import torchvision 
from torch.utils.data.dataset import Dataset  

# 读取数据集
train_data=torchvision.datasets.CIFAR10('../../../dataset', train=True, 
                                                      transform=None,  
                                                      target_transform=None, 
                                                      download=True)          
# 实现数据批量读取
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=2,
                                           shuffle=True,
                                           num_workers=4)        

这里batch_size设置了批量大小,shuffle设置为True在装载过程中为随机乱序,num_workers>=1表示多进程读取数据,在Win下num_workers只能设置为0,否则会报错。

3. CIFAR完整读取示例

接下来,我们对之前的内容进行整合。

这里将两个预处理操作通过Compose放在了transform中,第一步ToTensor将数据转化为张量,第二步通过Normalize将数据化为正态分布值。前面的(0.5,0.5,0.5)是 RGB三个通道上的均值,后面(0.5, 0.5, 0.5)是三个通道的标准差,Normalize对每个通道执行以下操作:image =(图像-平均值)/ std。当mean,std都是0.5时将使图像在[-1,1]范围内归一化。例如,最小值0将转换(0-0.5)/0.5=-1。

接着使用DataLoader将数据分为多个批次,batch_size指定每个批次包含几个图片,shuffle为是否打乱图片,num_workers指定多个线程去加载数据。当训练很快、加载数据时间过慢时会导致模型等待数据加载而变慢,这时可以采用多线程来加载数据。

import torch
import torchvision

# 定义数据预处理操作
transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载数据
data_path = 'D:/Temp/MachineLearning/data'
train_set = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform)
# 封装为批数据
train_loader = torch.utils.data.DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=4, shuffle=False, num_workers=2)
# 定义标签值
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

我们通过matplotlib打印其中的一个批次图片和标签。由于之前将图片标准化,所以需要进行反标准化操作。由于CiFAR10的图片数据为3×32×32,需要使用transpose将其转为32×32×3。

import matplotlib.pyplot as plt
import numpy as np

# 输出图像的函数
def imshow(img):
    img = img / 2 + 0.5     # 反标准化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# 获取一个批次的训练图片、标签
images, labels = iter(train_loader).next()
# 显示图片
imshow(torchvision.utils.make_grid(images))
# 打印图片标签
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

2.自定义数据集及读取方式

除了pytorch自带的数据集外,在实际应用中,我们可能还需要从其他各种不同的数据集或自己构建的数据集(将其统称为自定义数据集)中读取图像,这些图像可能来自于开源数据集网站,也可能是我们自己整理得到的。对于这样的图像数据:首先,我们要确定是否包含标签文件,如果没有就要自己先创建标签文件;然后,我们就可以使用pytorch来读取数据集了。道理是不是很简单?接下来,该小节我们将着重讲解pytorch自定义数据集的制作和读取方法。

在上一节中,我们已经能够使用Dataset和DataLoader两个类实现pytorch自带数据集的读写。其实,我们完全可以将上节的内容看作是pytorch读取数据“通用解”中的一种特殊情况,只不过它满足了一些特殊的条件——pytorch帮你下载好了数据并制作了数据标签,然后通过使用Dataset和DataLoader两个类完成了数据集的构建和读取。

图像数据不必多说,就是训练测试模型使用的图片。这里的索引文件指的就是记录数据标注信息的文件,我们必须有一个这样的文件来充当“引路人”,告诉程序哪个图片对应哪些标注信息,例如图片img_0013.jpg对应的类别为狗。之后便可以像套公式一样使用Dataset和DataLoader两个类完成数据读取。

1. 图像索引文件制作

图像索引文件只要能够合理记录标注信息即可,内容可以简单也可以复杂,但有一条要注意:内容是待读取图像的名称(或路径)及标签,并且读取后能够方便实现索引。该文件可以是txt文件,csv文件等多种形式,甚至是一个list都可以,只要是能够被Dataset类索引到即可。

我们以读取MNIST数据为例,构建分类任务的图像索引文件,对于其他任务的索引文件,相信你在学过分类任务的索引文件制作后将会无师自通。

通过https://www.cs.utoronto.ca/~kriz/cifar.html我们下载MNIST的图像和标签数据到Dive-into-CV-PyTorch/dataset/MNIST/目录下,得到下面的压缩文件并解压暂存,以用来充当自己的图像数据集。

train-images-idx3-ubyte.gz: training set images (9912422 bytes) ➡ train-images-idx3-ubyte(解压后)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes) ➡ train-labels-idx1-ubyte(解压后)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) ➡ t10k-images-idx3-ubyte(解压后)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes) ➡ t10k-labels-idx1-ubyte(解压后)

我们运行下面的代码来实现图像数据的本地存储和索引文件的制作。我们根据训练集和测试集分别存储图像,并分别为训练集和测试集创建索引文件,并将图像文件记录在索引文件中。名称和标签信息。

import os
from skimage import io
import torchvision.datasets.mnist as mnist

# 数据文件读取
root = r'./MNIST/'  # MNIST解压文件根目录
train_set = (
    mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
)
test_set = (
    mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
)

# 数据量展示
print('train set:', train_set[0].size())
print('test set:', test_set[0].size())


def convert_to_img(save_path, train=True):
    '''
    将图片存储在本地,并制作索引文件
    @para: save_path  图像保存路径,将在路径下创建train、test文件夹分别存储训练集和测试集
    @para: train      默认True,本地存储训练集图像,否则本地存储测试集图像 
    '''
    if train:
        f = open(save_path + 'train.txt', 'w')
        data_path = save_path + '/train/'
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img, label) in enumerate(zip(train_set[0], train_set[1])):
            img_path = data_path + str(i) + '.jpg'
            io.imsave(img_path, img.numpy())
            int_label = str(label).replace('tensor(', '')
            int_label = int_label.replace(')', '')
            f.write(str(i)+'.jpg' + ',' + str(int_label) + '\n')
        f.close()
    else:
        f = open(save_path + 'test.txt', 'w')
        data_path = save_path + '/test/'
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img, label) in enumerate(zip(test_set[0], test_set[1])):
            img_path = data_path + str(i) + '.jpg'
            io.imsave(img_path, img.numpy())
            int_label = str(label).replace('tensor(', '')
            int_label = int_label.replace(')', '')
            f.write(str(i)+'.jpg' + ',' + str(int_label) + '\n')
        f.close()

        
# 根据需求本地存储训练集或测试集
save_path = r'./MNIST/mnist_data/'
convert_to_img(save_path, True)
convert_to_img(save_path, False)

上面的代码虽然比较繁琐,但是可以清晰的展示出图片和我们索引文件内容的对应关系,还可以实现图片的本地存储和索引文件的构建。我们在索引文件中记录了每张图片的文件名和标签,每一行对应一张图片的信息,这也是为了方便数据索引。其实我们可以直接在索引文件中记录每张图片的路径和标签信息,但是考虑到数据的可移植性,只记录了图片的名称。

通过上面的示例,其实是为了展示自制分类数据集的数据形式与索引文件之间的关系,以方便后续构建自己的Dataset。

2. 构建自己的Dataset

想要读取我们自己数据集中的数据,就需要写一个Dataset的子类来定义我们的数据集,并必须对 __ init __ 、__ getitem __ 和 __ len __ 方法进行重载。

首先我们看源码如下:

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
	raise NotImplementedError
def __len__(self):
	raise NotImplementedError
def __add__(self, other):
	return ConcatDataset([self, other])

下面我们看一下构建Dataset类的基本结构:

from torch.utils.data.dataset import Dataset

class MyDataset(Dataset):  # 继承Dataset类
   def __init__(self):
       # 初始化图像文件路径或图像文件名列表等
       pass
   
   def __getitem__(self, index):
        # 1.根据索引index从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open,cv2.imread)
        # 2.预处理数据(例如torchvision.Transform)
        # 3.返回数据对(例如图像和标签)
       pass
   
   def __len__(self):
       return count  # 返回数据量
  • __ init __() : 初始化模块,初始化该类的一些基本参数。
  • __ getitem __() : 接收一个index,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息,返回数据对(图像和标签)。
  • __ len __() : 返回所有数据的数量。

重点说明一下 __ getitem __ () 函数,该函数接收一个index,也就是索引值。只要是具有索引的数据类型都能够被读取,如list,Series,Dataframe等形式。为了方便,我们一般采用list形式将文件代入函数中,该list中的每一个元素包含了图片的路径或标签等信息,以方便index用来逐一读取单一样本数据。在__ getitem __() 函数内部,我们可以选择性的对图像和标签进行预处理等操作,最后返回图像数据和标签。

我们延续上一小节自制MNIST索引文件,构建自己的Dataset类,以便通过该类读取特定图像数据。

import pandas as pd
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class MnistDataset(Dataset):

    def __init__(self, image_path, image_label, transform=None):
        super(MnistDataset, self).__init__()
        self.image_path = image_path  # 初始化图像路径列表
        self.image_label = image_label  # 初始化图像标签列表
        self.transform = transform  # 初始化数据增强方法

    def __getitem__(self, index):
        """
        获取对应index的图像,并视情况进行数据增强
        """
        image = Image.open(self.image_path[index])
        image = np.array(image)
        label = float(self.image_label[index])

        if self.transform is not None:
            image = self.transform(image)

        return image, torch.tensor(label)

    def __len__(self):
        return len(self.image_path)

    
def get_path_label(img_root, label_file_path):
    """
    获取数字图像的路径和标签并返回对应列表
    @para: img_root: 保存图像的根目录
    @para:label_file_path: 保存图像标签数据的文件路径 .csv 或 .txt 分隔符为','
    @return: 图像的路径列表和对应标签列表
    """
    data = pd.read_csv(label_file_path, names=['img', 'label'])
    data['img'] = data['img'].apply(lambda x: img_root + x)
    return data['img'].tolist(), data['label'].tolist()


# 获取训练集路径列表和标签列表
train_data_root = './dataset/MNIST/mnist_data/train/'
train_label = './dataset/MNIST/mnist_data/train.txt'
train_img_list, train_label_list = get_path_label(train_data_root, train_label)  
# 训练集dataset
train_dataset = MnistDataset(train_img_list,
                             train_label_list,
                             transform=transforms.Compose([transforms.ToTensor()]))

# 获取测试集路径列表和标签列表
test_data_root = './dataset/MNIST/mnist_data/test/'
test_label = './dataset/MNIST/mnist_data/test.txt'
test_img_list, test_label_list = get_path_label(test_data_root, test_label)
# 测试集sdataset
test_dataset = MnistDataset(test_img_list,
                            test_label_list,
                            transform=transforms.Compose([transforms.ToTensor()]))

上面的代码通过构建 MnistDataset 类,完成了数据集的定义。

首先通过 get_path_label() 函数获得图像的路径和标签列表,并通过 MnistDataset 类中 __ init __() 的 self.image_path 和 self.image_label 进行存储,我们能够看到此处的图像列表中的数据和标签列表中的数据是一一对应的关系,同时我们在初始化中还初始化了 transform,以实现后续中图像增强操作。

MnistDataset 类的 __ getitem __() 函数完成了图像读取和增强。该函数的前三行,我们通过 index 读取了 self.image_path 和 self.image_label (两个list,也是前文中提到的list)中的图像和标签。第四、五行,对图像进行处理,在 transform 中可以实现旋转、裁剪、仿射变换、标准化等等一系列操作。最后返回处理好的图像数据和标签。

通过 MnistDataset 类的定义,pytorch就知道了如何获取一张图片并完成相应的预处理工作。这里我们尝试从数据集中读取一些数据,打印下输出结果进行观察:

>>> train_iter = iter(train_dataset)
>>> next(train_iter)

(tensor([[[0.0000, 0.0000, 0.0039, 0.0039, 0.0118, 0.0196, 0.0118, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0039, 0.0039, 0.0000, 0.0000, 0.0039,
           0.0000, 0.0000, 0.0157, 0.0314, 0.0000, 0.0667, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          ...,
          [0.0667, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0510, 0.0471, 0.0078, 0.0118, 0.0000, 0.0157, 0.0000, 0.0196,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000]]]),
 tensor(5.))

>>> next(train_iter)

(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0392, 0.0039, 0.0000, 0.0157, 0.0000, 0.0000, 0.0314,
           0.0000, 0.0157, 0.0314, 0.0039, 0.0000, 0.0431, 0.0039, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000]]]),
 tensor(0.))

每个图像和标签都封装成一个双元组,第一个元素是图像矩阵,第二个元素是图像标签。让我们尝试打印每张图片的尺寸和标签信息,看看结果:

>>> for i in train_dataset:
        img, label = i
        print(img.size(), label)

torch.Size([1, 28, 28]) tensor(5.)
torch.Size([1, 28, 28]) tensor(0.)
torch.Size([1, 28, 28]) tensor(4.)
...
torch.Size([1, 28, 28]) tensor(5.)
torch.Size([1, 28, 28]) tensor(6.)
torch.Size([1, 28, 28]) tensor(8.)

>>> print(train_dataset.__len__())
train num: 60000

需要注意的是,当 Dataset 创建好后并没有将数据生产出来,我们只是定义了数据及标签生产的流水线,只有在真正使用时,如手动调用 next(iter(train_dataset)),或被 DataLoader调用,才会触发数据集内部的 __ getitem __() 函数来读取数据,通常CV入门者对于这一块会存在困惑。

3. 使用DataLoader批量读取数据

在构建好自己的 Dataset 之后,就可以使用 DataLoader 批量的读取数据,相当于帮我们完成一个batch的数据组装工作。Dataloader 为一个迭代器,最基本的使用方法就是传入一个 Dataset 对象,在Dataloader中,会触发Dataset对象中的 __ gititem __() 函数,逐次读取数据,并根据 batch_size 产生一个 batch 的数据,实现批量化的数据读取。

DataLoader使用方式如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)

范围:

  • dataset:加载的数据集(Dataset对象)。
  • batch_size:一个批量数目大小。
  • shuffle::是否打乱数据顺序。
  • sampler: 样本抽样方式。
  • num_workers:使用多进程加载的进程数,0代表不使用多进程。
  • collate_fn: 将多个样本数据组成一个batch的方式,一般使用默认的拼接方式,可以通过自定义这个函数来完成一些特殊的读取逻辑。
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些。
  • drop_last:为True时,dataset中的数据个数不是batch_size整数倍时,将多出来不足一个batch的数据丢弃。

承接上一节中的 train_dataset 和 test_dataset,使用 DataLoader 进行批量化读取,此处仅使用了常用的几个参数。

from torch.utils.data import DataLoader

# 训练数据加载
train_loader = DataLoader(dataset=train_dataset,  # 加载的数据集(Dataset对象)
                         batch_size=3,  # 一个批量大小
                         shuffle=True,  # 是否打乱数据顺序
                         num_workers=4)  # 使用多进程加载的进程数,0代表不使用多进程(win系统建议改成0)
# 测试数据加载
test_loader = DataLoader(dataset=test_dataset,
                        batch_size=3,
                        shuffle=False,
                        num_workers=4)

如上面的代码,为方便展示加载后的结果,我们定义了一个批量大小为 3 的 DataLoader 来加载训练集,并且打乱了数据顺序,在测试集的加载中,我们并没有打乱顺序,这都可以根据自己的需求进行调整。现在,train_loader 已经将原来训练集中的60000张图像重新“洗牌”后按照每3张一个batch划分完成(test_loader同理),进一步查看划分后的数据格式。

>>> loader = iter(train_loader)
>>> next(loader)

[tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0157, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0157, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]],
 
         [[[0.0000, 0.0000, 0.0118,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
           [0.0118, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           ...,
           [0.0510, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0157, 0.0196,  ..., 0.0000, 0.0000, 0.0000]]],
 
         [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]]]),
 tensor([2., 3., 9.])]

>>> next(loader)

[tensor([[[[0.0118, 0.0000, 0.0275,  ..., 0.0000, 0.0000, 0.0000],
           [0.0039, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0118, 0.0039, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]],
 
         [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0275, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0078, 0.0078, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0118, 0.0275],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]],
 
         [[[0.0196, 0.0000, 0.0118,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0510,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0392,  ..., 0.0000, 0.0000, 0.0000],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]]]),
 tensor([3., 8., 3.])]

经过DataLoader的封装,每3(一个batch_size数量)张图像数据及对应的标签被封装为一个二元元组,第一个元素为四维的tensor形式,第二个元素为对应的图像标签数据。按照如下方式将所有train_loader中的数据进行展示。

>>> for i, img_data in enumerate(train_loader, 1):
        images, labels = img_data
        print('batch{0}:images shape info-->{1} labels-->{2}'.format(i, images.shape, labels))
   
batch1:images shape info-->torch.Size([3, 1, 28, 28]) labels-->tensor([2., 3., 9.])
batch2:images shape info-->torch.Size([3, 1, 28, 28]) labels-->tensor([3., 8., 3.])
batch3:images shape info-->torch.Size([3, 1, 28, 28]) labels-->tensor([4., 7., 6.])
...
batch19998:images shape info-->torch.Size([3, 1, 28, 28]) labels-->tensor([0., 7., 7.])
batch19999:images shape info-->torch.Size([3, 1, 28, 28]) labels-->tensor([3., 7., 0.])
batch20000:images shape info-->torch.Size([3, 1, 28, 28]) labels-->tensor([9., 7., 5.])

>>> len(train_loader)
20000

我们将DataLoader与Dataset分别处理后的数据比较可以发现出两者的不同:Dataset是对本地数据读取逻辑的定义;而DataLoader是对Dataset对象的封装,执行调度,将一个batch size的图像数据组装在一起,实现批量读取数据。

总结

以上就是使用PyTorch读取数据的方法介绍。本文主要参考了动手学CV-Pytorch(链接见参考资料),是一个高质量的实战教程项目,包含了许多计算机视觉和Pytorch的内容,新手友好、注重实战,萌新可冲!

参考

[1]动手学CV-Pytorch.

版权声明:本文为博主吮指原味鸡!原创文章,版权归属原作者,如果侵权,请联系我们删除!

原文链接:https://blog.csdn.net/m0_52573426/article/details/123299254

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
社会演员多的头像社会演员多普通用户
上一篇 2022年3月6日
下一篇 2022年3月7日

相关推荐