在 Pytorch 中,我如何对 DataLoader 进行洗牌?

xiaoxingxing pytorch 535

原文标题In Pytorch, how can i shuffle a DataLoader?

我有一个包含 10000 个样本的数据集,其中类以有序的方式存在。首先,我将数据加载到 ImageFolder 中,然后加载到 DataLoader 中,我想将此数据集拆分为 train-val-test 集。我知道 DataLoader 类有一个 shuffle 参数,但这对我不利,因为它只会在枚举发生时对数据进行随机播放。我知道 RandomSampler 函数,但是有了它,我只能从数据集中随机抽取 n 个数据,而且我无法控制要取出的数据,因此 train、test 和 val 集中可能存在一个样本同时。

有没有办法对 DataLoader 中的数据进行洗牌?我唯一需要的是洗牌,之后我可以对数据进行子集化。

原文链接:https://stackoverflow.com//questions/71576668/in-pytorch-how-can-i-shuffle-a-dataloader

回复

我来回复
  • Umang Gupta的头像
    Umang Gupta 评论

    Subsetdataset 类采用索引(https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset)。您可能可以利用它来获得此功能,如下所示。基本上,您可以逃脱通过改组索引然后选择数据集的子集。

    # suppose dataset is the variable pointing to whole datasets
    N = len(dataset)
    
    # generate & shuffle indices
    indices = numpy.arange(N)
    indices = numpy.random.permutation(indices)
    # there are many ways to do the above two operation. (Example, using np.random.choice can be used here too
    
    # select train/test/val, for demo I am using 70,15,15
    train_indices = indices [:int(0.7*N)]
    val_indices = indices[int(0.7*N):int(0.85*N)]
    test_indices = indices[int(0.85*N):]
    
    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)
    test_dataset = Subset(dataset, test_indices)
    
    2年前 0条评论