离线强化学习初调研

起源

传统强化学习是在线(online)过程,通过智能体和环境不停交互获取数据来评估和改进策略。但往往实时交互成本较高,还存在安全风险(如自动驾驶等),并且需要消耗较长的训练改进策略时间。
因此,离线强化学习就是利用已有数据集(历史数据集),尽可能达到和传统在线强化学习的效果。
在某些问题上,离线强化学习通常趋于保守,如自动驾驶中,数据集通常是次优数据集,学习策略在探索未知上明显概率更低,而更倾向于数据集内的动作,数据集可以认定为次优数据集,也就是相对保证了安全。因此在高速度行驶、变道等动作上的选择趋于保守。尽管如此,在自动驾驶等场景,仍需要添加专门的策略去确保安全,而不是像游戏一样允许试错。

概念对比

Online 和 offline 是相对的,指是否和环境交互。On-policy 和 off-policy 则都是 online 下的两种情况,区别是算法是否利用同一个策略来评价自己,通俗理解就是亲自下棋和看别人下棋。Offline 必须使用 off-policy 的数据,因为 offline 数据集显然不是其策略采集的。

Model-based 和 model-free 的最大区别就是是否需要知道状态间转移概率,前者关心 MDP 四元组,能够对下一步回报和状态做出预测,或者说有对环境建模(知己知彼,百战不殆),而后者只关心 Q 函数(一心只读圣贤书)。

深度学习中的损失函数的目的是使预测值和真实值之间的差距尽可能小,而强化学习中的损失函数的目的是使总奖励的期望尽可能大。

核心问题

分布偏移,即数据集不能覆盖所有的(s,a)情况,而 Q 函数对这种数据集外情况的估计往往是不精确的,因此可能选择到实际收益很低的动作,导致整体效果很差。
离线强化学习本质上也是权衡,要基于数据集学习一个优于行为策略的策略,同时也要最小化和行为策略的误差,避免分布偏移。

数据集

开源基准数据集

D4RL: A collection of reference environments for offline reinforcement learning

非必要或者找不痛快请勿在 Windows 上尝试配置环境运行训练。

相关研究

核心

探讨数据集特征对离线学习效果的影响。

五种数据集生成策略

两个评价指标

结论

核心

探讨样本复杂性,消除分布偏移,增大 SACo,应该如何采集样本,确定需要样本量。
样本复杂性,更多地关注算法性能关于数据集大小的敏感性。
个人意见,本文主要通过充分完备的实验证明了验证集对离线 RL 算法评估的重要意义,并没有显著的亮点。

评价指标

使用专家行动和策略行动之间的均方误差(MSE)来衡量行为者和专家之间的偏差。
使用 MSE 而不使用 KL 散度是因为大多数离线 RL 算法如 BCQ,均是基于确定性策略的。

结论

算法改进

两大思路

离线学习 vs. 模仿学习

离线学习中通常包含了行为克隆(模仿学习的一种)思想。
模仿学习是针对专家数据集,目标就是尽可能去贴近专家策略,可以认为专家数据集已经是最优,只需要学习出和专家策略相似的策略即可。而离线强化学习则增加了 reward,也就是增大能够获得更多奖励的轨迹概率,通常是从次优的数据集中学习策略。

通用实现

CORL: Research-oriented Deep Offline Reinforcement Learning Library
八种离线强化学习算法,当下工作集中于已有算法的改进,核心都是解决数据分布偏移导致的 Q 错误高估计:

BCQ

Off-Policy Deep Reinforcement Learning without Exploration

改进

Offline Reinforcement Learning for Autonomous Driving with Safety and Exploration Enhancement
image.png

CQL

Conservative Q-Learning for Offline Reinforcement Learning
image.png

IQL

Offline Reinforcement Learning with Implicit Q-Learning
image.png

TD3+BC

A Minimalist Approach to Offline Reinforcement Learning
image.png

MCQ

Mildly Conservative Q-Learning for Offline Reinforcement Learning



正在加载今日诗词....

📌 Powered by Obsidian Digital Garden and Vercel
载入天数...载入时分秒...