有没有办法为文件夹中的多个 pytorch 文件创建训练数据集?

社会演员多 pytorch 465

原文标题Is there a way to create a train dataset for multiple pytorch files in a folder?

我有一个名为processed_data的文件夹。在这个文件中,我有多个.pt文件,分别命名为data_0.ptdata_1.ptdata_2.ptdata_3.ptdata_4.ptdata_5.pt,…..,data_998.ptdata_999.ptdata_1000.ptdata_1001.pt。所有这些.pt文件代表使用 pytorch-geometric 创建的图形。

我的问题是,如何保存加载所有这些文件以创建我的训练数据集,以便我可以在 DataLoader 中使用它们?

原文链接:https://stackoverflow.com//questions/71678709/is-there-a-way-to-create-a-train-dataset-for-multiple-pytorch-files-in-a-folder

回复

我来回复
  • jhschwartz的头像
    jhschwartz 评论

    火炬 DataLoader 需要一个 Dataset 对象。定义 Dataset 类时,需要实现__init____len____getitem__

    __init__很简单,但也取决于您的确切用例/上下文。假设最简单的情况,我将定义 init 以接收数据文件夹和一个包含训练集文件名称的文件(每行一个)。然后,我将每个文件名存储在一个列表中作为该类的成员。所以我们有:

    def __init__(self, data_folder, data_list_filename):
        self.data_folder = data_folder
        self.data_file_list = open(data_list_filename, 'r').read().splitlines()
    

    好的,现在我们在 Dataloader 中存储了两件事:1)数据文件夹和 2)数据文件名列表。这使得__len__特别容易:

    def __len__(self):
        return len(self.data_file_list)
    

    最后,我们只需要处理__get_item__

    def __getitem__(self, idx):
        filename = self.data_file_list[idx]
        data, label = extract_data_from_file(filename) # this is arbitrary because I don't know how you need to do this
        return data, label
    

    然后将所有这些放在一个类下:

    class MyDataset(Dataset):
        def __init__(self, data_folder, data_list_filename):
            self.data_folder = data_folder
            self.data_file_list = open(data_list_filename, 'r').read().splitlines()
    
        def __len__(self):
            return len(self.data_file_list)
    
        def __getitem__(self, idx):
            filename = self.data_file_list[idx]
            data, label = extract_data_from_file(filename) # idk how you plan to do this
            return data, label
    

    显然,您的确切用途看起来会有所不同。但这应该让你开始。

    我要提到的最后一件事:我故意将文件名放入这样的列表中(而不是使用idx直接获取编号的文件名),因为这种方式允许 DataLoader 随机播放数据,我假设你会想要这样做.

    2年前 0条评论