解决pytorch dataloader报错:Trying to resize storage that is not resizable

省流

碰到这种问题,尤其是平常运行的好好的,换个数据集就报错,那大概率就是数据集本身有问题。顺着这个思路去debug即可。

问题描述

dataloader在设置num_workers为任何大于0的数时出现如下报错:

Traceback (most recent call last):
  File "/home/username/distort/main.py", line 131, in <module>
    model, perms, accs = train_model(dinfos, args.mid, args.pretrained, args.num_classes, args.treps, args.testep, args.test_dist, device, args.distort)
  File "/home/username/distort/main.py", line 65, in train_model
    for img, y in train_dataloader:
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 681, in __next__
    data = self._next_data()
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1376, in _next_data
    return self._process_data(data)
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1402, in _process_data
    data.reraise()
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/_utils.py", line 461, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in default_collate
    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in <listcomp>
    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 140, in default_collate
    out = elem.new(storage).resize_(len(batch), *list(elem.size()))
RuntimeError: Trying to resize storage that is not resizable

num_workers设置为0时则出现新的报错:

Traceback (most recent call last):
  File "/home/username/distort/main.py", line 130, in <module>
    model, perms, accs = train_model(dinfos, args.mid, args.pretrained, args.num_classes, args.treps, args.testep, args.test_dist, device, args.distort)
  File "/home/username/distort/main.py", line 64, in train_model
    for img, y in train_dataloader:
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 681, in __next__
    data = self._next_data()
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 721, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in default_collate
    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in <listcomp>
    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.
  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 141, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [3, 64, 64] at entry 0 and [1, 64, 64] at entry 32

问题排查

第二个报错还是比较容易排查的。在自定义dataset类的__getitem__()函数中加入代码:当读取的tensor的shape[0]为1时打印该tensor对应原始数据文件的路径。

发现数据集中确实有通道数为1的图片(我用的tiny-imagenet-200),没想到真的是数据集的锅。

问题解决

在__getitem__()函数使用tensor类的expand,对于通道数不对的tensor,调用expand(3,-1,-1)即可。之后num_workers设置为0或者其他正数时都能正常加载数据集。

另外需要注意,有的博客说num_workers需要匹配GPU核心的数量,这逻辑属实离谱。从上面的第一个报错就能看出来,出错点和CUDA库毫无关系,因此不可能是GPU相关的问题。至少按照常用的加载数据集的方法,num_workers就是规定dataloader使用CPU线程的最大数量。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2023年8月16日
下一篇 2023年8月16日

相关推荐