如何在图像分类框架中添加带有标注的数据增强?

青葱年少 pytorch 202

原文标题How to add data augmentation with albumentation to image classification framework?

我正在使用 pytorch 使用来自 github 的此代码进行图像分类。我需要在训练我的模型之前添加数据增强,我选择了allementation 来执行此操作。这是我添加albumentation 时的代码:

data_transform = {
    "train": A.Compose([ 
                        A.RandomResizedCrop(224,224),
                        A.HorizontalFlip(p=0.5),
                        A.RandomGamma(gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5),
                        A.RandomBrightnessContrast (p=0.5),
                        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5),
                        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
                        A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
                        A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                        ToTensorV2(),]),
    "val": A.Compose([
                      A.Resize(256,256),
                      A.CenterCrop(224,224),
                      A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                      ToTensorV2()])}

我收到了这个错误:

KeyError:在 DataLoader 工作进程 0 中捕获 KeyError。

KeyError:’您必须将数据作为命名参数传递给增强,例如:aug(image=image)’

原文链接:https://stackoverflow.com//questions/71476099/how-to-add-data-augmentation-with-albumentation-to-image-classification-framewor

回复

我来回复
  • Max D.的头像
    Max D. 评论

    这个 Albumentations 函数接受一个位置参数 ‘image’ 并返回一个字典。这是使用它的示例:

    transforms = A.Compose([
                    A.augmentations.geometric.rotate.Rotate(limit=15,p=0.5),
                    A.Perspective(scale=[0,0.1],keep_size=False,fit_output=False,p=1),
                    A.Resize(224, 224),
                    A.HorizontalFlip(p=0.5),
                    A.GaussNoise(var_limit=(10.0, 50.0), mean=0),
                    A.RandomToneCurve(scale=0.5,p=1),
                    A.Normalize(mean=[0.5, 0.5, 0.5],std=[0.225, 0.225, 0.225]),
                    ToTensorV2()
                ])
    
    img = cv2.imread("dog.png")
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    transformed_img = transforms(image=img)["image"]
    
    2年前 0条评论
  • I'mahdi的头像
    I'mahdi 评论

    您可以通过编写如下所示的类来做您想做的事情:

    import albumentations as A
    import cv2 
    
    class ImageDataset(Dataset):
        def __init__(self, images_filepaths, transform=None):
            self.images_filepaths = images_filepaths
            self.transform = transform
    
        def __len__(self):
            return len(self.images_filepaths)
    
        def __getitem__(self, idx):
            image_filepath = self.images_filepaths[idx]
            image = cv2.imread(image_filepath)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            if self.transform is not None:
                image = self.transform(image=image)["image"]
            return image
        
    
    train_transform = A.Compose([
        A.RandomResizedCrop(224,224),
        A.HorizontalFlip(p=0.5),
        A.RandomGamma(gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5),
        A.RandomBrightnessContrast (p=0.5),
        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
        A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
        A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    
    val_transform = A.Compose([
        A.Resize(256,256),
        A.CenterCrop(224,224),
        A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    train_dataset = ImageDataset(images_filepaths=train_images_filepaths, transform=train_transform)
    val_dataset = ImageDataset(images_filepaths=val_images_filepaths, transform=val_transform)
    
    2年前 0条评论
  • Saeedeh Alebooyeh的头像
    Saeedeh Alebooyeh 评论

    我是否正确使用了您的建议?我有好坏图像的数据集(水下图像)

    import os
    import json
    import sys
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torchvision import transforms, datasets
    from tqdm import tqdm
    import random
    from model import resnet34
    import cv2 
    
    
    def main():
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print("using {} device.".format(device))
    class ImageDataset():
        def __init__(self, images_filepaths, transform=None):
            self.images_filepaths = images_filepaths
            self.transform = transform
    
        def __len__(self):
            return len(self.images_filepaths)
    
        def __getitem__(self, idx):
            image_filepath = self.images_filepaths[idx]
            image = cv2.imread(image_filepath)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            if self.transform is not None:
                image = self.transform(image=image)["image"]
            return image
    train_transform = A.Compose([
        A.RandomResizedCrop(224,224),
        A.HorizontalFlip(p=0.5),
        A.RandomGamma(gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5),
        A.RandomBrightnessContrast (p=0.5),
        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
        A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
        A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    
        val_transform = A.Compose([
          A.Resize(256,256),
          A.CenterCrop(224,224),
          A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
          ToTensorV2(),
    ])
    
    
    data_root = os.path.abspath(os.path.join(os.getcwd(), "/content/gdrive/"))  # get             data root path
    image_path = os.path.join(data_root, "MyDrive" , "totalimages")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    
    
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=train_transform)
    train_num = len(train_dataset)
    
    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    {'bad':1, 'good':2} #
    flower_list = train_dataset.class_to_idx
    image_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in image_list.items()) #dictionary
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
           json_file.write(json_str)
    
    batch_size = 64
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=val_transform)
        val_num = len(validate_dataset)
        validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    
       print("using {} images for training, {} images for  validation.".format(train_num,
                                                                           val_num))
    
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    

    model_weight_path = “./resnet34-pre.pth”

    model_weight_path = "/content/gdrive/MyDrive/resnet34-333f7ec4.pth"
    assert os.path.exists(model_weight_path), "file {} does not    exist.".format(model_weight_path)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    # for param in net.parameters():
    #     param.requires_grad = False
    
    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)
    net.to(device)
    
    # define loss function
    loss_function = nn.CrossEntropyLoss()
    
    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)
    
    epochs = 10
    best_acc = 0.0
    save_path = './resNet34.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
    
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
    
        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
    
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)
    
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
    
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
    
    print('Finished Training')
    
    
    if __name__ == '__main__':
    main()
    
    2年前 0条评论