[github优秀AI项目]实现4K60帧视频人体实时抠图

项目地址:https://github.com/PeterL1n/RobustVideoMatting文章:Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!PyTorch、TensorFlow、TensorFlow中的强大视频抠图功能。js,ONNX,CoreML!稳定视频抠像 (RVM)论文Robust High-Resolution Video Matting with Te..

项目地址:

https://github.com/PeterL1n/RobustVideoMatting

文章:

Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

PyTorch、TensorFlow、TensorFlow中的强大视频抠图功能。js,ONNX,CoreML!

稳定视频抠像 (RVM)

[github优秀AI项目]实现4K60帧视频人体实时抠图 论文 Robust High-Resolution Video Matting with Temporal Guidance 的官方 GitHub 库。RVM 专为稳定人物视频抠像设计。不同于现有神经网络将每一帧作为单独图片处理,RVM 使用循环神经网络,在处理视频流时有时间记忆。RVM 可在任意视频上做实时高清抠像。在 Nvidia GTX 1080Ti 上实现 4K 76FPS 和 HD 104FPS。此研究项目来自字节跳动

展示视频

观看展示视频 (YouTubeBilibili),了解模型能力。

[github优秀AI项目]实现4K60帧视频人体实时抠图

视频中的所有素材都提供下载,可用于测试模型:Google Drive

Demo

  • 网页: 在浏览器里看摄像头抠像效果,展示模型内部循环记忆值。
  • Colab: 用我们的模型转换你的视频。

下载

推荐在通常情况下使用 MobileNetV3 的模型。ResNet50 的模型大很多,效果稍有提高。我们的模型支持很多框架。详情请阅读推断文档

框架 下载 备注
PyTorch rvm_mobilenetv3.pth
rvm_resnet50.pth
官方 PyTorch 模型权值。文档
TorchHub 无需手动下载。 更方便地在你的 PyTorch 项目里使用此模型。文档
TorchScript rvm_mobilenetv3_fp32.torchscript
rvm_mobilenetv3_fp16.torchscript
rvm_resnet50_fp32.torchscript
rvm_resnet50_fp16.torchscript
若需在移动端推断,可以考虑自行导出 int8 量化的模型。文档
ONNX rvm_mobilenetv3_fp32.onnx
rvm_mobilenetv3_fp16.onnx
rvm_resnet50_fp32.onnx
rvm_resnet50_fp16.onnx
在 ONNX Runtime 的 CPU 和 CUDA backend 上测试过。提供的模型用 opset 12。文档导出
TensorFlow rvm_mobilenetv3_tf.zip
rvm_resnet50_tf.zip
TensorFlow 2 SavedModel 格式。文档
TensorFlow.js rvm_mobilenetv3_tfjs_int8.zip 在网页上跑模型。展示示范代码
CoreML rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel
rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel
CoreML 只能导出固定分辨率,其他分辨率可自行导出。支持 iOS 13+。s 代表下采样比。文档导出

所有模型可在 Google Drive 或百度网盘(密码: gym7)上下载。

PyTorch 范例

  1. 1 安装 Python 库:
pip install -r requirements_inference.txt
  1. 2 加载模型:
import torch
from model import MattingNetwork

model = MattingNetwork('mobilenetv3').eval().cuda()  # 或 "resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
  1. 3 若只需要做视频抠像处理,我们提供简单的 API:
from inference import convert_video

convert_video(
    model,                           # 模型,可以加载到任何设备(cpu 或 cuda)
    input_source='input.mp4',        # 视频文件,或图片序列文件夹
    output_type='video',             # 可选 "video"(视频)或 "png_sequence"(PNG 序列)
    output_composition='com.mp4',    # 若导出视频,提供文件路径。若导出 PNG 序列,提供文件夹路径
    output_alpha="pha.mp4",          # [可选项] 输出透明度预测
    output_foreground="fgr.mp4",     # [可选项] 输出前景预测
    output_video_mbps=4,             # 若导出视频,提供视频码率
    downsample_ratio=None,           # 下采样比,可根据具体视频调节,或 None 选择自动
    seq_chunk=12,                    # 设置多帧并行计算
)
  1. 4 或自己写推断逻辑:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter

reader = VideoReader('input.mp4', transform=ToTensor())
writer = VideoWriter('output.mp4', frame_rate=30)

bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda()  # 绿背景
rec = [None] * 4                                       # 初始循环记忆(Recurrent States)
downsample_ratio = 0.25                                # 下采样比,根据视频调节

with torch.no_grad():
    for src in DataLoader(reader):                     # 输入张量,RGB通道,范围为 0~1
        fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio)  # 将上一帧的记忆给下一帧
        com = fgr * pha + bgr * (1 - pha)              # 将前景合成到绿色背景
        writer.write(com)                              # 输出帧
  1. 5 模型和 API 也可通过 TorchHub 快速载入。
# 加载模型
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # 或 "resnet50"

# 转换 API
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")

推断文档里有对 downsample_ratio 参数,API 使用,和高阶使用的讲解。

训练和评估

请参照训练文档(英文)

速度

速度用 inference_speed_test.py 测量以供参考。

GPU dType HD (1920×1080) 4K (3840×2160)
RTX 3090 FP16 172 FPS 154 FPS
RTX 2060 Super FP16 134 FPS 108 FPS
GTX 1080 Ti FP32 104 FPS 74 FPS
  • 注释1:HD 使用 downsample_ratio=0.25,4K 使用 downsample_ratio=0.125。 所有测试都使用 batch size 1 和 frame chunk 1。
  • 注释2:图灵架构之前的 GPU 不支持 FP16 推理,所以 GTX 1080 Ti 使用 FP32。
  • 注释3:我们只测量张量吞吐量(tensor throughput)。 提供的视频转换脚本会慢得多,因为它不使用硬件视频编码/解码,也没有在并行线程上完成张量传输。如果您有兴趣在 Python 中实现硬件视频编码/解码,请参考 PyNvCodec

复现使用

知乎有个大佬把它分别用python和C++复现了RobustVideoMatting🔥2021 ONNXRuntime C++工程化记录-实现篇 – 知乎

python代码:

import cv2
import time
import argparse
import numpy as np
import onnxruntime as ort


def normalize(frame: np.ndarray) -> np.ndarray:
    """
    Args:
        frame: BGR
    Returns: normalized 0~1 BCHW RGB
    """
    img = frame.astype(np.float32).copy() / 255.0
    img = img[:, :, ::-1]  # RGB
    img = np.transpose(img, (2, 0, 1))  # (C,H,W)
    img = np.expand_dims(img, axis=0)  # (B=1,C,H,W)
    return img


def infer_rvm_frame(weight: str = "rvm_resnet50_fp32.onnx",
                    img_path: str = "test.jpg",
                    output_path: str = "test_onnx.jpg"):
    sess = ort.InferenceSession(f'./checkpoint/{weight}')
    print(f"Load checkpoint/{weight} done!")

    for _ in sess.get_inputs():
        print("Input: ", _)
    for _ in sess.get_outputs():
        print("Input: ", _)

    frame = cv2.imread(img_path)
    src = normalize(frame)
    rec = [np.zeros([1, 1, 1, 1], dtype=np.float32)] * 4  # 必须用模型一样的 dtype
    downsample_ratio = np.array([0.25], dtype=np.float32)  # 必须是 FP32
    bgr = np.array([0.47, 1., 0.6]).reshape((3, 1, 1))

    fgr, pha, *rec = sess.run([], {
        'src': src,
        'r1i': rec[0],
        'r2i': rec[1],
        'r3i': rec[2],
        'r4i': rec[3],
        'downsample_ratio': downsample_ratio
    })

    merge_frame = fgr * pha + bgr * (1. - pha)  # (1,3,H,W)
    merge_frame = merge_frame[0] * 255.  # (3,H,W)
    merge_frame = merge_frame.astype(np.uint8)  # RGB
    merge_frame = np.transpose(merge_frame, (1, 2, 0))  # (H,W,3)
    merge_frame = cv2.cvtColor(merge_frame, cv2.COLOR_BGR2RGB)

    cv2.imwrite(output_path, merge_frame)

    print(f"infer done! saved {output_path}")


def infer_rvm_video(weight: str = "rvm_resnet50_fp32.onnx",
                    video_path: str = "./demo/1917.mp4",
                    output_path: str = "./demo/1917_onnx.mp4"):
    sess = ort.InferenceSession(f'./checkpoint/{weight}')
    print(f"Load checkpoint/{weight} done!")

    for _ in sess.get_inputs():
        print("Input: ", _)
    for _ in sess.get_outputs():
        print("Input: ", _)

    # 读取视频
    video_capture = cv2.VideoCapture(video_path)
    width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
    print(f"Video Caputer: Height: {height}, Width: {width}, Frame Count: {frame_count}")

    # 写出视频
    fps = 25
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    print(f"Create Video Writer: {output_path}")

    i = 0
    rec = [np.zeros([1, 1, 1, 1], dtype=np.float32)] * 4  # 必须用模型一样的 dtype
    downsample_ratio = np.array([0.25], dtype=np.float32)  # 必须是 FP32
    bgr = np.array([0.47, 1., 0.6]).reshape((3, 1, 1))

    print(f"Infer {video_path} start ...")
    while video_capture.isOpened():
        success, frame = video_capture.read()

        if success:
            i += 1
            src = normalize(frame)
            # src 张量是 [B, C, H, W] 形状,必须用模型一样的 dtype
            t1 = time.time()
            fgr, pha, *rec = sess.run([], {
                'src': src,
                'r1i': rec[0],
                'r2i': rec[1],
                'r3i': rec[2],
                'r4i': rec[3],
                'downsample_ratio': downsample_ratio
            })
            t2 = time.time()
            print(f"Infer {i}/{frame_count} done! -> cost {(t2 - t1) * 1000} ms", end=" ")
            merge_frame = fgr * pha + bgr * (1. - pha)  # (1,3,H,W)
            merge_frame = merge_frame[0] * 255.  # (3,H,W)
            merge_frame = merge_frame.astype(np.uint8)  # RGB
            merge_frame = np.transpose(merge_frame, (1, 2, 0))  # (H,W,3)
            merge_frame = cv2.cvtColor(merge_frame, cv2.COLOR_BGR2RGB)
            merge_frame = cv2.resize(merge_frame, (width, height))

            video_writer.write(merge_frame)
            print(f"write {i}/{frame_count} done.")
        else:
            print("can not read video! skip!")
            break

    video_capture.release()
    video_writer.release()
    print(f"Infer {video_path} done!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, default="video")
    parser.add_argument("--weight", type=str, default="rvm_resnet50_fp32.onnx")
    parser.add_argument("--input", type=str, default="./demo/1917.mp4")
    parser.add_argument("--output", type=str, default="./demo/1917_onnx.mp4")
    args = parser.parse_args()

    if args.mode == "video":
        infer_rvm_video(weight=args.weight, video_path=args.input, output_path=args.output)
    else:
        infer_rvm_frame(weight=args.weight, img_path=args.input, output_path=args.output)

    """
    rvm_resnet50_fp32.onnx
    rvm_mobilenetv3_fp32.onnx
    PYTHONPATH=. python3 ./inference_onnx.py --input ./demo/1917.mp4 --output ./demo/1917_onnx.mp4
    PYTHONPATH=. python3 ./inference_onnx.py --mode img --input test.jpg --output test_onnx.jpg
    python inference_onnx.py --input ./demo/1917.mp4 --output ./demo/1917_onnx.mp4
    """

[github优秀AI项目]实现4K60帧视频人体实时抠图

版权声明:本文为博主易小侠原创文章,版权归属原作者,如果侵权,请联系我们删除!

原文链接:https://blog.csdn.net/dwf1354046363/article/details/122555021

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2022年1月18日 上午10:34
下一篇 2022年1月18日 上午10:53

相关推荐