torchvision.datasets下载数据集后,怎样取其中部分数据做训练

Mnist数据集为例

一、直接在整个数据集上训练

数据下载和预处理

trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
dataset_train = torchvision.datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
dataset_test = torchvision.datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)

然后即可放入Dataloader中,

trainDataLoader = torch.utils.data.DataLoader(dataset=trainData, batch_size=batch_size, shuffle=True)  # 批量读取并打乱
testDataLoader = torch.utils.data.DataLoader(dataset=testData, batch_size=batch_size)

训练时迭代读取数据

for epoch in range(1, epochs + 1):
    processBar = tqdm(trainDataLoader, unit='step')
    model.train(True)
    train_loss, train_correct = 0, 0
    for step, (train_imgs, labels) in enumerate(processBar):

        if torch.cuda.is_available():  # GPU可用
            train_imgs = train_imgs.cuda()
            labels = labels.cuda()
        model.zero_grad()  # 梯度清零
        outputs = model(train_imgs)  # 输入训练集
        loss = criterion(outputs, labels)  # 计算损失函数
        predictions = torch.argmax(outputs, dim=1)  # 得到预测值
        correct = torch.sum(predictions == labels)
        accuracy = correct / labels.shape[0]  # 计算这一批次的正确率
        loss.backward()  # 反向传播
        optimizer.step()  # 更新优化器参数
        processBar.set_description("[%d/%d] Loss: %.4f, Acc: %.4f" %  # 可视化训练进度条设置
                                   (epoch, epochs, loss.item(), accuracy.item()))
二、取数据集上部分数据训练

以数据集索引提取训练数据
取其中序号为data_idx的数据

dataset_train[data_idx][0] #取图像数据(image)
dataset_train[data_idx][1] #取对应的标签(label)

于是,采样一些数据作为训练集可使用如下代码

sample_index = [i for i in range(500)] #假设取前500个训练数据
X_train = []
y_train = []
for i in sample_index:
    X = dataset_train[i][0]
    X_train.append(X)
    y = dataset_train[i][1]
    y_train.append(y)

sampled_train_data = [(X, y) for X, y in zip(X_train, y_train)] #包装为数据对
trainDataLoader = torch.utils.data.DataLoader(sampled_train_data, batch_size=16, shuffle=True)

将trainDataloader带入训练过程中即可。

参考资料

[1] PyTorch入门——实现MNIST分类

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2023年12月7日
下一篇 2023年12月7日

相关推荐