在 Python 中从数据框创建迭代器
pytorch 473
原文标题 :Create iterator from a Data Frame in Python
我正在使用 Seq2Seq 进行 NLP 项目。我从我的数据集中创建了一个数据框,然后使用数据加载器创建了一个批处理迭代器,请参见以下代码:
# creates lists containing each pair
original_word_pairs = [[w for w in l.split('\t')] for l in lines[:num_examples]]
data = pd.DataFrame(original_word_pairs, columns=["src", "trg"])
# conver the data to tensors and pass to the Dataloader
# to create a batch iterator
class MyData(Dataset):
def __init__(self, X, y):
self.data = X
self.target = y
# TODO: convert this into torch code is possible
self.length = [ np.sum(1 - np.equal(x, 0)) for x in X]
def __getitem__(self, index):
x = self.data[index]
y = self.target[index]
x_len = self.length[index]
return x,y,x_len
def __len__(self):
return len(self.data)
train_dataset = MyData(input_tensor_train, target_tensor_train)
val_dataset = MyData(input_tensor_val, target_tensor_val)
train_dataset = DataLoader(train_dataset, batch_size = BATCH_SIZE,
drop_last=True,
shuffle=True)
test_dataset= DataLoader(val_dataset, batch_size = BATCH_SIZE,
drop_last=True,
shuffle=True)
那是我代码的一部分,问题是我想像这样使用迭代器
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
但是我得到一个错误“AttributeError:’list’对象没有属性’src’”我怎样才能使用迭代器并访问特定的列?
回复
我来回复-
aretor 评论
你可以在你的
Dataset
中重新定义__getitem__
来返回一个字典:def __getitem__(self, index): x = self.data[index] y = self.target[index] x_len = self.length[index] return {"src": x, "trg": y, "x_len": x_len}
默认
collate_fn
ofDataLoader
会注意提供包含批次而不是单个观察的字典,但是您需要将x_len
转换为tensor
into__getitem__
才能使其工作(或者您可以传递自定义collate_fn
)。2年前