在 Pytorch 中将多类图像分类简化为二进制分类
pytorch 207
原文标题 :Reduce multiclass image classification to binary classification in Pytorch
我正在研究一个由 10 个不同类组成的 stl-10 图像数据集。我想将这个多类图像分类问题简化为二元类图像分类,例如 1 类 Vs 休息。我正在使用 PyTorch torchvision 下载和使用 stl 数据,但我无法像其他人一样做到这一点。
train_data=torchvision.datasets.STL10(root='data',split='train',transform=data_transforms['train'], download=True)
test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True)
train_dataloader = DataLoader(train_data,batch_size = 64,shuffle=True,num_workers=2)
test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)
回复
我来回复-
ki-ljl 评论
您需要重新标记图像。最开始class 0对应label 0,class 1对应label 1,…,class 10对应label 9。如果要实现二分类,需要更改类别1的图片的label (或其他)为0,所有其他类别的图片为1。
2年前 -
asymptote 评论
一种方法是在运行时更新标签值,然后再将它们传递给训练循环中的损失函数。假设我们要将类 5 重新标记为 1,其余的标记为 0:
my_class_id = 5 for imgs, labels in train_dataloader: labels = torch.where(labels == my_class_id, 1, 0) ...
您可能还需要对 test_dataloader 进行类似的重新标记。另外,我不确定
labels
的数据类型。如果它是浮动的,请相应地更改。2年前 -
Balagopal Unnikrishnan 评论
对于 torchvision 数据集,有一种内置的方法可以做到这一点。您需要定义一个转换函数或类,并在创建数据集时将其添加到
target_transform
中。torchvision.datasets.STL10(root: str, split: str = 'train', folds: Union[int, NoneType] = None, transform: Union[Callable, NoneType] = None, target_transform: Union[Callable, NoneType] = None, download: bool = False)
这是一个可供参考的工作示例:
import torchvision from torch.utils.data import DataLoader from torchvision import transforms class Multi2UniLabelTfm(): def __init__(self,pos_label=5): if isinstance(pos_label,int) or isinstance(pos_label,float): pos_label = [pos_label,] self.pos_label = pos_label def __call__(self,y): # if y==self.pos_label: if y in self.pos_label: return 1 else: return 0 if __name__=='__main__': test_tfms = transforms.Compose([ transforms.ToTensor() ]) data_transforms = {'val':test_tfms} #Original Labels # target_transform = None # Label 5 is converted to 1. Rest are 0. # target_transform = Multi2UniLabelTfm(pos_label=5) # Labels 5,6,7 are converted to 1. Rest are 0. target_transform = Multi2UniLabelTfm(pos_label=[5,6,7]) test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True, target_transform=target_transform) test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2) for idx,(x,y) in enumerate(test_dataloader): print(idx,y) if idx == 5: break
2年前