如何覆盖我的 _getitem_ 函数?

乘风 pytorch 202

原文标题How can I overwrite my _getitem_ function?

现在我必须从我的训练集中读取 2 张图像作为我的 resnet-34 模型的输入。但是原始的 Dataset 和 DataLoader 无法完成我的任务。我必须自定义我的 Dataset 类,但我不知道如何覆盖它的功能。

我的数据集看起来像 ‘data_set/train/img1,img2,img3………’,我需要 2 2 获取图像,返回 2 图像及其标签(如 ‘C01’),如何制作我的数据集?

我试图创建一个这样的但失败了。

class MyDataSet(Dataset):
"""customize my dataset"""

def __init__(self):
    self.images_path = None

def __len__(self):
    return len(self.images_path)

def __getitem__(self, item):
    img1 = Image.open(self.images_path[2*item])
    img2 = Image.open(self.images_path[2*item+1])
    label = self.images_class[item] # I don't know how to return the label of 2 images.

    if self.transform is not None:
        img1 = self.transform(img1)
        img2 = self.transform(img2)

    return img1,img2, label

enter image description here

原文链接:https://stackoverflow.com//questions/71887496/how-can-i-overwrite-my-getitem-function

回复

我来回复
  • ki-ljl的头像
    ki-ljl 评论

    看完你的评论,我明白你的意思了:

    class MyDataset(Dataset):
        def __init__(self):
            self.images_path = None
    
        def __len__(self):
            return len(self.images_path)
    
        def __getitem__(self, item):
            img1 = Image.open(self.images_path[2*item])
            img2 = Image.open(self.images_path[2*item+1])
            label = self.images_class[2*item]
            # for example, label = xxx1(and xxx2), you need return xxx
            # str
            label = label[:len(label) - 1]  # return xxx
            if self.transform is not None:
                img1 = self.transform(img1)
                img2 = self.transform(img2)
            # we need to concatenate the two images
            img = torch.cat((img1, img2), dim=0)
            return img, label
    

    例如,对于11和12这两张图片,经过处理后,会返回一个通道号为6的图片(11和12拼接在一起),这个图片的label是1(如果是21和22,则返回图像标签 2)。

    2年前 0条评论