在 Pytorch 中,我如何对 DataLoader 进行洗牌?
pytorch 535
原文标题 :In Pytorch, how can i shuffle a DataLoader?
我有一个包含 10000 个样本的数据集,其中类以有序的方式存在。首先,我将数据加载到 ImageFolder 中,然后加载到 DataLoader 中,我想将此数据集拆分为 train-val-test 集。我知道 DataLoader 类有一个 shuffle 参数,但这对我不利,因为它只会在枚举发生时对数据进行随机播放。我知道 RandomSampler 函数,但是有了它,我只能从数据集中随机抽取 n 个数据,而且我无法控制要取出的数据,因此 train、test 和 val 集中可能存在一个样本同时。
有没有办法对 DataLoader 中的数据进行洗牌?我唯一需要的是洗牌,之后我可以对数据进行子集化。
回复
我来回复-
Umang Gupta 评论
该回答已被采纳!
Subset
dataset 类采用索引(https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset)。您可能可以利用它来获得此功能,如下所示。基本上,您可以逃脱通过改组索引然后选择数据集的子集。# suppose dataset is the variable pointing to whole datasets N = len(dataset) # generate & shuffle indices indices = numpy.arange(N) indices = numpy.random.permutation(indices) # there are many ways to do the above two operation. (Example, using np.random.choice can be used here too # select train/test/val, for demo I am using 70,15,15 train_indices = indices [:int(0.7*N)] val_indices = indices[int(0.7*N):int(0.85*N)] test_indices = indices[int(0.85*N):] train_dataset = Subset(dataset, train_indices) val_dataset = Subset(dataset, val_indices) test_dataset = Subset(dataset, test_indices)
2年前