Gradio~让你的机器学习模型~性感起来

gradio是一个快速构建机器学习Web展示页面的开源Python库。

只需要几行代码,就可以让你的机器学习模型从抽象晦涩的代码变成性感可爱的交互界面。

让没有任何编程技能的用户也能够轻松使用和体验模型。

它非常适合在模型迭代测试中快速获取用户反馈或者在汇报展示中进行使用,非常酷炫。

公众号算法美食屋后台回复关键词:gradio, 获取本文notebook源代码和 Bilibili视频演示教程~

相比另一个机器学习应用web展示库streamlit,gradio具有如下优势:

  • 便于分享:gradio可以在启动应用时设置share=True参数创建外部分享链接,可以直接在微信中分享给用户使用。

  • 方便调试:gradio可以在jupyter中直接展示页面,更加方便调试。

大多数的gradio应用一般由如下最常用的基础模块构成。

  • 应用界面:gr.Interface(简易场景), gr.Blocks(定制化场景)

  • 输入输出:gr.Image(图像), gr.Textbox(文本框), gr.DataFrame(数据框), gr.Dropdown(下拉选项), gr.Number(数字), gr.Markdown, gr.Files

  • 控制组件:gr.Button(按钮)

  • 布局组件:gr.Tab(标签页), gr.Row(行布局), gr.Column(列布局)

我们将由易到难通过5个范例来介绍gradio的使用方法。

  • hello world范例 (gr.Interface)

  • 文本分类 (gr.Interface)

  • 图片分类 (gr.Interface)

  • 目标检测  (gr.Blocks定制化)

  • 图片筛选器 (gr.Blocks)

参考资料:

  • 官方教程:https://gradio.app/

一,Hello World (难度系数: ⭐️)

import gradio as gr

def greet(name):
    return "Hello " + name + "!!"

demo = gr.Interface(fn=greet, inputs="text", outputs="text")
gr.close_all()
demo.launch(share=True)

二,文本分类 (难度系数: ⭐️⭐️)

#!pip install gradio, ultralytics, transformers, torchkeras
import gradio as gr 
from transformers import pipeline

pipe = pipeline("text-classification")

def clf(text):
    result = pipe(text)
    label = result[0]['label']
    score = result[0]['score']
    res = {label:score,'POSITIVE' if label=='NEGATIVE' else 'NEGATIVE': 1-score}
    return res 

demo = gr.Interface(fn=clf, inputs="text", outputs="label")
gr.close_all()
demo.launch(share=True)

三,图片分类 (难度系数: ⭐️⭐️⭐)

import gradio as gr 
import pandas as pd 
from ultralytics import YOLO
from skimage import data
from PIL import Image

model = YOLO('yolov8n-cls.pt')
def predict(img):
    result = model.predict(source=img)
    df = pd.Series(result[0].names).to_frame()
    df.columns = ['names']
    df['probs'] = result[0].probs
    df = df.sort_values('probs',ascending=False)
    res = dict(zip(df['names'],df['probs']))
    return res
gr.close_all() 
demo = gr.Interface(fn = predict,inputs = gr.Image(type='pil'), outputs = gr.Label(num_top_classes=5), 
                    examples = ['cat.jpeg','people.jpeg','coffee.jpeg'])
demo.launch()

四,目标检测 (难度系数: ⭐️⭐️⭐⭐️)

import gradio as gr 
import pandas as pd 
from skimage import data
from ultralytics.yolo.data import utils 

model = YOLO('yolov8n.pt')

#load class_names
yaml_path = str(Path(ultralytics.__file__).parent/'datasets/coco128.yaml') 
class_names = utils.yaml_load(yaml_path)['names']
def detect(img):
    if isinstance(img,str):
        img = get_url_img(img) if img.startswith('http') else Image.open(img).convert('RGB')
    result = model.predict(source=img)
    if len(result[0].boxes.boxes)>0:
        vis = plots.plot_detection(img,boxes=result[0].boxes.boxes,
                     class_names=class_names, min_score=0.2)
    else:
        vis = img
    return vis
with gr.Blocks() as demo:
    gr.Markdown("# yolov8目标检测演示")

    with gr.Tab("捕捉摄像头喔"):
        in_img = gr.Image(source='webcam',type='pil')
        button = gr.Button("执行检测",variant="primary")

        gr.Markdown("## 预测输出")
        out_img = gr.Image(type='pil')

        button.click(detect,
                     inputs=in_img, 
                     outputs=out_img)
        
    
    ...
gr.close_all() 
demo.queue(concurrency_count=5)
demo.launch()

五,图片筛选器 (难度系数: ⭐️⭐️⭐⭐️⭐️)

尽管gradio的设计初衷是为了快速创建机器学习用户交互页面。

但实际上,通过组合gradio的各种组件,用户可以很方便地实现非常实用的各种应用小工具。

例如:  数据分析展示dashboard,  数据标注工具, 制作一个小游戏界面等等。

本范例我们将应用 gradio来构建一个图片筛选器,从百度爬取的一堆猫咪表情包中刷选一些我们喜欢的出来。

#!pip install -U torchkeras
import torchkeras 
from torchkeras.data import download_baidu_pictures 
download_baidu_pictures('猫咪表情包',100)
import gradio as gr
from PIL import Image
import time,os
from pathlib import Path 
base_dir = '猫咪表情包'
selected_dir = 'selected'
files = [str(x) for x in 
         Path(base_dir).rglob('*.jp*g') 
         if 'checkpoint' not in str(x)]
def show_img(path):
    return Image.open(path)
def fn_before(done,todo):
    ...
    return done,todo,path,img
def fn_next(done,todo):
    ...
    return done,todo,path,img
def save_selected(img_path):
    ...
    return msg 
def get_default_msg():
    ...
    return msg
with gr.Blocks() as demo:
    with gr.Row():
        total = gr.Number(len(files),label='总数量')
        with gr.Row(scale = 1):
            bn_before = gr.Button("上一张")
            bn_next = gr.Button("下一张")
        with gr.Row(scale = 2):
            done = gr.Number(0,label='已完成')
            todo = gr.Number(len(files),label='待完成')
    path = gr.Text(files[0],lines=1, label='当前图片路径')
    feedback_button = gr.Button("选择图片",variant="primary")
    msg = gr.TextArea(value=get_default_msg,lines=3,max_lines = 5)
    img = gr.Image(value = show_img(files[0]),type='pil')
    
    bn_before.click(fn_before,
                 inputs= [done,todo], 
                 outputs=[done,todo,path,img])
    bn_next.click(fn_next,
                 inputs= [done,todo], 
                 outputs=[done,todo,path,img])
    feedback_button.click(save_selected,
                         inputs = path,
                         outputs = msg
                         )
demo.launch()

六,huggingface托管 (难度系数: ⭐️⭐️)

为了便于向合作伙伴永久展示我们的模型App,可以将gradio的模型部署到 HuggingFace的 Space托管空间中,完全免费的哦。

方法如下:

1,注册huggingface账号:https://huggingface.co/join

2,在space空间中创建项目:https://huggingface.co/spaces

3,创建好的项目有一个Readme文档,可以根据说明操作,也可以手工编辑app.py和requirements.txt文件。

参考范例:

https://huggingface.co/spaces/lyhue1991/yolov8_demo

公众号算法美食屋后台回复关键词:gradio,获取本文notebook源代码和B站视频讲解.

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年5月26日
下一篇 2023年5月26日

相关推荐