离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

【更新日志】

离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

论文信息:Scott Fujimoto, Shixiang Shane Gu: “A Minimalist Approach to Offline Reinforcement Learning”, 2021;arXiv:2106.06860.

本文是Google Brain团队和McGill大学合作,由TD3、BCQ的作者Fujimoto提出并发表在NeurIPS2020 顶会上的文章,本文方法最大的优点是:方法简单、无任何复杂数学公式、可实现性强(开源)、对比实验非常充分(满分推荐),正如标题一样(A minimalist approach)。

摘要: 相比于几篇博客讲过的BCQ(通过扰动网络生成动作,不断将学习策略和行为策略拉进)、BEAR(通过支撑集匹配避免分布匹配的问题)、BRAC(通过VP和PR两个方法正则化)以及REM(通过随机集成混合方法对多个值函数求取凸优化最优的鲁棒性)方法。本文作者提出的TD3+BC方法,结构简单,仅在值函数上添加一个行为克隆(BC)的正则项,并对state进行normalizing,简单的对TD3修改了几行代码就可以与前几种方法相媲美,结果表明:TD3+BC效果好,训练时间也比其他少很多。

1. Offline RL的一些挑战。

  • 实现和Tune的复杂性(Implementation and Tuning Complexities), 在强化学习中,算法的实现、论文的复现都是一个非常难的问题,很多算法并没法去复现,即使相同的seed有时候未必也能达到效果。同样在Offline中仍然存在,此外在Offline中还要解决分布偏移、OODd等之外的一些问题。
  • 额外算力需求(Extra Computation Requirement),由于过于复杂的数学优化、过多的超参数等算法的执行带来了很长的训练时间,导致不得不增加计算资源来学习算法使得其收敛。
  • 训练策略的不稳定性(Instability of Trained Policies),强化学习领域的不稳定性众所周知,所以Offline RL如何才能与Supervised leanring一样很稳定是一个重要的研究问题。
  • Offline RL改进问题(algorithmic/Coding/Optimization),包括了代码层次的优化改进和理论结构方面的改进等。

其实上述的这些问题并不是去解决offline RL中的一些诸如分布偏移、OOD、过估计以及等等这些问题,而是去解决如何简单、快速、高效的实现算法的实现与高效运行问题,因此作者面对这些问题,发出疑问并给出方法:
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

2. TD3+BC原理

2.1 TD3+BC相比于其他的优势

下图是TD3+BC算法相对于CQL、Fish-BRC算法的复杂性对比,从表中我们可以看到CQL和Fish-BRC在算法(algorithmic)上有了很多的变种,使用生成网络,近似离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)等,而TD3+BC仅仅添加了一个BC term和Normalized state,足够的简单。
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

2.2 理论部分

对于经典的DDPG、TD3等算法来讲, 策略梯度的计算根据David sliver提出的如下定义,即求解状态-动作值函数的期望值。

离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)
在本文中,作者为了使这两个动作尽可能接近,增加了一个正则项离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

个人看法: 有点像BCQ中的让学习策略和行为策略之间的距离减少那种意思,只不过添加到正则项里面.
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

另外一个技术点就是从代码执行层面的优化,即Normalize State,具体的Normalize过程如公式所示:

离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

其中的离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)表示一个normalization常量,作者在文中使用了离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇),离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)表示期望和标准差(standard deviation)。

实验效果(关于纵坐标Percent difference后文有说明,本部分只看效果)
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

最后一个技术点就是关于离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)的求解,作者给出了计算公式,并在后文中说取值为离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)的时候效果最好, 实验部分有作者做的ablation实验证明。

离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

最后贴出作者在TD3代码上的改动部分==》TD3+BC算法实现

离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

2.3 经典的Rebuttal场面

此外,我们看一下作者如何rebuttle这些OpenReview提出的审稿意见[1],[2]

离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

其实这部分蛮有意思的,我们发现大多数普通人的工作还是集中在对算法的小部分优化(数学大佬和代码大神略过),这里作者教你手把手给审稿人回复(建议收藏,热别是第2条)

审稿人:
(1)首先,该方法的新颖性似乎有点有限。作者似乎直接使 RL+BC 适应离线设置,只是他们添加了状态归一化,这也不是新的。作者也没有从理论上证明这种方法的合理性。例如,作者应该证明该方法可以保证安全的策略改进,并且享有可比或更好的策略改进保证 w.r.t.先前的方法。如果没有理论依据,并且考虑到该方法的当前形式,我认为该方法有点增量。
(2)此外,实证评估并不彻底。作者仅在 D4RL 中的简单 mujoco 环境中评估了该方法。目前尚不清楚该方法是否可以很好地执行更多无向多任务数据集,例如蚂蚁迷宫和厨房,以及更复杂的操作任务,例如 D4RL 中的 adroit。似乎该方法在随机数据集上表现不佳。这是一个主要限制吗?我还认为作者应该将状态归一化添加到所有基线以确保公平比较,因为状态归一化不是 RL 中的新技术。
(3)最后,我认为比较不完整。作者还应该将该方法与最近的无模型离线 RL 方法(如 [1])和基于模型的方法(如 [2,3])进行比较,后者在随机和中等重放数据集上获得了更好的性能。
总的来说,鉴于上述评论,我会投票支持弱拒绝。

下面就来看看作者神奇而巧妙的回复吧

作者回复:
(1)关于新颖性:我们完全不同意我们的算法在新颖性方面是递增的(我们在相关工作中强调了许多类似的算法)。然而,我们的主要主张/贡献与其说这是最好的离线 RL 算法,或者说它特别新颖,不如说是令人惊讶的观察,即使用非常简单的技术可以匹配/优于当前算法。希望 TD3+BC 可以用作易于实现的基线或其他添加(例如 S4RL)的起点,同时消除更复杂方法所需的许多不必要的复杂性、超参数调整或计算成本.
(2)关于经验评估:据我们所知,我们最强的基线 Fisher-BRC 被认为是无模型算法的 SOTA,最近在 ICML 上发表。
(3)由于 D4RL 结果的标准化,我们可以直接与建议的基线进行比较(我们会将这些结果包含在最终草案中)。我们在下面报告这些,但我们想说明两点:
(4)MOReL 和 MOPO 来自不同的方法系列(基于模型),并且都使用特定于环境的超参数。
S4RL 与我们的方法相切,只需将 CQL 替换为 TD3+BC,就可以很容易地将其添加到我们的方法中。我们的方法可以说更适合基础算法的这些类型的添加,因为超参数更少,这意味着我们不必担心变化之间的交互作用。
最终,我们没有发现添加状态归一化可以为基线提供相同水平的好处,可能是因为这些方法需要超参数调整来补偿额外的修改。

很有趣,学习收藏吧!

3. 实验及过程分析

3.1 实验超参数

这部分是笔者实验的基础,相当良心。它特定于每个实验环境的版本号。
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

这部分具体说明作者良心:代码版发布
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

3.2 衡量指标:百分比差异(Percent Difference)

这部分公式是作者实验的参考基准计算方式,其中在博客也提出了关于差距百分比的疑问,特意查了了一下计算过程[3](备注,有的地方可能用了绝对值):
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

3.3 实验验证与结果简要分析

说明:关于D4RL数据集的组成、安装和解释请参考博文离线强化学习(Offline RL)系列2: (环境篇)D4RL数据集简介、安装及错误解决

本实验参数
HC = HalfCheetah, Hop = Hopper, W = Walker, r = random, m = medium, mr = medium-replay, me = medium-expert, e = expert. While online algorithms (TD3) typically have small episode variances per trained policy (as they should at convergence),

3.3.1 D4RL验证讨论

离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

3.3.2 运行训练时间讨论

可以从实验结果中很直白的看到,CQL、FishBRC与TD3+BC(离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇))的运行时间, 其实这与算法的复杂性紧密相关,对于TD3来说只需要去根据超参数学习网络即可,但对于CQL等算法,需要学习一堆的参数。
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

3.3.3 消融(ablation)实验(如何确定离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)?)

这部分其实对比了vanillaBC方法和区别,同时就参数离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)做了对比得出了最好的离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)
离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

4. 代码实例分析

离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)

def train(self, replay_buffer, batch_size=256):
		self.total_it += 1

		# Sample replay buffer 
		state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

		with torch.no_grad():
			# Select action according to policy and add clipped noise
			noise = (
				torch.randn_like(action) * self.policy_noise
			).clamp(-self.noise_clip, self.noise_clip)
			
			next_action = (
				self.actor_target(next_state) + noise
			).clamp(-self.max_action, self.max_action)

			# Compute the target Q value
			target_Q1, target_Q2 = self.critic_target(next_state, next_action)
			target_Q = torch.min(target_Q1, target_Q2)
			target_Q = reward + not_done * self.discount * target_Q

		# Get current Q estimates
		current_Q1, current_Q2 = self.critic(state, action)

		# Compute critic loss
		critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

		# Optimize the critic
		self.critic_optimizer.zero_grad()
		critic_loss.backward()
		self.critic_optimizer.step()

		# Delayed policy updates
		if self.total_it % self.policy_freq == 0:

			# Compute actor loss
			pi = self.actor(state)
			Q = self.critic.Q1(state, pi)
			lmbda = self.alpha/Q.abs().mean().detach()

			actor_loss = -lmbda * Q.mean() + F.mse_loss(pi, action) 
			
			# Optimize the actor 
			self.actor_optimizer.zero_grad()
			actor_loss.backward()
			self.actor_optimizer.step()

			# Update the frozen target models
			for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
				target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

			for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
				target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)


`

```python=
def eval_policy(policy, env_name, seed, mean, std, seed_offset=100, eval_episodes=10):
	eval_env = gym.make(env_name)
	eval_env.seed(seed + seed_offset)

	avg_reward = 0.
	for _ in range(eval_episodes):
		state, done = eval_env.reset(), False
		while not done:
			state = (np.array(state).reshape(1,-1) - mean)/std
			action = policy.select_action(state)
			state, reward, done, _ = eval_env.step(action)
			avg_reward += reward

	avg_reward /= eval_episodes
	d4rl_score = eval_env.get_normalized_score(avg_reward) * 100

	print("---------------------------------------")
	print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}, D4RL score: {d4rl_score:.3f}")
	print("---------------------------------------")
	return d4rl_score

参考

[1]。Scott Fujimoto, Shixiang Shane Gu: “A Minimalist Approach to Offline Reinforcement Learning”, 2021;arXiv:2106.06860.
[2]. A Minimalist Approach to Offline Reinforcement Learning,OpenReview
[3]. percent-difference,percent-difference

OfflineRL推荐阅读

离线强化学习(Offline RL)系列3: (算法篇) REM(Random Ensemble Mixture)算法详解与实现
离线强化学习(Offline RL)系列3: (算法篇)策略约束 – BRAC算法原理详解与实现(经验篇)
离线强化学习(Offline RL)系列3: (算法篇)策略约束 – BEAR算法原理详解与实现
离线强化学习(Offline RL)系列3: (算法篇)策略约束 – BCQ算法详解与实现
离线强化学习(Offline RL)系列2: (环境篇)D4RL数据集简介、安装及错误解决
离线强化学习(Offline RL)系列1:离线强化学习原理入门

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2022年4月12日
下一篇 2022年4月12日