我如何错误地使用 SubsetRandomSampler?
原文标题 :How am i incorrectly using SubsetRandomSampler?
我有一个自定义数据集:rcvdataset = rcvLSTMDataSet(‘foo.csv’, ‘foolabels.csv’)
我还定义了以下内容:
batch_size = 50
validation_split = .2
shuffle_rcvdataset = True
random_seed= 42
```
rcvdataset_size = len(rcvdataset)
indices = list(range(rcvdataset_size))
split = int(np.floor(validation_split * rcvdataset_size))
if shuffle_rcvdataset :
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(rcvdataset, batch_size=batch_size,
sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(rcvdataset, batch_size=batch_size,
sampler=test_sampler)
```
使用此培训电话:
```
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
```
但是当我尝试运行它时,我得到:
Epoch 1
-------------------------------
Traceback (most recent call last):
File "lstmTrainer.py", line 94, in <module>
train(train_sampler, model, loss_fn, optimizer)
File "lstmTrainer.py", line 58, in train
size = len(dataloader.dataset)
AttributeError: 'SubsetRandomSampler' object has no attribute 'dataset'
如果我改为间接加载数据集:
train(train_loader, model, loss_fn, optimizer)
它告诉我:
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'pandas.core.series.Series'>
我根本不清楚第一个错误是什么。第二个错误是否试图告诉我数据集中的某个地方不是张量?
谢谢你。
根据要求,这里是 rcvDataSet.py:
from __future__ import print_function, division
import os
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
class rcvLSTMDataSet(Dataset):
"""rcv dataset."""
TIMESTEPS = 10
def __init__(self, csv_data_file, annotations_file):
"""
Args:
csv_data_file (string): Path to the csv file with the training data
annotations_file (string): Path to the file with the annotations
"""
self.csv_data_file = csv_data_file
self.annotations_file = annotations_file
self.labels = pd.read_csv(annotations_file)
self.data = pd.read_csv(csv_data_file)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
"""
pytorch expects whatever data is returned is in the form of a tensor. Included, it expects the label for the data.
Together, they make a tuple.
"""
# convert every ten indexes and label into one observation
Observation = []
counter = 0
start_pos = self.TIMESTEPS *idx
avg_1 = 0
avg_2 = 0
avg_3 = 0
while counter < self.TIMESTEPS:
Observation.append(self.data.iloc[idx + counter])
avg_1 += self.labels.iloc[idx + counter][2]
avg_2 += self.labels.iloc[idx + counter][1]
avg_3 += self.labels.iloc[idx + counter][0]
counter += 1
avg_1 = avg_1 / self.TIMESTEPS
avg_2 = avg_2 / self.TIMESTEPS
avg_3 = avg_3 / self.TIMESTEPS
current_labels = [avg_1, avg_2, avg_3]
print(current_labels)
return Observation, current_labels
def main():
loader = rcvLSTMDataSet('foo1.csv','foo2.csv')
j = 0
while j < len(loader.data % loader.TIMESTEPS):
print(loader.__getitem__(j))
j += 1
if "__main__" == __name__:
main()
回复
我来回复-
Phoenix 评论
该回答已被采纳!
原因:如果您查看错误消息,您会发现您这样调用 train 函数:
train(train_sampler, model, loss_fn, optimizer)
这是不正确的,你应该用
train_loader
而不是train_sampler
调用 train()。解决方案:您应该将其更正为:
train(train_loader, model, loss_fn, optimizer)
错误信息:
Epoch 1 ------------------------------- Traceback (most recent call last): File "lstmTrainer.py", line 94, in <module> train(train_sampler, model, loss_fn, optimizer) <------ look here File "lstmTrainer.py", line 58, in train size = len(dataloader.dataset) AttributeError: 'SubsetRandomSampler' object has no attribute 'dataset'
第二个错误消息:
如果您查看您的 Dataset 类
rcvLSTMDataSet
,您会发现observations
list append items withpandas.core.series.Series
type,它不是pythonic scalar numbers,因为您读取了csv文件中的所有列。你应该使用.iloc[....].values
而不是iloc[....]
。通过这样做,您将确保您的列表包含char
orfloat
orint
类型,并且可以顺利地将其转换为张量而不会出错。最后的评论:
你可以在这里阅读关于 Dataloader 和 Samplers,我在这里总结了几点:
采样器用于指定数据加载中使用的索引/键的顺序。
数据加载器结合了数据集和采样器,并提供了对给定数据集的可迭代。
PyTorch 提供了两种数据原语:
torch.utils.data.DataLoader
和torch.utils.data.Dataset
,它们允许您使用预加载的数据集以及您自己的数据。 Dataset 存储样本及其相应的标签,DataLoader 在 Dataset 周围包装了一个可迭代对象,以便轻松访问样本。2年前