如何从pytorch中的多个数据集中加载数据

xiaoxingxing pytorch 551

原文标题How to load data from multiply datasets in pytorch

我有两个图像数据集 – 室内和室外,它们没有相同数量的示例。

每个数据集的图像包含一定数量的类(最少1个最多4个),这些类可以出现在两个数据集中,每个类有4个类别——红色、蓝色、绿色、白色。示例:室内——猫、狗、马户外- 狗,人类

我正在尝试训练一个模型,我告诉它,“这是一张包含猫的图像,告诉我它的颜色”,无论它是在哪里拍摄的(室内、室外、车内、月球上)

为此,我需要展示我的模型示例,以便每个批次只有一个类别(猫、狗、马或人),但我想从包含这些对象的所有数据集(在本例中为两个)中采样并将它们混合.我怎样才能做到这一点?

它必须考虑到每个数据集中的示例数量不同,并且某些类别出现在一个数据集中,而其他类别可以出现在多个数据集中。并且每个批次必须只包含一个类别。

我将不胜感激,我已经尝试解决这个问题几天了。

原文链接:https://stackoverflow.com//questions/71422639/how-to-load-data-from-multiply-datasets-in-pytorch

回复

我来回复
  • Matthew R.的头像
    Matthew R. 评论

    假设问题是:

    1. 将 2+ 个数据集与可能重叠的对象类别相结合(可通过标签区分)
    2. 每个对象对每种颜色都有 4 个“子类别”(可通过​​标签区分)
    3. 每个批次应该只包含一个对象类别

    第一步将是确保来自两个数据集的对象标签的一致性,如果不是已经一致的话。例如,如果狗类别在第一个数据集中的标签为 0,而在第二个数据集中的标签为 2,那么我们需要确保这两个狗类别正确合并。我们可以用一个简单的数据集包装器来做这个“翻译”:

    class TranslatedDataset(Dataset):
      """
      Args:
        dataset: The original dataset.
        translate_label: A lambda (function) that maps the original
          dataset label to the label it should have in the combined data set
      """
      def __init__(self, dataset, translate_label):
        super().__init__()
        self._dataset = dataset
        self._translate_label = translate_label
    
      def __len__(self):
        return len(self._dataset)
    
      def __getitem__(self, idx):
        inputs, target = self._dataset[idx]
        return inputs, self._translate_label(target)
    

    下一步是将翻译后的数据集组合在一起,这可以使用 ConcatDataset 轻松完成:

    first_original_dataset = ...
    second_original_dataset = ...
    
    first_translated = TranslateDataset(
      first_original_dataset, 
      lambda y: 0 if y is 2 else 2 if y is 0 else y, # or similar
    )
    second_translated = TranslateDataset(
      second_original_dataset, 
      lambda y: y, # or similar
    )
    
    combined = ConcatDataset([first_translated, second_translated])
    

    最后,我们需要将批量采样限制在同一个类中,这可以在创建数据加载器时使用自定义采样器来实现。

    class SingleClassSampler(torch.utils.data.Sampler):
      def __init__(self, dataset, batch_size):
        super().__init__()
        # We need to create sequential groups
        # with batch_size elements from the same class
        indices_for_target = {} # dict to store a list of indices for each target
        
        for i, (_, target) in enumerate(dataset):
          # converting to string since Tensors hash by reference, not value
          str_targ = str(target)
          if str_targ not in indices_for_target:
            indices_for_target[str_targ] = []
          indices_for_target[str_targ] += [i]
    
        # make sure we have a whole number of batches for each class
        trimmed = { 
          k: v[:-(len(v) % batch_size)] 
          for k, v in indices_for_target.items()
        }
    
        # concatenate the lists of indices for each class
        self._indices = sum(list(trimmed.values()))
      
      def __len__(self):
        return len(self._indices)
    
      def __iter__(self):
        yield from self._indices
    

    然后使用采样器:

    loader = DataLoader(
      combined, 
      sampler=SingleClassSampler(combined, 64), 
      batch_size=64, 
      shuffle=True
    )
    

    我没有运行这段代码,所以它可能并不完全正确,但希望它能让你走上正确的轨道。


    torch.utils.data Docs

    2年前 0条评论