在 Python 中从数据框创建迭代器

扎眼的阳光 pytorch 473

原文标题Create iterator from a Data Frame in Python

我正在使用 Seq2Seq 进行 NLP 项目。我从我的数据集中创建了一个数据框,然后使用数据加载器创建了一个批处理迭代器,请参见以下代码:

# creates lists containing each pair
original_word_pairs = [[w for w in l.split('\t')] for l in lines[:num_examples]]
data = pd.DataFrame(original_word_pairs, columns=["src", "trg"])

# conver the data to tensors and pass to the Dataloader 
# to create a batch iterator

class MyData(Dataset):
    def __init__(self, X, y):
        self.data = X
        self.target = y
        # TODO: convert this into torch code is possible
        self.length = [ np.sum(1 - np.equal(x, 0)) for x in X]
        
    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        x_len = self.length[index]
        return x,y,x_len
    
    def __len__(self):
        return len(self.data)

train_dataset = MyData(input_tensor_train, target_tensor_train)
val_dataset = MyData(input_tensor_val, target_tensor_val)

train_dataset = DataLoader(train_dataset, batch_size = BATCH_SIZE, 
                     drop_last=True,
                     shuffle=True)
test_dataset= DataLoader(val_dataset, batch_size = BATCH_SIZE, 
                     drop_last=True,
                     shuffle=True)

那是我代码的一部分,问题是我想像这样使用迭代器

for i, batch in enumerate(iterator):
        
        src = batch.src
        trg = batch.trg

但是我得到一个错误“AttributeError:’list’对象没有属性’src’”我怎样才能使用迭代器并访问特定的列?

原文链接:https://stackoverflow.com//questions/71515161/create-iterator-from-a-data-frame-in-python

回复

我来回复
  • aretor的头像
    aretor 评论

    你可以在你的Dataset中重新定义__getitem__来返回一个字典:

    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        x_len = self.length[index]
        return {"src": x, "trg": y, "x_len": x_len}
    

    默认collate_fnofDataLoader会注意提供包含批次而不是单个观察的字典,但是您需要将x_len转换为tensorinto__getitem__才能使其工作(或者您可以传递自定义collate_fn)。

    2年前 0条评论
此站出售,如需请站内私信或者邮箱!