写在前面一些无关紧要的话:印象中,这个专栏已经快五个月没更新过了,How time flies!当时本来应该把TD-learning这一块写完再停笔,但不知被什么事所打扰,遂忘却以致搁置到今日。至今仍然觉得写帖子不失为一种有效的学习方式,不仅方便他人浏览,而且每当自己忘记一些细节之时,重温起来亦很方便。故今天毅然决然写起了来岛国之后的第一篇技术帖。
言归正传,回到今天的主角时间差分法(TD-learning),这是一个大类,包含我们耳熟能详的Sarsa、Q-learning等,以及各种拓展变形:expected Sarsa、n-step Sarsa、double Q-learning等等。
Sarsa算法
Sarsa是典型的时间差分法,TD-learning结合了MC的sampling方法和DP的bootstrap方法,是空间复杂度和时间复杂度都最低的算法。与蒙特卡洛方法的相似之处在于二者均通过与环境交互得到的序列来估计值函数,不同之处在于蒙特卡洛方法在估计值函数时用了完整序列的长期回报,而TD法使用的是非完整序列的回报,对于一步TD法,则使用的是 使用的是当前回报和下一时刻的估计。
TD法对值函数更新的框架一般如下:
。其中,
被称作TD target项,
被称作TD error项。
不同的TD法主要体现在两点:一是行为策略和评估策略(改进策略)是否一致,即属于on-policy的控制问题还是off-policy的控制问题;二是TD target项不同。Sarsa算法的TD target项为
,动作值函数的更新公式为
。所以Sarsa算法的得名也很直接,当前状态(S)、当前动作(A)、下一状态(S')、下一动作(A')和回报(R)的英文首字母组合而成,伪码流程如下:
Sarsa伪码流程
Q-learning
Q-learning是典型的off-policy control问题,其行动策略为
策略,该策略保证算法的探索(exploration);其评估策略为贪婪策略,该策略保证算法的利用(exploitation)。Q-learning算法在选择下一个动作时,直接贪婪地选择Q值最大的其TD target项为:
。伪码如下:
Q-learning伪码流程
Sarsa与Q-learning的图示及代码
从直观的图形比较二者,
Sarsa图示
Q-learning图示
算法部分代码如下,完整代码在我github仓库中:https://github.com/zhengsizuo/RL_exercize/blob/master/TD_learning.py
def sarsa_eval(self):
"""使用Sarsa算法评估策略"""
for i in range(self.k):
episode = []
s = initial_state
a = self.policy_behavior(s)
while True:
episode.append((s, a))
s_, r, done = self.step(s, a)
a_ = self.policy_behavior(s_)
x, y = self.turple2id((s, a))
next_x, next_y = self.turple2id((s_, a_))
if s_ in self.terminal_states:
self.q_values[next_x][next_y] = 0
td_target = r + gamma*self.q_values[next_x][next_y]
self.q_values[x][y] += alpha * (td_target - self.q_values[x][y])
s = s_
a = a_
if done:
break
# 如果又回到之前经历过的状态
if (s, a) in episode:
break
def q_learning(self):
"""使用Q-learning算法对策略进行评估"""
for i in range(self.k):
episode = []
s = initial_state
while True:
a = self.policy_behavior(s)
if (s, a) in episode:
break
episode.append((s, a))
s_, r, done = self.step(s, a)
x, y = self.turple2id((s, a))
q_max = np.max(self.q_values[s_])
td_target = r + gamma * q_max
episode.append((s, a, td_target))
self.q_values[x][y] += alpha * (td_target - self.q_values[x][y])
s = s_
if done:
break
还有一个Udacity作业版本:
https://github.com/zhengsizuo/DRL_udacity/tree/master/Temporal_Differencegithub.com