如何覆盖我的 _getitem_ 函数?
pytorch 263
原文标题 :How can I overwrite my _getitem_ function?
现在我必须从我的训练集中读取 2 张图像作为我的 resnet-34 模型的输入。但是原始的 Dataset 和 DataLoader 无法完成我的任务。我必须自定义我的 Dataset 类,但我不知道如何覆盖它的功能。
我的数据集看起来像 ‘data_set/train/img1,img2,img3………’,我需要 2 2 获取图像,返回 2 图像及其标签(如 ‘C01’),如何制作我的数据集?
我试图创建一个这样的但失败了。
class MyDataSet(Dataset):
"""customize my dataset"""
def __init__(self):
self.images_path = None
def __len__(self):
return len(self.images_path)
def __getitem__(self, item):
img1 = Image.open(self.images_path[2*item])
img2 = Image.open(self.images_path[2*item+1])
label = self.images_class[item] # I don't know how to return the label of 2 images.
if self.transform is not None:
img1 = self.transform(img1)
img2 = self.transform(img2)
return img1,img2, label
回复
我来回复-
ki-ljl 评论
看完你的评论,我明白你的意思了:
class MyDataset(Dataset): def __init__(self): self.images_path = None def __len__(self): return len(self.images_path) def __getitem__(self, item): img1 = Image.open(self.images_path[2*item]) img2 = Image.open(self.images_path[2*item+1]) label = self.images_class[2*item] # for example, label = xxx1(and xxx2), you need return xxx # str label = label[:len(label) - 1] # return xxx if self.transform is not None: img1 = self.transform(img1) img2 = self.transform(img2) # we need to concatenate the two images img = torch.cat((img1, img2), dim=0) return img, label
例如,对于11和12这两张图片,经过处理后,会返回一个通道号为6的图片(11和12拼接在一起),这个图片的label是1(如果是21和22,则返回图像标签 2)。
2年前