对数据进行预处理,实现K折交叉验证

在深度学习中经常会用到K折交叉验证,本文主要介绍在实际应用中,如何实现图像数据的K折划分,为后面的模型训练与验证做好前期的数据准备工作。

本文主要实现将文件夹中的图像数据进行K折划分,最后将划分好的数据信息以多个子表写入到一个统一的Excel文件中。完整的github代码链接在文章底部。

  • 假设图片已经按照文件夹分类,文件夹名的第一个字母代表分类号

对数据进行预处理,实现K折交叉验证

  • 创建Excel文件,读取图片信息,将图片名称及对应的标签存入到Excel文件中
    def getKFoldData(self):
        imgDataInfo = []    # 统计图像文件名及对应的label,格式:{'img':img1.tif, 'label': 0}
        for category in os.listdir(self.imgRootPath):
            categoryPath = join(self.imgRootPath, category)
            for img in os.listdir(categoryPath):
                imgInfo = {'img':img, 'label': category[0]}
                imgDataInfo.append(imgInfo)
        # 将数据信息写入到excel文件中
        df = pd.DataFrame(imgDataInfo)
        df.to_excel(excelFileSavePath, index=False, sheet_name='originalFile')
  • 分别统计各类别信息,按类别依次进行K折划分,并将划分好的数据写入到上面创建好的Excel文件的不同子表中。
        category_0 = []
        category_1 = []
        category_2 = []
        for dataInfo in imgDataInfo:
            if dataInfo['label'] == '0':
                category_0.append(dataInfo)
            elif dataInfo['label'] == '1':
                category_1.append(dataInfo)
            elif dataInfo['label'] == '2':
                category_2.append(dataInfo)
            else:
                Exception("数据{}的类别异常,为{}:".format(dataInfo["label"], dataInfo["label"]))

        category0_foldSize = int(len(category_0) / self.kFold)
        category1_foldSize = int(len(category_1) / self.kFold)
        category2_foldSize = int(len(category_2) / self.kFold)

       for kthFold in range(self.kFold):
            category_0_fold_train, category_0_fold_valid = self.get_k_fold_index(category_0, category0_foldSize, kthFold)

            category_1_fold_train, category_1_fold_valid = self.get_k_fold_index(category_1, category1_foldSize, kthFold)

            category_2_fold_train, category_2_fold_valid = self.get_k_fold_index(category_2, category2_foldSize, kthFold)


            fold_train = category_0_fold_train + category_1_fold_train + category_2_fold_train
            fold_valid = category_0_fold_valid + category_1_fold_valid + category_2_fold_valid

            train_sheetName = 'train_' + 'fold' + str(kthFold)
            self.createExcelFile(fold_train, train_sheetName)

            valid_sheetName = 'valid_' + 'fold' + str(kthFold)
            self.createExcelFile(fold_valid, valid_sheetName)
  • K折划分具体实现
    def get_k_fold_index(self, data: List, everyFoldSize: int, kthFold: int) -> List:
        """
        :param foldSize: 每折的数量
        :param kthFold: [0 , self.kFold - 1]
        :return: 返回对应的图片名称
        # 如果 (dataLength % kthFold) != 0,则除不尽多余的数据统一存放到最后一折数据中
        """
        assert kthFold <= self.kFold - 1, "输入的折数:{}超出范围!".format(kthFold)
        dataLength = len(data)
        train = []
        valid = []
        for j in range(self.kFold):
            idx = slice(j * everyFoldSize, (j + 1) * everyFoldSize)
            data_part = data[idx]
            if j == kthFold:
                # 属于验证集的那一折数据
                if kthFold == self.kFold - 1:
                    # 最后一折数据,在(dataLength % kthFold) != 0的情况下,将多余的数据也包含进去
                    index = slice(j * everyFoldSize, dataLength)
                    valid = data[index]
                else:
                    valid = data_part
            elif len(train) == 0:
                train = data_part
            elif j == (self.kFold - 1):
                # 最后一折数据,在(dataLength % kthFold) != 0的情况下,将多余的数据也包含进去
                index = slice(j * everyFoldSize, dataLength)
                data_part = data[index]
                train.extend(data_part)
            else:
                train.extend(data_part)

        return [train, valid]
  • Excel文件写入具体实现
    def createExcelFile(self, dataInfo: List, sheet_name: str):
        excelData = {'img': [item["img"] for item in dataInfo],
                     'label': [item["label"] for item in dataInfo]}
        df = pd.DataFrame(data=excelData)
        with pd.ExcelWriter(self.excelFileSavePath, mode='a', engine='openpyxl') as writer:
            df.to_excel(writer, sheet_name=sheet_name,  index=False)
        writer.save()
        writer.close()
  • 最终结果显示

对数据进行预处理,实现K折交叉验证完整的代码github链接

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2022年3月29日 下午6:11
下一篇 2022年3月29日 下午6:20

相关推荐