自定义采样器在 Pytorch 中的正确使用

社会演员多 pytorch 566

原文标题Custom Sampler correct use in Pytorch

我有一个地图类型的数据集,用于实例分割任务。数据集非常不平衡,因为有些图像只有 10 个对象,而另一些则多达 1200 个。

如何限制每批的对象数量?

一个最小的可重现示例是:

import math
import torch
import random
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler


np.random.seed(0)
random.seed(0)
torch.manual_seed(0)


W = 700
H = 1000

def collate_fn(batch) -> tuple:
    return tuple(zip(*batch))

class SyntheticDataset(Dataset):
    def __init__(self, image_ids):
        self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
        self.num_classes = 9

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx: int):
        """
            returns single sample
        """
        # print("idx: ", idx)

        # deliberately left dangling
        # id = self.image_ids[idx].item()
        # image_id = self.image_ids[idx]
        image_id = torch.as_tensor(idx)
        image = torch.randint(0, 255, (H, W))

        num_objects = random.randint(10, 1200)
        image = torch.randint(0, 255, (3, H, W))
        masks = torch.randint(0, 255, (num_objects, H, W))

        target = {}
        target["image_id"] = image_id

        areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
        boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
        labels = torch.randint(1, self.num_classes, (1, num_objects), dtype=torch.int64)
        iscrowd = torch.zeros(len(labels), dtype=torch.int64)

        target["boxes"] = boxes
        target["labels"] = labels
        target["area"] = areas
        target["iscrowd"] = iscrowd
        target["masks"] = masks

        return image, target, image_id


class BalancedObjectsSampler(BatchSampler):
    """Samples either batch_size images or batches num_objs_per_batch objects.

    Args:
        data_source (list): contains tuples of (img_id).
        batch_size (int): batch size.
        num_objs_per_batch (int): number of objects in a batch.
    Return
        yields the batch_ids/image_ids/image_indices

    """

    def __init__(self, data_source, batch_size, num_objs_per_batch, drop_last=False):
        self.data_source = data_source
        self.sampler = data_source
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.num_objs_per_batch = num_objs_per_batch
        self.batch_count = math.ceil(len(self.data_source) / self.batch_size)

    def __iter__(self):

        obj_count = 0
        batch = []
        batches = []
        counter = 0
        for i, (k, s) in enumerate(self.data_source.iteritems()):
            if (
                obj_count <= obj_count + s
                and len(batch) <= self.batch_size - 1
                and obj_count + s <= self.num_objs_per_batch
                and i < len(self.data_source) - 1
            ):
                # because of https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler
                batch.append(i)
                obj_count += s
            else:
                batches.append(batch)
                yield batch
                obj_count = 0
                batch = []
            counter += 1


obj_sums = {}
batch_size = 10
workers = 4
fake_image_ids = np.random.randint(1600000, 1700000, 100)

# assigning any in-range number objects count to each image
for i, k in enumerate(fake_image_ids):
    obj_sums[k] = random.randint(10, 1200)

obj_counts = pd.Series(obj_sums)

train_dataset = SyntheticDataset(image_ids=fake_image_ids)

balanced_sampler = BalancedObjectsSampler(
    data_source=obj_counts,
    batch_size=batch_size,
    num_objs_per_batch=1500,
    drop_last=False,
)

data_loader_sampler = torch.utils.data.DataLoader(
    train_dataset,
    num_workers=workers,
    collate_fn=collate_fn,
    sampler=balanced_sampler,
)

data_loader_iter = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=workers,
    collate_fn=collate_fn,
)

遍历balanced_sampler

for i, bal_batch in enumerate(balanced_sampler):
    print(f"batch_{i}: ", bal_batch)

产量

batch_0:  [0]
batch_1:  [2, 3]
batch_2:  [5]
batch_3:  [7]
batch_4:  [9, 10]
batch_5:  [12, 13, 14, 15]
batch_6:  [17, 18]
batch_7:  [20, 21, 22]
batch_8:  [24, 25]
batch_9:  [27]
batch_10:  [29]
batch_11:  [31]
batch_12:  [33]
batch_13:  [35, 36, 37]
batch_14:  [39, 40]
batch_15:  [42, 43]
batch_16:  [45, 46]
batch_17:  [48, 49, 50]
batch_18:  [52, 53, 54]
batch_19:  [56]
batch_20:  [58, 59]
batch_21:  [61, 62]
batch_22:  [64]
batch_23:  [66]
batch_24:  [68]
batch_25:  [70, 71]
batch_26:  [73]
batch_27:  [75, 76, 77]
batch_28:  [79, 80]
batch_29:  [82, 83, 84, 85, 86, 87]
batch_30:  [89]
batch_31:  [91]
batch_32:  [93, 94]
batch_33:  [96]
batch_34:  [98]

上面显示的值是图像的索引,但也可以是批处理索引甚至是图像的 id。

通过运行

for i, batch in enumerate(data_loader_sampler):
    print("__sample__: ", i, len(batch[0]))

可以看到该批次包含单个样品,而不是预期的数量。

__sample__:  0 1
__sample__:  1 1
__sample__:  2 1
__sample__:  3 1
__sample__:  4 1
__sample__:  5 1
__sample__:  6 1
__sample__:  7 1
__sample__:  8 1
__sample__:  9 1
__sample__:  10 1
__sample__:  11 1
__sample__:  12 1
__sample__:  13 1
__sample__:  14 1
__sample__:  15 1
__sample__:  16 1
__sample__:  17 1
__sample__:  18 1
__sample__:  19 1
__sample__:  20 1
__sample__:  21 1
__sample__:  22 1
__sample__:  23 1
__sample__:  24 1
__sample__:  25 1
__sample__:  26 1
__sample__:  27 1
__sample__:  28 1
__sample__:  29 1
__sample__:  30 1
__sample__:  31 1
__sample__:  32 1
__sample__:  33 1
__sample__:  34 1

我真正想要防止的是以下行为

for i, batch in enumerate(data_loader_iter):
    print("__iter__: ", i, sum([k["masks"].shape[0] for k in batch[1]]))

这是

__iter__:  0 2510
__iter__:  1 2060
__iter__:  2 2203
__iter__:  3 2815
ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/usr/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "/blip/venv/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 328, in reduce_storage
    fd, size = storage._share_fd_()
RuntimeError: falseINTERNAL ASSERT FAILED at "../aten/src/ATen/MapAllocator.cpp":300, please report a bug to PyTorch. unable to write to file </torch_431207_56>
Traceback (most recent call last):
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 990, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/usr/lib/python3.8/multiprocessing/queues.py", line 107, in get
    if not self._poll(timeout):
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 424, in _poll
    r = wait([self], timeout)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.8/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 431257) is killed by signal: Bus error. It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "so.py", line 170, in <module>
    for i, batch in enumerate(data_loader_iter):
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1186, in _next_data
    idx, data = self._get_data()
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1152, in _get_data
    success, data = self._try_get_data()
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1003, in _try_get_data
    raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 431257) exited unexpectedly

当每批对象的数量大于〜2500时,总是会发生这种情况。

一个直接的解决方法是设置batch_sizelow,我只需要一个更优化的解决方案。

原文链接:https://stackoverflow.com//questions/71500629/custom-sampler-correct-use-in-pytorch

回复

我来回复
  • fmolina199的头像
    fmolina199 评论

    如果您真正要解决的问题是:

    ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
    

    您可以尝试调整分配的共享内存的大小

    # mount -o remount,size=<whatever_is_enough>G /dev/shm
    

    但是,由于这并不总是可能的,因此您的问题的一种解决方法是

    class SyntheticDataset(Dataset):
    
        def __init__(self, image_ids):
            self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
            self.num_classes = 9
    
        def __len__(self):
            return len(self.image_ids)
    
        def __getitem__(self, indices):
            worker_info = torch.utils.data.get_worker_info()
    
            batch = []
            for i in indices:
                sample = self.get_sample(i)
                batch.append(sample)
            gc.collect()
            return batch
    
        def get_sample(self, idx: int):
    
            image_id = torch.as_tensor(idx)
            image = torch.randint(0, 255, (H, W))
    
            num_objects = idx
            image = torch.randint(0, 255, (3, H, W))
            masks = torch.randint(0, 255, (num_objects, H, W))
    
            target = {}
            target["image_id"] = image_id
    
            areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
            boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
            labels = torch.randint(1, self.num_classes, (1, num_objects), dtype=torch.int64)
            iscrowd = torch.zeros(len(labels), dtype=torch.int64)
    
            target["boxes"] = boxes
            target["labels"] = labels
            target["area"] = areas
            target["iscrowd"] = iscrowd
            target["masks"] = masks
    
            return image, target, image_id
    
    

    class BalancedObjectsSampler(BatchSampler):
        """Samples either batch_size images or batches num_objs_per_batch objects.
    
        Args:
            data_source (list): contains tuples of (img_id).
            batch_size (int): batch size.
            num_objs_per_batch (int): number of objects in a batch.
        Return
            yields the batch_ids/image_ids/image_indices
    
        """
    
        def __init__(self, data_source, batch_size, num_objs_per_batch, drop_last=False):
            self.data_source = data_source
            self.sampler = data_source
            self.batch_size = batch_size
            self.drop_last = drop_last
            self.num_objs_per_batch = num_objs_per_batch
            self.batch_count = math.ceil(len(self.data_source) / self.batch_size)
    
            obj_count = 0
            batch = []
            batches = []
            batches_sums = []
            for i, (k, s) in enumerate(self.data_source.iteritems()):
    
                if (
                    len(batch) < self.batch_size
                    and obj_count + s < self.num_objs_per_batch
                    and i < len(self.data_source) - 1
                ):
                    batch.append(s)
                    obj_count += s
                else:
                    batches.append(len(batch))
                    batches_sums.append(obj_count)
                    obj_count = 0
                    batch = []
    
            self.batches = batches
            self.batch_count = len(batches)
    
        def __iter__(self):
            batch = []
            img_counts_id = 0
            for idx, (k, s) in enumerate(self.data_source.iteritems()):
                if len(batch) < self.batches[img_counts_id] and idx < len(self.data_source):
                    batch.append(s)
                elif len(batch) == self.batches[img_counts_id]:
                    gc.collect()
                    yield batch
                    batch = []
                    if img_counts_id < self.batch_count - 1:
                        img_counts_id += 1
                    else:
                        break
    
            if len(batch) > 0 and not self.drop_last:
                yield batch
    
        def __len__(self) -> int:
            if self.drop_last:
                return len(self.data_source) // self.batch_size
            else:
                return (len(self.data_source) + self.batch_size - 1) // self.batch_size
    

    由于 SyntheticDataset 的__getitem__正在接收索引列表,最简单的解决方案就是遍历索引并检索样本列表。您可能只需要以不同的方式整理输出,以便将其提供给您的模型。

    对于 BalancedObjectsSampler,我计算了__init__中每个批次的大小,并在__iter__中使用它来组装批次。

    注意:如果您的num_workers > 0为您尝试将最多 1500 个对象打包成一个批次,这仍然会失败 – 通常一个工人一次加载一个批次。因此,在考虑使用多处理时,您必须重新评估您的num_objs_per_batch

    2年前 0条评论