机械臂强化学习实战(stable baselines3+panda-gym)

今天参考知乎岳小飞的博客尝试用一下比较标准的机械臂+强化学习的实战项目。这篇博客主要记录一下实现过程,当做个人学习笔记。

在第一遍安装过程中遇到了panda-gym和stb3以及gym-robotics这三个包对gym版本依赖冲突的问题。

这里记录第二遍安装

原文链接:https://www.zhihu.com/people/shen-yue-79/posts
panda-gym的github链接:https://github.com/qgallouedec/panda-gym
panda-gym官方文档:
https://panda-gym.readthedocs.io/en/latest/index.html
机械臂强化学习实战(stable baselines3+panda-gym)

1.新建环境,安装stb3

这里我用anaconda新建了一个python 3.7的环境
安装gym:

pip install stable-baselines3==1.3.0

安装完成后运行如下代码,可检查是否正常:

import gym
from stable_baselines3 import A2C

env = gym.make('CartPole-v1')

model = A2C('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=10000)

obs = env.reset()
for i in range(1000):
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()

运行了以下有个报错应该是没安pyglet包,安一下

pip install pyglet

机械臂强化学习实战(stable baselines3+panda-gym)
到这里应该是安装好stb3了。

2.安装panda-gym

panda-gym 基于 PyBullet 引擎开发,围绕 panda 机械臂封装了 reach、push、slide、pick&place、stack、flip 等 6 个任务,主要也是受 OpenAI Fetch 启发,发表在了 NeurIPS 2021 的 workshop 上。

有两种安装方式一种是直接安装,这里用方式一安装2.0.0版本
方式1:

pip install panda-gym

如果想修改现有代码,或者自定义更多的任务,可以下载源码后,以 -e 方式安装(这种方式我尝试了遇到了包版本冲突问题):
方式2:

git clone https://github.com/qgallouedec/panda-gym/
cd panda-gym
pip install -e .

测试panda-gym是否安装成功

import gym
import panda_gym

env = gym.make("PandaReach-v2", render=True)

obs = env.reset()
done = False

while not done:
    action = env.action_space.sample()
    obs, reward, done, info = env.step(action)
    env.render()

env.close()

机械臂强化学习实战(stable baselines3+panda-gym)
现在应该是装好panda-gym了。

3.开始训练

以 PandaReach-v2 任务为例,训练 DDPG/TD3/SAC+HER 算法,方便做横向对比。

reach 任务比较简单,要求机械臂到达指定位置,误差在一定范围之内即代表成功,默认采用稀疏奖励。

训练之前安装一下tensorboard帮助我们之后看训练过程

pip install tensorboard

以 PandaReach-v2 任务为例(代码参考自岳小飞,这里加了第二行导入panda-gym包)

import gym
import panda_gym
from stable_baselines3 import DDPG, TD3, SAC, HerReplayBuffer

env = gym.make("PandaReach-v2")
log_dir = './panda_reach_v2_tensorboard/'

# DDPG
model = DDPG(policy="MultiInputPolicy", env=env, buffer_size=100000, replay_buffer_class=HerReplayBuffer, verbose=1, tensorboard_log=log_dir)
model.learn(total_timesteps=20000)
model.save("ddpg_panda_reach_v2")
# TD3
model = TD3(policy="MultiInputPolicy", env=env, buffer_size=100000, replay_buffer_class=HerReplayBuffer, verbose=1, tensorboard_log=log_dir)
model.learn(total_timesteps=20000)
model.save("td3_panda_reach_v2")
# SAC
model = SAC(policy="MultiInputPolicy", env=env, buffer_size=100000, replay_buffer_class=HerReplayBuffer, verbose=1, tensorboard_log=log_dir)
model.learn(total_timesteps=20000)
model.save("sac_panda_reach_v2")

运行上面代码,又遇到第一遍的问题
机械臂强化学习实战(stable baselines3+panda-gym)
问题是protobuf版本太高了

TypeError: Descriptors cannot not be created directly

谷歌查到了解决办法:https://discuss.streamlit.io/t/typeerror-descriptors-cannot-not-be-created-directly/25639/8
简而言之,终端运行下面两个命令

pip uninstall protobuf
pip install protobuf~=3.19.0

机械臂强化学习实战(stable baselines3+panda-gym)
重新运行代码,大概10分钟(这里我也不知道为啥用了cpu而不是cudn,我这台电脑有个3070卡,不过电脑的cpu配置也还行)
机械臂强化学习实战(stable baselines3+panda-gym)

4.查看训练过程

运行完毕后,终端输入命令

tensorboard --logdir panda_reach_v2_tensorboard  

这里如果有如下报错:tensorboard : 无法将“tensorboard”项识别为 cmdlet、函数、脚本文件或可运行程序的名称。可以参看这篇博客:https://blog.csdn.net/qq_47997583/article/details/125028675
这里没问题的话出现下面的结果:
机械臂强化学习实战(stable baselines3+panda-gym)
在浏览器粘贴链接
机械臂强化学习实战(stable baselines3+panda-gym)
从训练结果看在reach任务中,从训练效率DDPG和TD3接近,SAC稍慢,但是成功率都很快到了100%。后面也包括train过程中参数的变化值。

4.查看训练出的模型的真实效果

代码参考自岳小飞

import gym
from stable_baselines3 import DDPG, TD3, SAC, HerReplayBuffer

env = gym.make("PandaReach-v2", render=True)
model = DDPG.load('ddpg_panda_reach_v2', env=env)
# model = TD3.load('td3_panda_reach_v2', env=env)
# model = SAC.load('sac_panda_reach_v2', env=env)

obs = env.reset()
for i in range(1000):
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
        print('Done')
        obs = env.reset()

机械臂强化学习实战(stable baselines3+panda-gym)
视频效果见:https://www.bilibili.com/video/BV1hB4y1X7k7?spm_id_from=333.999.0.0

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
上一篇 2022年5月31日 上午11:24
下一篇 2022年5月31日 上午11:27

相关推荐

本站注重文章个人版权,不会主动收集付费或者带有商业版权的文章,如果出现侵权情况只可能是作者后期更改了版权声明,如果出现这种情况请主动联系我们,我们看到会在第一时间删除!本站专注于人工智能高质量优质文章收集,方便各位学者快速找到学习资源,本站收集的文章都会附上文章出处,如果不愿意分享到本平台,我们会第一时间删除!