Dyna-Q算法的理论基础及其代码实践【CliffWalking-v0】

2023-11-11

Dyna-Q 理论基础

强化学习中,“模型”通常指与智能体交互的环境模型,即对环境的状态转移概率和奖励函数进行建模。根据是否具有环境模型,强化学习算法分为两种:

  • 基于模型的强化学习(model-based):无模型的强化学习根据智能体与环境交互采样到的数据直接进行策略提升或者价值估计,比如 Sarsa 和 Q-learning 算法,便是两种无模型的强化学习方法。
  • 无模型的强化学习(model-free):在基于模型的强化学习中,模型可以是事先知道的,也可以是根据智能体与环境交互采样到的数据学习得到的,然后用这个模型帮助策略提升或者价值估计。动态规划算法中的策略迭代和价值迭代,则是基于模型的强化学习方法,在这两种算法中环境模型是事先已知的。

本文的主角 Dyna-Q 算法也是非常基础的基于模型的强化学习算法,不过它的环境模型是通过采样数据估计得到的。

强化学习算法有两个重要的评价指标:

  • 一个是算法收敛后的策略在初始状态下的期望回报
  • 另一个是样本复杂度,即算法达到收敛结果需要在真实环境中采样的样本数量。

基于模型的强化学习算法由于具有一个环境模型,智能体可以额外和环境模型进行交互,对真实环境中样本的需求量往往就会减少,因此通常会比无模型的强化学习算法具有更低的样本复杂度。但是,环境模型可能并不准确,不能完全代替真实环境,因此基于模型的强化学习算法收敛后其策略的期望回报可能不如无模型的强化学习算法。

说到这里,我们思考如下两个问题:

  • 像之前我们讨论的大量强化学习方法(DQN, Double DQN, 等等)都是基于model-free的,这也是RL学习的主要优势之一,因为大部分情况下智能体所处的环境会非常复杂,很难获得一个确定的模型。但是如果现在有一个已知模型的环境,该如何利用这个环境来加快智能体的学习进程呢?
  • 由于不可能精确和完美的拟合真正环境,纯基于模型的强化学习效果往往很差。那有没有什么办法可以在一定程度上避免这一点呢?

Dyna-Q选手给出了它的答案,把基于模型 + 不基于模型的强化学习结合起来,它既在模型中学习,也在交互中学习。下面就来看看Dyna-Q算法的一些思想吧。

Dyna-Q 算法是一个经典的基于模型的强化学习算法。Dyna-Q 使用一种叫做 Q-planning 的方法来基于模型生成一些模拟数据,然后用模拟数据和真实数据一起改进策略。Q-planning 每次选取一个曾经访问过的状态 s s s,采取一个曾经在该状态下执行过的动作 a a a,通过模型得到转移后的状态 s ′ s' s以及奖励 r r r,并根据这个模拟数据 ( s , a , r , s ′ ) (s,a,r,s') (s,a,r,s),用 Q-learning 的更新方式来更新动作价值函数。

Dyna-Q算法简单明了,在算法的前面和普通的Q-learning算法一模一样,只有后面有所不同,见下面的Dyna-Q算法流程。下面的红框中的步骤①的前提是环境的模型是基于确定环境下的假设(对于非确定的环境或者是非常复杂的环境根据特定的情况来做假设),后面的步骤可以被概括为使用已经学习到的模型来更新Q函数 n n n次。最后的Q函数更新和前面的一模一样。此外,在Dyna-Q中同样的强化学习方法既可以用于从实际经验中学习也可以用于从模拟经验中进行规划,因此该强化学习方法是学习和规划的最终共同道路
image.png

可以看到,在每次与环境进行交互执行一次 Q-learning 之后,Dyna-Q 会做 n n n次 Q-planning。其中 Q-planning 的次数 N N N是一个事先可以选择的超参数,当其为 0 时就是普通的 Q-learning。值得注意的是,上述 Dyna-Q 算法是执行在一个离散并且确定的环境中,所以当看到一条经验数据 ( s , a , r , s ′ ) (s,a,r,s') (s,a,r,s)时,可以直接对模型做出更新,即 M ( s , a ) ← r , s ′ M(s, a) \leftarrow r, s^{\prime} M(s,a)r,s

image.png
上图为Dyna-Q的结构。我们不难发现有两个向上的箭头指向 Policy/value functions,也就是我们这篇文章中说的Q function,左边的箭头是RL直接从实际的经验对Q进行更新,右边的更新Q箭头是从模拟经验进行的规划更新。由此可见,每当agent采取一个action时,学习的进程同时通过实际的选择的action和环境模型的模拟来更新,这样就能够加快我们智能体的学习速度。

Dyna-Q 代码实践

代码实现如下:(其实和QL差不多,只是多了一点步骤)

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import gym


class DynaQ:
    def __init__(self, ncol, nrow, epsilon, alpha, gamma, n_planning, n_action=4) -> None:
        self.q_table = np.zeros([nrow * ncol, n_action])  # 初始化Q(s,a)表格
        self.n_action = n_action
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.n_planning = n_planning
        self.model = dict()  # 环境模型

    def select_actions(self, state):
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.n_action)
        else:
            action = np.argmax(self.q_table[state])
        return action

    def QL_update(self, s, a, r, s_):
        '''Q-learning算法更新步骤'''
        td_error = r + self.gamma * self.q_table[s_].max() - self.q_table[s, a]
        self.q_table[s, a] += self.alpha * td_error

    def update(self, s, a, r, s_):
        self.QL_update(s, a, r, s_)
        ################################################################
        self.model[(s, a)] = r, s_  # 将数据添加到模型中
        for _ in range(self.n_planning):   # Q-planning循环
            # 随机选择曾经遇到过的状态动作对
            (s, a), (r, s_) = random.choice(list(self.model.items()))
            self.QL_update(s, a, r, s_)
        ################################################################

    def DynaQ_CliffWalking_running(self, num_episodes):
        return_list = []
        for i in range(10):  # 显示10个进度条
            with tqdm(total=num_episodes//10, desc=f"Iteration {i}") as pbar:
                for ep in range(num_episodes//10):  # 每个进度条的序列数
                    episode_return = 0
                    state = env.reset()
                    done = False
                    while not done:
                        action = agent.select_actions(state)
                        next_state, reward, done, _ = env.step(action)
                        episode_return += reward
                        agent.update(state, action, reward, next_state)
                        state = next_state

                    return_list.append(episode_return)
                    if (ep + 1) % 10 == 0:
                        pbar.set_postfix({
                            "episode": f"{i / 10 * i + ep + 1}",
                            "return": f"{np.mean(return_list[-10:])}"
                        })
                    pbar.update(1)
        return return_list


def moving_average(a, window_size):
    """滑动平均"""
    cumulative_sum = np.cumsum(np.insert(a, 0, 0))
    middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
    r = np.arange(1, window_size - 1, 2)
    begin = np.cumsum(a[:window_size - 1])[::2] / r
    end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
    return np.concatenate((begin, middle, end))


if __name__ == '__main__':
    env = gym.make("CliffWalking-v0")
    n_row, n_col = env.shape
    epsilon = 0.01
    alpha = 0.1
    gamma = 0.9
    num_episodes = 300

    n_planning_list = [0, 2, 20]

    for n_planning in n_planning_list:
        print('Q-planning步数为:%d' % n_planning)
        agent = DynaQ(n_col, n_row, epsilon, alpha, gamma, n_planning)
        return_list = agent.DynaQ_CliffWalking_running(num_episodes)
        episodes_list = list(range(len(return_list)))
        episodes_list = moving_average(episodes_list, 19)
        plt.plot(episodes_list, return_list, label=str(
            n_planning) + ' planning steps')

    plt.legend()
    plt.ylim(-250, 0)
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title('Dyna-Q on {}'.format('Cliff Walking'))
    plt.show()

代码运行结果如下:
image.png
从上述结果中我们可以很容易地看出,随着 Q-planning 步数的增多,Dyna-Q 算法的收敛速度也随之变快。当然,并不是在所有的环境中,都是 Q-planning 步数越大则算法收敛越快,这取决于环境是否是确定性的,以及环境模型的精度。在上述悬崖漫步环境中,状态的转移是完全确定性的,构建的环境模型的精度是最高的,所以可以通过增加 Q-planning 步数来直接降低算法的样本复杂度。

\quad
\quad
\quad


参考:


持续更新…

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Dyna-Q算法的理论基础及其代码实践【CliffWalking-v0】 的相关文章

随机推荐

  • 软件开发是一门艺术还是工程

    软件开发是一门艺术 艺术是没有具体形象的 一名艺术家必须要擅长创新 工程则是循规蹈矩的 一名工程师则必须要守规矩 而软件开发之所以可以称为一门艺术而不是工程师因为软件要满足用户的需求并不是循规蹈矩的 不同的软件开发者对做一个相同的软 件开发
  • 制作ubuntu server启动盘

    Mac 查看磁盘列表 gt gt gt diskutil list 格式化磁盘 gt gt gt diskutil partitionDisk dev disk2 MBR FAT32 UNTITLED 0b 推出磁盘 gt gt gt di
  • 练习--输出一个7行的菱形

    练习 输出一个7行的菱形 对于菱形的输出 在编程的时候需要注意空格的输出和 号的输出 define CRT SECURE NO WARNINGS 1 include
  • Redis简介以及和其他缓存数据库的区别

    转载 https blog csdn net xlgen157387 article details 60761232 Redis简介 Redis 是一个开源的内存中的数据结构存储系统 它可以用作数据库 缓存和消息中间件 它支持多种类型的数
  • python二维数组切片规则_详解Python二维数组与三维数组切片的方法

    如果对象是二维数组 则切片应当是x 的形式 里面有一个冒号 冒号之前和之后分别表示对象的第0个维度和第1个维度 如果对象是三维数组 则切片应当是x 里面有两个冒号 分割出三个间隔 三个间隔的前 中和后分别表示对象的第0 1 2个维度 x n
  • HTML中让表单input等文本框为只读不可编辑但可以获取value值的方法;让文本域前面的内容显示在左上角,居中...

    HTML中让表单input等文本框为只读不可编辑的方法 有时候 我们希望表单中的文本框是只读的 让用户不能修改其中的信息 如使input text的内容 中国两个字不可以修改 有时候 我们希望表单中的文本框是只读的 让用户不能修改其中的信息
  • 预加重、去加重和均衡总结

    1 定义 由于在信号通路中 相对于低频分量 信号的高频分量有很大的衰减 均衡的作用就是在接收端口对信号处理 根据信号经过的基板的衰减特性 将信号的高频成分适当增强 这样就可以得到低频成分与高频成分被 均衡 到一个水平的信号 增强了发送到接收
  • linux 可视化分区,可视化linux块设备的工具(分区,LVM PV,LV,mdadm设备……)

    我正在寻找一种能够扫描我的无GUI服务器的工具 并以一些丰富的可视化格式创建一个易于理解的所有块设备及其关系 磁盘分区 mdadm设备 LVM PV和LV等 的粗略概述 html pdf svg png 这是一个简单的示例可视化 sda1
  • Angular和RxJS的一些应用场景

    2019独角兽企业重金招聘Python工程师标准 gt gt gt 最近一直看有关rxjs的东西 想学会响应式编程思想 但这种东西没点实操根本不能融会贯通 现在只能借鉴别人的东西分析一下 先上两篇帖子都是关于rxjs在NG上的实际应用 使用
  • 解决CuDNN runtime版本和编译版本不同的问题

    在编译安装好TensorFlow后 可下载示例代码运行 但在执行run all sh时 出现如下错误 该错误意思就是CuDNN的runtime版本和编译时指定的版本不同 2018 05 08 09 00 18 042137 E tensor
  • linux 返回上一级目录 和 返回根目录

    返回上一级目录 cd 返回根目录 cd
  • DRM几个重要的结构体及panel开发

    一 DRM Linux下的DRM框架内容众多 结构复杂 本文将简单介绍下开发过程中用到的几个结构体 这几个结构体都在之前文章里面开发DRM驱动时用到的 未用到的暂不介绍 DRM中的KMS包含Framebuffer CRTC ENCODER
  • 机器智能的未来

    ChatGPT丨小智ai丨chatgpt丨人工智能丨OpenAI丨聊天机器人丨AI语音助手丨GPT 3 5丨OpenAI ChatGPT GPT 4 GPT 3 人机对话 ChatGPT应用 小智ai 小智ai 小智ai 小智ai 小智AI
  • MySql使用全记录4 -----设置root口令(即修改默认口令)

    设置MySql的root用户口令 本文由CSDN 蚍蜉撼青松 主页 http blog csdn net howeverpf 整理编辑 转载请注明出处 参考链接 http wenku baidu com view 73ab05737fd53
  • html取出单元格中的数值_简单爬取html页面的表格中的数据

    关于爬虫方面本人小白一个 通过无所不能的度娘 从中汲取营养 得到一个简单的能用的例子 在这分享一下 供大家一起汲取 首先说一下 你想从一个页面中获取到你想要的数据 首先你要先得到这个页面 然后把获取到的页面 使用Jsoup解析成 Docum
  • 如何使用挂载磁盘和windows服务器进行文件传输?

    如何远程连接windows服务器 相信对于使用过windows服务器的朋友来说这都是非常简单的事情 但是对于如何以及为什么挂载本地磁盘到windows服务器 很多新手就不明白为什么了 那么今天行云管家赵博士就来教大家怎样将本地磁盘挂载到到w
  • Windows10下配置Jmeter环境变量

    安装之后配置环境变量的步骤如下 1 点 此电脑 右键选 属性 2 选择 高级系统设置 环境变量 如下图 3 新建环境变量JMETER HOME 如下截图 4 点击确定之后 编辑 CLASSPATH 的变量 在变量值最后追加内容 JMETER
  • 你要的住宅地产行业数据化解决方案来啦!

    传统标准化复制品和服务越来越难以应付市场需求与行业竞争格局的改变 众多房地产企业寻求数字化转型 在转型过程中 会遇到各种各样的挑战 而一套合适的住宅地产行业数据化解决方案会解决很多难题 助力房企顺利实现转型 我推荐帆软住宅地产行业数据化解决
  • 记一次JAVA自定义@interface中方法定义诡异问题

    诡异问题描述 使用IDEA工具 正常不报错但是执行mvn install的时候 出现了大量的方法和属性不存在提示错误 实际上都要是存在 但无论如何都编译不通过 这种场景有点类似于在一个类中少了个大括号 然后真个类报错的那种感觉 问题查找 排
  • Dyna-Q算法的理论基础及其代码实践【CliffWalking-v0】

    Dyna Q 理论基础 强化学习中 模型 通常指与智能体交互的环境模型 即对环境的状态转移概率和奖励函数进行建模 根据是否具有环境模型 强化学习算法分为两种 基于模型的强化学习 model based 无模型的强化学习根据智能体与环境交互采