[1141]基于MODnet无绿幕抠图

前言

MODNet由香港城市大学和商汤科技于2020年11月首次提出,用于实时抠图任务

MODNet特性:

  • 轻量级(light-weight )
  • 实时性高(real-time)
  • 预测时不需要额外的背景输入(trimap-free)
  • 准确度高(hight performance)
  • 单模型(single model instead of a complex pipeline)
  • 泛化能力强(better generalization ability)

论文地址 : https://arxiv.org/pdf/2011.11961.pdf
git地址: https://github.com/ZHKKKe/MODNet
https://github.com/xuebinqin/U-2-Net

复现代码

基于onnx推理代码

官方给出了基于torch和onnx推理代码,这里用的是关于onnx模型的推理代码.

import os
import cv2
import argparse
import numpy as np
from PIL import Image

import onnx
import onnxruntime


if __name__ == '__main__':
    # define cmd arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--image-path', default= 'test.jpeg',type=str, help='path of the input image (a file)')
    parser.add_argument('--output-path',default= 'result.png', type=str, help='paht for saving the predicted alpha matte (a file)')
    parser.add_argument('--model-path', default='hrnet.onnx', type=str, help='path of the ONNX model')
    args = parser.parse_args()

    # check input arguments
    if not os.path.exists(args.image_path):
        print('Cannot find the input image: {0}'.format(args.image_path))
        exit()
    if not os.path.exists(args.model_path):
        print('Cannot find the ONXX model: {0}'.format(args.model_path))
        exit()

    ref_size = 512

    # Get x_scale_factor & y_scale_factor to resize image
    def get_scale_factor(im_h, im_w, ref_size):

        if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
            if im_w >= im_h:
                im_rh = ref_size
                im_rw = int(im_w / im_h * ref_size)
            elif im_w < im_h:
                im_rw = ref_size
                im_rh = int(im_h / im_w * ref_size)
        else:
            im_rh = im_h
            im_rw = im_w

        im_rw = im_rw - im_rw % 32
        im_rh = im_rh - im_rh % 32

        x_scale_factor = im_rw / im_w
        y_scale_factor = im_rh / im_h

        return x_scale_factor, y_scale_factor

    ##############################################
    #  Main Inference part
    ##############################################

    # read image
    im = cv2.imread(args.image_path)
    img = im.copy()
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

    # unify image channels to 3
    if len(im.shape) == 2:
        im = im[:, :, None]
    if im.shape[2] == 1:
        im = np.repeat(im, 3, axis=2)
    elif im.shape[2] == 4:
        im = im[:, :, 0:3]

    # normalize values to scale it between -1 to 1
    im = (im - 127.5) / 127.5   

    im_h, im_w, im_c = im.shape
    x, y = get_scale_factor(im_h, im_w, ref_size) 

    # resize image
    im = cv2.resize(im, None, fx = x, fy = y, interpolation = cv2.INTER_AREA)

    # prepare input shape
    im = np.transpose(im)
    im = np.swapaxes(im, 1, 2)
    im = np.expand_dims(im, axis = 0).astype('float32')

    # Initialize session and get prediction
    session = onnxruntime.InferenceSession(args.model_path, None)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    result = session.run([output_name], {input_name: im})

    # refine matte
    matte = (np.squeeze(result[0]) * 255).astype('uint8')
    matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation = cv2.INTER_AREA)

    cv2.imwrite(args.output_path, matte)


    # 保存彩色图片
    # b,g,r = cv2.split(img)
    # rbga_img = cv2.merge((b, g, r, matte))
    rbga_img = cv2.merge((img, matte))
    cv2.imwrite('rbga_result.png',rbga_img)

抠图效果

测试图片

image.png

测试结果

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-B3JzUG16-1652968229233)(https://upload-images.jianshu.io/upload_images/12504508-d87c7f4b721020d9.png)]

可以发现抠图已经达到了丝发级别,对于清晰的图片抠图还是很准确的.

基于demo.image_matting.colab.inference推理代码

预训练模型在这里 :modnet_photographic_portrait_matting.ckpt

模型百度网盘:在这里
密码:gchf

把模型下载到目录:MODNet/pretrained,下面运行需要加载此模型。
现在,工作目录是MODNet,在其目录下建立输入图片和输出图片的目录:
input-img, output-img
把需要抠图的图片放到input-img
MODNet目录下,运行

python -m demo.image_matting.colab.inference-1   \
                   --input-path input-img  \
                   --output-path output-img  \
                   --ckpt-path pretrained/modnet_photographic_portrait_matting.ckpt

现在可以从output-img中找到已经抠好的图片xxx_fg.png,遮罩图片xxx_matte.png
看看MODNet模型的抠图效果

image.png

python程序如下。原作者的程序中只给出遮罩matte,没有抠图结果。鄙人不才,添加了抠出的前景图片,供参考。

import os
import sys
import argparse
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from src.models.modnet import MODNet

if __name__ == '__main__':
    # define cmd arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--input-path', type=str, help='path of input images')
    parser.add_argument('--output-path', type=str, help='path of output images')
    parser.add_argument('--ckpt-path', type=str, help='path of pre-trained MODNet')
    args = parser.parse_args()

    # check input arguments
    if not os.path.exists(args.input_path):
        print('Cannot find input path: {0}'.format(args.input_path))
        exit()
    if not os.path.exists(args.output_path):
        print('Cannot find output path: {0}'.format(args.output_path))
        exit()
    if not os.path.exists(args.ckpt_path):
        print('Cannot find ckpt path: {0}'.format(args.ckpt_path))
        exit()
    # define hyper-parameters
    ref_size = 512
    # define image to tensor transform
    im_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )

    # create MODNet and load the pre-trained ckpt
    modnet = MODNet(backbone_pretrained=False)
    modnet = nn.DataParallel(modnet).cuda()
    modnet.load_state_dict(torch.load(args.ckpt_path))
    modnet.eval()
# 注:程序中的数字仅表示某张输入图片尺寸,如1080x1440,此处只为记住其转换过程。
    # inference images
    im_names = os.listdir(args.input_path)
    for im_name in im_names:
        print('Process image: {0}'.format(im_name))
        # read image
        im = Image.open(os.path.join(args.input_path, im_name))
        # unify image channels to 3
        im = np.asarray(im)
        if len(im.shape) == 2:
            im = im[:, :, None]
        if im.shape[2] == 1:
            im = np.repeat(im, 3, axis=2)
        elif im.shape[2] == 4:
            im = im[:, :, 0:3]
        im_org = im                                # 保存numpy原始数组 (1080,1440,3)
        # convert image to PyTorch tensor
        im = Image.fromarray(im)
        im = im_transform(im)
        # add mini-batch dim
        im = im[None, :, :, :]
        # resize image for input
        im_b, im_c, im_h, im_w = im.shape
        if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
            if im_w >= im_h:
                im_rh = ref_size
                im_rw = int(im_w / im_h * ref_size)
            elif im_w < im_h:
                im_rw = ref_size
                im_rh = int(im_h / im_w * ref_size)
        else:
            im_rh = im_h
            im_rw = im_w
        im_rw = im_rw - im_rw % 32
        im_rh = im_rh - im_rh % 32
        im = F.interpolate(im, size=(im_rh, im_rw), mode='area')

        # inference
        _, _, matte = modnet(im.cuda(), True)    # 从模型获得的 matte ([1,1,512, 672])

        # resize and save matte,foreground picture
        matte = F.interpolate(matte, size=(im_h, im_w), mode='area')  #内插,扩展到([1,1,1080,1440])  范围[0,1]
        matte = matte[0][0].data.cpu().numpy()    # torch 张量转换成numpy (1080, 1440)
        matte_name = im_name.split('.')[0] + '_matte.png'
        Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(args.output_path, matte_name))
        matte_org = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2)   # 扩展到 (1080, 1440, 3) 以便和im_org计算
        
        foreground = im_org * matte_org + np.full(im_org.shape, 255) * (1 - matte_org)         # 计算前景,获得抠像
        fg_name = im_name.split('.')[0] + '_fg.png'
        Image.fromarray(((foreground).astype('uint8')), mode='RGB').save(os.path.join(args.output_path, fg_name))

参考:
https://blog.csdn.net/weixin_44238733/article/details/114457650
https://blog.csdn.net/small_wu/article/details/124041904
https://blog.csdn.net/jacke121/article/details/110774623
https://aijishu.com/a/1060000000162206
https://blog.csdn.net/missyoudaisy/article/details/111085552

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
上一篇 2022年5月21日 上午11:31
下一篇 2022年5月21日 上午11:36

相关推荐

本站注重文章个人版权,不会主动收集付费或者带有商业版权的文章,如果出现侵权情况只可能是作者后期更改了版权声明,如果出现这种情况请主动联系我们,我们看到会在第一时间删除!本站专注于人工智能高质量优质文章收集,方便各位学者快速找到学习资源,本站收集的文章都会附上文章出处,如果不愿意分享到本平台,我们会第一时间删除!