在 Pytorch 中将多类图像分类简化为二进制分类

社会演员多 pytorch 207

原文标题Reduce multiclass image classification to binary classification in Pytorch

我正在研究一个由 10 个不同类组成的 stl-10 图像数据集。我想将这个多类图像分类问题简化为二元类图像分类,例如 1 类 Vs 休息。我正在使用 PyTorch torchvision 下载和使用 stl 数据,但我无法像其他人一样做到这一点。

train_data=torchvision.datasets.STL10(root='data',split='train',transform=data_transforms['train'], download=True)
test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True)

train_dataloader = DataLoader(train_data,batch_size = 64,shuffle=True,num_workers=2)
test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)

原文链接:https://stackoverflow.com//questions/71889622/reduce-multiclass-image-classification-to-binary-classification-in-pytorch

回复

我来回复
  • ki-ljl的头像
    ki-ljl 评论

    您需要重新标记图像。最开始class 0对应label 0,class 1对应label 1,…,class 10对应label 9。如果要实现二分类,需要更改类别1的图片的label (或其他)为0,所有其他类别的图片为1。

    2年前 0条评论
  • asymptote的头像
    asymptote 评论

    一种方法是在运行时更新标签值,然后再将它们传递给训练循环中的损失函数。假设我们要将类 5 重新标记为 1,其余的标记为 0:

    my_class_id = 5
    for imgs, labels in train_dataloader:
        labels = torch.where(labels == my_class_id, 1, 0)
        ...
    

    您可能还需要对 test_dataloader 进行类似的重新标记。另外,我不确定labels的数据类型。如果它是浮动的,请相应地更改。

    2年前 0条评论
  • Balagopal Unnikrishnan的头像
    Balagopal Unnikrishnan 评论

    对于 torchvision 数据集,有一种内置的方法可以做到这一点。您需要定义一个转换函数或类,并在创建数据集时将其添加到target_transform中。

    torchvision.datasets.STL10(root: str, split: str = 'train', folds: Union[int, NoneType] = None, transform: Union[Callable, NoneType] = None, target_transform: Union[Callable, NoneType] = None, download: bool = False)
    

    这是一个可供参考的工作示例:

    
    import torchvision
    from torch.utils.data import DataLoader
    from torchvision import transforms
    
    
    class Multi2UniLabelTfm():
        def __init__(self,pos_label=5):
            if isinstance(pos_label,int) or isinstance(pos_label,float):
                pos_label = [pos_label,]
            self.pos_label = pos_label
    
        def __call__(self,y):
            # if y==self.pos_label:
            if y in self.pos_label:
                return 1
            else:
                return 0
    
    if __name__=='__main__':
    
        test_tfms = transforms.Compose([
            transforms.ToTensor()
        ])
        data_transforms = {'val':test_tfms}
    
    
        #Original Labels
        # target_transform = None   
    
        # Label 5 is converted to 1. Rest are 0.
        # target_transform = Multi2UniLabelTfm(pos_label=5)     
    
        # Labels 5,6,7 are converted to 1. Rest are 0.
        target_transform = Multi2UniLabelTfm(pos_label=[5,6,7])
        test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True, target_transform=target_transform)
        test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)
    
        for idx,(x,y) in enumerate(test_dataloader):
            print(idx,y)
    
            if idx == 5:
                break
    
    2年前 0条评论