强化学习环境升级 – 从gym到Gymnasium
作为强化学习最常用的工具,gym一直在不停地升级和折腾,比如gym[atari]变成需要要安装接受协议的包啦,atari环境不支持Windows环境啦之类的,另外比较大的变化就是2021年接口从gym库变成了gymnasium库。让大量的讲强化学习的书中介绍环境的部分变得需要跟进升级了。
不过,不管如何变,gym[nasium]作为强化学习的代理库的总的设计思想没有变化,变的都是接口的细节。
step和观察结果
总体来说,对于gymnasium我们只需要做两件事情:一个是初始化环境,另一个就是通过step函数不停地给环境做输入,然后观察对应的结果。
初始化环境分为两步。
第一步是创建gymnasium工厂中所支持的子环境,比如我们使用经典的让一个杆子不倒的CartPole环境:
import gymnasium as gym
env = gym.make("CartPole-v1")
第二步,我们就可以通过env的reset函数来进行环境的初始化:
observation, info = env.reset(seed=42)
我们可以将observation打印出来,它一个4元组,4个数值分别表示:
- 小车位置
- 小车速度
- 棍的倾斜角度
- 棍的角速度
如果角度大于12度,或者小车位置超出了2.4,就意味着失败了,直接结束。
小车的输入就是一个力,要么是向左的力,要么是向右的力。0是向左推小车,1是向右推小车。
下面我们让代码跑起来。
首先我们通过pip来安装gymnasium的包:
pip install gymnasium -U
安装成功之后,
import gymnasium as gym
env = gym.make("CartPole-v1")
print(env.action_space)
observation, info = env.reset(seed=42)
steps = 0
for _ in range(1000):
action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)
print(observation)
if terminated or truncated:
print("Episode finished after {} steps".format(steps))
observation, info = env.reset()
steps = 0
else:
steps += 1
env.close()
env.action_space输出是Discrete(2)。也就是两个离散的值0和1。前面我们介绍了,这分别代表向左和向右推动小车。
observation输出的4元组,我们前面也讲过了,像这样:
[ 0.0273956 -0.00611216 0.03585979 0.0197368 ]
下面就是关键的step一步:
action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)
刚才我们介绍了,CartPole的输入只有0和1两个值。我们采用随机让其左右动的方式来试图让小车不倒。
如果你觉得还是不容易懂的话,我们可以来个更无脑的,管它是什么情况,我们都一直往左推:
observation, reward, terminated, truncated, info = env.step(0)
基本上几步就完了:
[ 0.02699083 -0.16518621 -0.00058549 0.3023946 ] 1.0 False False {}
[ 0.0236871 -0.36029983 0.0054624 0.5948928 ] 1.0 False False {}
[ 0.01648111 -0.5554978 0.01736026 0.88929135] 1.0 False False {}
[ 0.00537115 -0.750851 0.03514608 1.1873806 ] 1.0 False False {}
[-0.00964587 -0.94641054 0.0588937 1.4908696 ] 1.0 False False {}
[-0.02857408 -1.1421978 0.08871109 1.8013463 ] 1.0 False False {}
[-0.05141804 -1.3381925 0.12473802 2.1202288 ] 1.0 False False {}
[-0.07818189 -1.534317 0.16714258 2.4487078 ] 1.0 False False {}
[-0.10886823 -1.7304213 0.21611674 2.7876763 ] 1.0 True False {}
Episode finished after 8 steps
下面我们解释下返回的5元组,observation就是位置4元组,reward是用于强化学习的奖励,在本例中只要是不死就是1. terminated就是是否游戏结束了。
Truncated在官方定义中用于处理比如超时等特殊结束的情况。
truncated, info对于CartPole来说没有用到。
搭建好了gymnasium环境之后,我们就可以进行策略的升级与迭代了。
比如我们写死一个策略,如果位置小于0则向右推,反之则向左推:
def action_pos(status):
pos, v, ang, va = status
#print(status)
if pos <= 0:
return 1
else:
return 0
或者我们根据角度来判断,如果角度大于0则左推,反之则右推:
def action_angle(status):
pos, v, ang, va = status
#print(status)
if ang > 0:
return 1
else:
return 0
角度策略的完整代码如下:
import gymnasium as gym
env = gym.make("CartPole-v1")
#env = gym.make("CartPole-v1",render_mode="human")
print(env.action_space)
#print(env.get_action_meanings())
observation, info = env.reset(seed=42)
print(observation,info)
def action_pos(status):
pos, v, ang, va = status
#print(status)
if pos <= 0:
return 1
else:
return 0
def action_angle(status):
pos, v, ang, va = status
#print(status)
if ang > 0:
return 1
else:
return 0
steps = 0
for _ in range(1000):
action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action_angle(observation))
print(observation, reward, terminated, truncated, info)
if terminated or truncated:
print("Episode finished after {} steps".format(steps))
observation, info = env.reset()
steps = 0
else:
steps += 1
env.close()
与老gym的主要区别
目前版本与之前gym的最主要区别在于step返回值从原来的4元组变成了5元组。
原来是observation, reward, done, info,而现在done变成了 terminated增加了truncated。
老版本的:
status, reward, done, info = env.step(0)
新版的:
observation, reward, terminated, truncated, info = env.step(0)
Atari游戏
我们通过gymnasium[atari]包来安装atari游戏的gymnasium支持。
pip install gymnasium[atari]
通过get_action_meanings来获取游戏支持的操作
之前的CartPole只知道是离散的两个值。而Atari游戏则可支持获取游戏支持的操作的含义:
['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE']
rendor_mode
另外,针对于Atari游戏,render_mode现在是必选项了。要指定是显示成人类可看的human模式,还是只输出rgb_array的模式。
完整例子
我们以乒乓球游戏为例,组装让其运行起来:
import gymnasium as gym
env = gym.make("ALE/Pong-v5", render_mode="human")
observation, info = env.reset()
print(env.get_action_meanings())
scores = 0
for _ in range(1000):
action = env.action_space.sample() # agent policy that uses the observation and info
observation, reward, terminated, truncated, info = env.step(action)
#print(observation, reward, terminated, truncated, info)
if terminated or truncated:
print("Episode finished after {} steps".format(scores))
observation, info = env.reset()
scores = 0
else:
scores +=1
env.close()
完整的游戏支持列表可以在https://gymnasium.farama.org/environments/atari/ 官方文档中查到。
文章出处登录后可见!