将不同的自定义函数应用于不同的数据集
pytorch 408
原文标题 :Apply different custom functions to different datasets
给定以下数据集类:
class MYDataset(Dataset):
def __init__(self, path):
df = read_csv(path, header=None, delimiter=r"\s+")
df = df.iloc[:, 1:-1].values
self.X = df[:, 1:]
self.y = df[:, 0]
self.y = self.y.reshape((len(self.y), 1))
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return [self.X[idx], self.y[idx]]
def get_splits(self, n_test=0.3):
testing_size = round(n_test * len(self.X))
train_size = len(self.X) - testing_size
val_size = round(testing_size /2)
test_size = testing_size - val_size
return random_split(self, [train_size, val_size, test_size])
def prepare_data(path):
dataset = MYDataset(path)
train, val, test = dataset.get_splits()
train_dl = DataLoader(train, batch_size=64, shuffle=True)
val_dl = DataLoader(val, batch_size=len(val), shuffle=False)
test_dl = DataLoader(test, batch_size=len(test), shuffle=False)
return train_dl, val_dl, test_dl
以及下面的其他 2 个自定义功能:
def augment_data(array):
...
return(gen_data)
def BINNED(dataframe):
...
return(data_ohe, 0,1)
如何将函数BINNED
应用于训练、验证和测试子集,并将函数augment_data
仅应用于“训练”子集?
回复
我来回复-
Phoenix 评论
您可以将这两个函数放在您的类中,并在从 MYDataset 实例化 3 个对象时调用
BINNED
。但是augment_data
稍后通过train_dataset调用它。class MYDataset(Dataset): def __init__(self, path): df = read_csv(path, header=None, delimiter=r"\s+") df = df.iloc[:, 1:-1].values self.X = df[:, 1:] self.y = df[:, 0] self.y = self.y.reshape((len(self.y), 1)) self.return_data = self.BINNED(df) # <---- def __len__(self): return len(self.X) def __getitem__(self, idx): return [self.X[idx], self.y[idx]] def get_splits(self, n_test=0.3): testing_size = round(n_test * len(self.X)) train_size = len(self.X) - testing_size val_size = round(testing_size /2) test_size = testing_size - val_size return random_split(self, [train_size, val_size, test_size]) def BINNED(self, dataframe): # <---- ... return(data_ohe, 0,1) def augment_data(self, array): # <---- ... return(gen_data) train_dataset = MYDataset("path to train dataset") valid_dataset = MYDataset("path valid dataset") test_dataset = MYDataset("path test dataset") augmntd_train = train_dataset.augment_data(array) # <----
2年前