BasicSR入门教程

BasicSR入门教程

1.安装环境

由于安装好的其他环境已经有了pytorch,那么新建环境时直接拷贝该环境就好

//复制环境
conda create --name my-basicsr --clone mmediting

克隆项目

git clone https://github.com/XPixelGroup/BasicSR.git

安装依赖包

cd BasicSR
pip install -r requirements.txt

在BasicSR的根目录下安装BasicSR

python setup.py develop

验证BasicSR是否安装成功

import basicsr

通过本地clone安装成功的时候,此时使用pip list 命令查看BasicSR 路径

pip list

2.准备数据集

常用的图像超分数据集如下:

name数据集数据描述下载
2K ResolutionDIV2Kproposed in NTIRE17 (800 train and 100 validation)official website
Classical SR TestingSet5Set5 test datasetGoogle Drive / Baidu Drive
Classical SR TestingSet14Set14 test datasetGoogle Drive / Baidu Drive

DIV2K下载地址:https://data.vision.ee.ethz.ch/cvl/DIV2K/

Set5下载地址:https://drive.google.com/drive/folders/1B3DJGQKB6eNdwuQIhdskA64qUuVKLZ9u

Set14下载地址:https://drive.google.com/drive/folders/1B3DJGQKB6eNdwuQIhdskA64qUuVKLZ9u

因为DIV2K 数据集是2K 分辨率的(比如: 2048×1080), 而我们在训练的时候往往并不要那么大(常见的是128×128 或者192×192 的训练patch). 因此我们可以先把2K 的图片裁剪成有overlap 的480×480 的子图像块. 然后再由dataloader 从这个480×480 的子图像块中随机crop 出128×128 或者192×192 的训练patch。运行脚本extract_subimages.py。

cd BasicSR
python scripts/data_preparation/extract_subimages.py

若需要使用LMDB,则需要制作LMDB,数据准备运行脚本:

python scripts/data_preparation/create_lmdb.py --dataset div2k

数据集的目录结构如下
BasicSR入门教程

3.修改配置文件

创建新的训练配置文件options/train/SRResNet_SRGAN/my_train_MSRResNet_x4.yml,内容如下

# Modified SRResNet w/o BN from:
# Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

# ----------- Commands for running
# ----------- Single GPU with auto_resume
# PYTHONPATH="./:${PYTHONPATH}"  CUDA_VISIBLE_DEVICES=0 python basicsr/train.py -opt options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml --auto_resume

# general settings
name: 001_MSRResNet_x4_f64b16_DIV2K_10k_B16G1_wandb_myfirst
model_type: SRModel
scale: 4
num_gpu: 1  # set num_gpu: 0 for cpu mode
manual_seed: 0

# dataset and data loader settings
datasets:
  train:
    name: DIV2K
    type: PairedImageDataset
    # dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub
    # dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub
    # meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
    # dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub
    # dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub
    # meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
    # (for lmdb)
    dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
    dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
    filename_tmpl: '{}'
    io_backend:
      # type: disk
      # (for lmdb)
      type: lmdb

    gt_size: 128
    use_hflip: true
    use_rot: true

    # data loader
    num_worker_per_gpu: 6
    batch_size_per_gpu: 16
    dataset_enlarge_ratio: 100
    prefetch_mode: ~

  val:
    name: Set5
    type: PairedImageDataset
    dataroot_gt: datasets/Set5/GTmod12
    dataroot_lq: datasets/Set5/LRbicx4
    io_backend:
      type: disk

  val_2:
    name: Set14
    type: PairedImageDataset
    dataroot_gt: datasets/Set14/GTmod12
    dataroot_lq: datasets/Set14/LRbicx4
    io_backend:
      type: disk

# network structures
network_g:
  type: MSRResNet
  num_in_ch: 3
  num_out_ch: 3
  num_feat: 64
  num_block: 16
  upscale: 4

# path
path:
  pretrain_network_g: ~
  param_key_g: params
  strict_load_g: true
  resume_state: ~

# training settings
train:
  ema_decay: 0.999
  optim_g:
    type: Adam
    lr: !!float 2e-4
    weight_decay: 0
    betas: [0.9, 0.99]

  scheduler:
    type: CosineAnnealingRestartLR
    periods: [250000, 250000, 250000, 250000]
    restart_weights: [1, 1, 1, 1]
    eta_min: !!float 1e-7

  # total_iter: 1000000
  total_iter: 10000
  warmup_iter: -1  # no warm up

  # losses
  pixel_opt:
    type: L1Loss
    loss_weight: 1.0
    reduction: mean

# validation settings
val:
  val_freq: !!float 5e3
  save_img: false

  metrics:
    psnr: # metric name, can be arbitrary
      type: calculate_psnr
      crop_border: 4
      test_y_channel: false
      better: higher  # the higher, the better. Default: higher
    niqe:
      type: calculate_niqe
      crop_border: 4
      better: lower  # the lower, the better

# logging settings
logger:
  print_freq: 100
  save_checkpoint_freq: !!float 5e3
  use_tb_logger: true
  wandb:
    project: ~
    resume_id: ~

# dist training settings
dist_params:
  backend: nccl
  port: 29500

可以开始训练

python basicsr/train.py -opt options/train/SRResNet_SRGAN/my_train_MSRResNet_x4.yml

训练完成后,结果会保存在results文件夹下的001_MSRResNet_x4_f64b16_DIV2K_10k_B16G1_wandb_myfirst文件夹中

创建新的测试配置文件options/test/SRResNet_SRGAN/my_test_MSRResNet_x4.yml,内容如下

# ----------- Commands for running
# ----------- Single GPU
# PYTHONPATH="./:${PYTHONPATH}"  CUDA_VISIBLE_DEVICES=0 python basicsr/test.py -opt options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml

# general settings
name: 001_MSRResNet_x4_f64b16_DIV2K_10k_B16G1_wandb_myfirst
model_type: SRModel
scale: 4
num_gpu: 1  # set num_gpu: 0 for cpu mode
manual_seed: 0

# test dataset settings
datasets:
  test_1:  # the 1st test dataset
    name: Set5
    type: PairedImageDataset
    dataroot_gt: datasets/Set5/GTmod12
    dataroot_lq: datasets/Set5/LRbicx4
    io_backend:
      type: disk
  test_2:  # the 2nd test dataset
    name: Set14
    type: PairedImageDataset
    dataroot_gt: datasets/Set14/GTmod12
    dataroot_lq: datasets/Set14/LRbicx4
    io_backend:
      type: disk
  test_3: # the 3rd test dataset
    name: DIV2K100
    type: PairedImageDataset
    dataroot_gt: datasets/DIV2K/DIV2K_valid_HR
    dataroot_lq: datasets/DIV2K/DIV2K_valid_LR_bicubic/X4
    filename_tmpl: '{}x4'
    io_backend:
      type: disk

# network structures
network_g:
  type: MSRResNet
  num_in_ch: 3
  num_out_ch: 3
  num_feat: 64
  num_block: 16
  upscale: 4

# path
path:
  pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_10k_B16G1_wandb_myfirst/models/net_g_10000.pth
  param_key_g: params
  strict_load_g: true

# validation settings
val:
  save_img: true
  suffix: ~  # add suffix to saved images, if None, use exp name

  metrics:
    psnr: # metric name, can be arbitrary
      type: calculate_psnr
      crop_border: 4
      test_y_channel: false
      better: higher  # the higher, the better. Default: higher
    ssim:
      type: calculate_ssim
      crop_border: 4
      test_y_channel: false
      better: higher

测试完成后,结果会保存在results文件夹下的001_MSRResNet_x4_f64b16_DIV2K_10k_B16G1_wandb_myfirst文件夹中

4.tensorboard可视化训练过程

在用于训练的yml配置文件中设置tensorboard开启

# logging settings
logger:
  print_freq: 100
  save_checkpoint_freq: !!float 5e3
  use_tb_logger: true # 设置为true
  wandb:
    project: ~
    resume_id: ~

在命令行输入以下命令,就可以在服务器的浏览器中查看:

tensorboard --logdir tb_logger --port 5500 --bind_all

tensorboard 在本机可以方便使用,但使用服务器时需要设置一下。

在Windows系统装一个Xshell,在文件->属性->ssh->隧道->添加,类型local,源主机填127.0.0.1(意思是本机),端口设置一个,比如12345,目标主机为服务器,目标端口一般是5500,如果5500被占了可以改为其他端口。

BasicSR入门教程

在本地浏览器中输入127.0.0.1:12345即可

BasicSR入门教程

最后感谢小伙伴们的学习噢~

BasicSR入门教程

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2023年2月26日 下午12:37
下一篇 2023年2月26日

相关推荐