将不同的自定义函数应用于不同的数据集

青葱年少 pytorch 408

原文标题Apply different custom functions to different datasets

给定以下数据集类:

class MYDataset(Dataset):
    def __init__(self, path):
        df = read_csv(path, header=None, delimiter=r"\s+")
        df = df.iloc[:, 1:-1].values
        self.X = df[:, 1:]
        self.y = df[:, 0]
        self.y = self.y.reshape((len(self.y), 1))

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return [self.X[idx], self.y[idx]]
    
    def get_splits(self, n_test=0.3):
        testing_size = round(n_test * len(self.X))
        train_size = len(self.X) - testing_size
        val_size = round(testing_size /2)
        test_size = testing_size - val_size
        return random_split(self, [train_size, val_size, test_size])
    
def prepare_data(path):
    dataset = MYDataset(path)
    train, val, test = dataset.get_splits()
    train_dl = DataLoader(train, batch_size=64, shuffle=True)
    val_dl = DataLoader(val, batch_size=len(val), shuffle=False)
    test_dl = DataLoader(test, batch_size=len(test), shuffle=False)
    return train_dl, val_dl, test_dl

以及下面的其他 2 个自定义功能:

  def augment_data(array):
     ...
     return(gen_data)
  def BINNED(dataframe):
     ...
     return(data_ohe, 0,1)

如何将函数BINNED应用于训练、验证和测试子集,并将函数augment_data仅应用于“训练”子集?

原文链接:https://stackoverflow.com//questions/71483593/apply-different-custom-functions-to-different-datasets

回复

我来回复
  • Phoenix的头像
    Phoenix 评论

    您可以将这两个函数放在您的类中,并在从 MYDataset 实例化 3 个对象时调用BINNED。但是augment_data稍后通过train_dataset调用它。

    class MYDataset(Dataset):
        def __init__(self, path):
            df = read_csv(path, header=None, delimiter=r"\s+")
            df = df.iloc[:, 1:-1].values
            self.X = df[:, 1:]
            self.y = df[:, 0]
            self.y = self.y.reshape((len(self.y), 1))
            self.return_data = self.BINNED(df)   # <----
    
        def __len__(self):
            return len(self.X)
    
        def __getitem__(self, idx):
            return [self.X[idx], self.y[idx]]
        
        def get_splits(self, n_test=0.3):
            testing_size = round(n_test * len(self.X))
            train_size = len(self.X) - testing_size
            val_size = round(testing_size /2)
            test_size = testing_size - val_size
            return random_split(self, [train_size, val_size, test_size])
       
        def BINNED(self, dataframe):  # <----
             ...
             return(data_ohe, 0,1)
    
        def augment_data(self, array): # <----
             ...
            return(gen_data)
    
    
    train_dataset = MYDataset("path to train dataset")
    valid_dataset = MYDataset("path valid dataset")
    test_dataset = MYDataset("path test dataset")
    
    augmntd_train = train_dataset.augment_data(array)  # <----
    
    2年前 0条评论