什么是数据加载器?
- 深度学习是由数据支撑起来的,所以我们一般在做深度学习的时候往往伴随着大量、复杂的数据。如果把所有的数据全部加载到内存上,容易把电脑的内存“撑爆”,所以要分批次一点点加载数据
- 每一种深度学习的框架都有自己所规定的数据格式,数据加载器就有了必要的作用
数据加载器就是把大量的数据,分批次加载和处理成框架所需要的数据格式
数据分批次加载
使用PyTorch内置的模块 torch.utils.data.DataLoader()
数据加载器:
参数
- dataset:数据集
- batch_size: 每一批数据的总量
- shuffle: True or False
为True的时候会将数据打乱再分批
PyTorch自带MNIST数据的分批
手写数字数据集
- 加载数据
MNIST数据集在torchvision.datasets.MNIST中
import torch
import torchvision
train_dataset = torchvision.datasets.MNIST(root="./data1",train=True,transform=torchvision.transforms.ToTensor(),download=False)
- 取出一张图片展示
import numpy as np
import matplotlib.pyplot as plt
# 获取到第一条数据
data,label = train_dataset[0]
# 因为数据集里面的数据进行过归一化,所以要反归一化
img = np.array(data) * 255
img = img.reshape(28,28).astype(np.uint8)
# 展示
plt.imshow(img,'gray')
plt.show()
- 使用DataLoader方法分批次
from torch.utils.data import DataLoader
# 创建DataLoader对象
train_loader = DataLoader(dataset=train_dataset,batch_size=100,shuffle=True)
num_epochs = 1
for epoch in range(num_epochs):
# 第二层循环会每次打开一批次的数据 当前一批次为100
for i,(inputs,labels) in enumerate(train_loader):
print(f'Epoch: {epoch+1}/{num_epochs},Step {i+1}/{len(train_dataset)/100}| Inputs {inputs.shape} | Labels {labels.shape}')
# 当前inputs和labels里面有100条数据
print(labels)
break
print(len(train_loader))
自定义Dataset类
DataLoader()的dataset参数必须继承于PyTorch的Dataset类
- 只有继承了PyTorch中的Dataset接口的类,才能够被传入DataLoader中
自定义一个Dataset
类,让PyTorch去认识我们的数据
步骤:
-
创建一个类继承Dataset
-
__init__
魔法方法内读取数据- 获取到数据的长度
- 获取到特征数据和输出标签
-
__getitem__
方法内返回第index条数据 -
__len__
方法内返回数据的长度
from torch.utils.data import Dataset
class WineDataset(Dataset):
def __init__(self):
# 读取csv数据
xy = pd.read_csv("./wine.csv")
# 获取到数据的长度
self.n_samples = xy.shape[0]
# 特征数据
self.x_data = torch.from_numpy(xy.values[:,1:])
# 输出标签
self.y_data = torch.from_numpy(xy.values[:,0])
def __getitem__(self,index):
# 遍历的时候返回数据 可迭代对象
return self.x_data[index],self.y_data[index]
def __len__(self.n_samples):
# 返回数据长度
return self.n_sampleso
查看自定义的Dataset类:
# 使用DataLoader去加载数据集合
from torch.utils.data import DataLoader
import torch
wineData = WineDataset()
# 传入加载器
train_loader = DataLoader(dataset=wineData,batch_size=4,shuffle=True)
# 分批训练
# 迭代次数
epoch_num = 5
total_samples = len(wineData)
print("total_samples:",total_samples)
# 开始训练
for epoch in range(epoch_num):
for i,(inputs,labels) in enumerate(train_loader):
print(i,labels)
每次批次加载4条数据
因为数据分为4批次,是有余数的,所以最后一行数据不是4条:
文章出处登录后可见!
已经登录?立即刷新