Deep Ensemble Bootstrapped Q-Learning (Deep-EBQL)【代码复现】

2023-11-18

Deep-EBQL理论基础

原文链接:Ensemble Bootstrapping for Q-Learning

Deep-EBQL是EBQL的深度学习版本,也即是在DQN的基础上,引入集成的思想,解决DQN过估计的问题。深度版本的EBQL在Atari环境下有着非常好的表现。

EBQL的理论基础与代码复现可以看这篇文章:Ensemble Bootstrapping for Q-Learning(EBQL)【论文复现】

这是原文中在Atari环境下做的实验,可以看到EBQL算法确实有不俗的表现。
image.png

Deep-EBQL代码实现

下面介绍Deep-EBQL算法的代码实现,这个算法在DQN的基础上加以改进就可以了。环境是基于Pendulum-v0的,因为这个环境能很好的观察到DQN的过估计现象。

经验池

利用Deque数据结构实现一个经验池:

class ReplayBuffer:
    """经验回放池"""

    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)  # 队列,先进先出

    # 将数据加入buffer
    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    # 从buffer中采样数据,数量为batch_size
    def sample(self, batch_size):
        transitions = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transitions)
        return np.array(state), action, reward, np.array(next_state), done

    # 目前buffer中数据的数量
    def size(self):
        return len(self.buffer)

网络设置

本文使用的环境是Pendulum-v0,在这个环境中状态和动作维数很小,所以一层隐藏层就行了。

class Qnet(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(Qnet, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, s):
        s = self.layer(s)
        return s

Deep-EBQL实现

这里的思想和标准的EBQL一模一样,只是把Q表格换成了神经网络!

class Deep_EBQL:
    def __init__(self, args):
        self.args = args

        self.K = args.K  # 使用Qnet的个数
        self.hidden_dim = args.hidden_dim
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.gamma = args.gamma  # 折扣因子
        self.epsilon = args.epsilon  # epsilon-贪婪策略
        self.target_update = args.target_update  # 目标网络更新频率
        self.count = 0  # 计数器,记录更新次数
        self.num_episodes = args.num_episodes
        self.minimal_size = args.minimal_size
        self.env = gym.make(args.env_name)

        random.seed(args.seed)
        np.random.seed(args.seed)
        self.env.seed(args.seed)
        torch.manual_seed(args.seed)

        self.replay_buffer = ReplayBuffer(args.buffer_size)
        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = 11  # 将连续动作分成11个离散动作

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.q_net = []
        self.optimizer = []
        for i in range(self.K):
            self.q_net.append(Qnet(self.state_dim, self.hidden_dim, self.action_dim).to(self.device))
            self.optimizer.append(Adam(self.q_net[i].parameters(), lr=self.lr))
        self.target_q_net = copy.deepcopy(self.q_net)

    def select_action(self, state):  # epsilon-贪婪策略采取动作
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.action_dim)
        else:
            state = torch.tensor([state], dtype=torch.float).to(self.device)

            action_Q_values = torch.zeros(state.shape[0], self.action_dim)
            for k in range(self.K):
                action_Q_values += self.q_net[k](state)
            action_Q_values = action_Q_values / self.K

            action = action_Q_values.argmax().item()
        return action

    def max_q_value(self, state):  # 为了显示算法的过估计现象
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        for i in range(self.K):
            return self.q_net[i](state).max().item()

    def update(self, transition):
        states = torch.tensor(transition["states"], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition["actions"]).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition["rewards"], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition["next_states"], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition["dones"], dtype=torch.float).view(-1, 1).to(self.device)

        ######################################################################
        kt = np.random.randint(0, self.K)
        q_values = self.q_net[kt](states).gather(1, actions)  # Q value

        max_next_q_values = torch.zeros(self.batch_size, 1).to(self.device)
        for k in range(self.K):
            if k != kt:
                max_next_q_values += self.target_q_net[k](next_states).max(1)[0].view(-1, 1)  # 下个状态的最大Q值
        max_next_q_values = max_next_q_values / (self.K - 1)
        q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)  # TD error
        ######################################################################

        loss = torch.mean(F.mse_loss(q_values, q_targets))  # 均方误差损失函数
        self.optimizer[kt].zero_grad()  # PyTorch中默认梯度会累积,这里需要显式将梯度置为0
        loss.backward()  # 反向传播更新参数
        self.optimizer[kt].step()

        if self.count % self.target_update == 0:  # 更新目标网络
            for k in range(self.K):
                self.target_q_net[k].load_state_dict(self.q_net[k].state_dict())

        self.count += 1

参数设置

def define_args():
    parser = argparse.ArgumentParser(description='Deep EBQN parametes settings')

    parser.add_argument('--batch_size', type=int, default=64, metavar='N', help='batch size')
    parser.add_argument('--lr', type=float, default=1e-2, help='Learning rate for the net.')
    parser.add_argument('--num_episodes', type=int, default=200, help='the num of train epochs')
    parser.add_argument('--seed', type=int, default=0, metavar='S', help='Random seed.')

    parser.add_argument('--gamma', type=float, default=0.9, metavar='S', help='the discount rate')
    parser.add_argument('--epsilon', type=float, default=0.01, metavar='S', help='the epsilon rate')

    parser.add_argument('--K', type=int, default=5, metavar='S', help='the number of Qnet used to algorithm')

    parser.add_argument('--target_update', type=float, default=10, metavar='S', help='the frequency of the target net')
    parser.add_argument('--buffer_size', type=float, default=5000, metavar='S', help='the size of the buffer')
    parser.add_argument('--minimal_size', type=float, default=500, metavar='S', help='the minimal size of the learning')

    parser.add_argument('--hidden_size', type=float, default=128, metavar='S', help='the size of the hidden layer')
    parser.add_argument('--env_name', type=str, default="Pendulum-v0", metavar='S', help='the name of the environment')
    args = parser.parse_args()
    return args

代码运行结果

在这里插入图片描述

Deep-EBQL、Double-DQN和DQN的对比

不是说Deep-EBQL可以解决Q值过估计问题吗,那我们就拿DQN和Double-DQN来做一个比较,看一下Deep-EBQL对于Q值估计的效果。

其中,DQN与Douvble-DQN的理论基础与代码实现看如下链接:

在实际对比中,我们尽量使一样的参数保持相同,比如经验池大小之类参数。在Pendulum-v0环境下,运行200个episodes,得到的结果如下:

在这

在这里插入图片描述

可以看到,Deep-EBQL算法可以很好的解决Q值过估计问题。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Deep Ensemble Bootstrapped Q-Learning (Deep-EBQL)【代码复现】 的相关文章

  • 如何使 Django ManyToMany “直通”查询更加高效?

    我使用的是 ManyToManyField 和 through 类 这会在获取事物列表时产生大量查询 我想知道是否有更有效的方法 例如 这里有一些描述书籍及其几位作者的简化类 它们通过角色类 定义 编辑器 插画家 等角色 class Per
  • SQLAlchemy:检查给定值是否在列表中

    问题 在 PostgreSQL 中 检查某个字段是否在给定列表中是使用IN操作员 SELECT FROM stars WHERE star type IN Nova Planet SQLAlchemy 的等价物是什么INSQL查询 我尝试过
  • 如何充分释放函数中使用的GPU内存

    我在用着cupy在接收一个函数numpy数组 将其推到 GPU 上 对其进行一些操作并返回cp asnumpy它的副本 问题 函数执行后内存没有被释放 如ndidia smi 我知道内存的缓存和重用cupy 但是 这似乎仅适用于每个用户 当
  • __getitem__、__setitem__ 如何处理切片?

    我正在运行 Python 2 7 10 我需要拦截列表中的更改 我所说的 更改 是指在浅层意义上修改列表的任何内容 如果列表由相同顺序的相同对象组成 则列表不会更改 无论这些对象的状态如何 否则 它会更改 我不需要找出来how列表已经改变
  • 为什么我不能“string”.print()?

    我的理解print 在 Python 和 Ruby 以及其他语言 中 它是字符串 或其他类型 上的方法 因为它的语法非常常用 打印 嗨 works 那么为什么不呢 hi print 在 Python 中或 hi print在红宝石工作 当你
  • 使用 OpenCV 进行相机校准 - 如何调整棋盘方块大小?

    我正在使用 OpenCV Python 示例开发相机校准程序 来自 OpenCV 教程 http opencv python tutroals readthedocs io en latest py tutorials py calib3d
  • 什么时候用==,什么时候用is?

    奇怪的是 gt gt gt a 123 gt gt gt b 123 gt gt gt a is b True gt gt gt a 123 gt gt gt b 123 gt gt gt a is b False Seems a is b
  • 直接打开Spyder还是通过Pythonxy打开?

    之前 我一直在运行PythonSpyder 我总是开始Spyder直接双击其图标 今天突然发现我还有一个东西叫Python x y 我注意到我也可以开始Spyder通过它 这两种方法有什么区别吗 如果不是的话 有什么意义Python x y
  • 获取 HTML 代码的结构

    我正在使用 BeautifulSoup4 我很好奇是否有一个函数可以返回 HTML 代码的结构 有序标签 这是一个例子 h1 Simple example h1 p This is a simple example of html page
  • 如何将字符串方法应用于数据帧的多列

    我有一个包含多个字符串列的数据框 我想使用对数据帧的多列上的系列有效的字符串方法 我希望这样的事情 df pd DataFrame A 123f 456f B 789f 901f df Out 15 A B 0 123f 789f 1 45
  • Python `concurrent.futures`:根据完成顺序迭代 future

    我想要类似的东西executor map 除了当我迭代结果时 我想根据完成的顺序迭代它们 例如首先完成的工作项应该首先出现在迭代中 等等 这样 当且仅当序列中的每个工作项尚未完成时 迭代就会阻塞 我知道如何使用队列自己实现这一点 但我想知道
  • 在Python中确定句子中2个单词之间的邻近度

    我需要确定 Python 句子中两个单词之间的接近度 例如 在下面的句子中 the foo and the bar is foo bar 我想确定单词之间的距离foo and bar 确定之间出现的单词数foo and bar 请注意 该词
  • 如何将列表中的每个项目转换为字符串,以便连接它们? [复制]

    这个问题在这里已经有答案了 我需要加入一个项目列表 列表中的许多项目都是从函数返回的整数值 IE myList append munfunc 我应该如何将返回的结果转换为字符串以便将其加入列表 我是否需要对每个整数值执行以下操作 myLis
  • Python:使用for循环更改变量后缀

    我知道这个问题被问了很多 但到目前为止我无法使用 理解答案 我想改变for循环中变量的后缀 我尝试了 stackoverflow 搜索提供的所有答案 但很难理解提问者经常提出的具体代码 因此 为了清楚起见 我使用一个简单的示例 这并不意味着
  • 使用 plone.api 创建文件的 Python 脚本在设置文件时出现错误 WrongType

    Dears 我正在创建一个脚本python来在Plone站点中批量上传文件 安装是UnifiedInstaller Plone 4 3 10 该脚本读取了一个txt 并且该txt以分号分隔 在新创建的项目中设置文件时出现错误 下面是脚本 f
  • 如何使用 Ajax 在 Flask 中发布按钮值而不刷新页面?

    我有一个问题 当我单击 Flask 应用程序中的按钮时 我想避免重新加载 我知道有 Ajax 解决方案 但我想知道如何将我的按钮链接到 ajax 函数以发布按钮值并运行链接到其值的 python 函数 这是我的 html 按钮 div di
  • 无法在 Windows 服务器上使 SVN 预提交脚本失败

    我正在编写一个 SVN pre commit bat 文件 该文件调用 Python 脚本来查询我们的问题跟踪系统 以确定用户提供的问题跟踪 ID 是否处于正确的状态 例如 打开 状态 并与正确的关联项目 SVN 服务器运行 Windows
  • Python模糊字符串匹配作为相关样式表/矩阵

    我有一个文件 其中包含 x 个字符串名称及其关联的 ID 本质上是两列数据 我想要的是一个格式为 x by x 的相关样式表 将相关数据作为 x 轴和 y 轴 但我想要 fuzzywuzzy 库的函数 fuzz ratio x y 作为输出
  • 使用Python的线程模块调用ctypes函数比使用多处理更快?

    我一生都无法找出这个问题的答案 我编写了一个可以执行数百次繁重计算的脚本 我有一个绝妙的主意 将这些计算任务编写为 C 然后使用 Python 的 ctypes 与它们交互 我心想 我什至可以使用并行性进一步优化它 我最初的方法是使用线程
  • 将自定义属性添加到 Tk 小部件

    我的主要目标是向小部件添加隐藏标签或字符串之类的内容 以在其上保存简短信息 我想到创建一个新的自定义 Button 类 在本例中我需要按钮 它继承所有旧选项 这是代码 form tkinter import class NButton Bu

随机推荐

  • 数据链路层相关协议

    网络类型 根据数据链路层协议进行划分 MA 多点接入网络 BMA广播型 NBMA非广播型 P2P 点到点的网络 以太网协议 需要使用MAC地址对不同的主机设备进行区分和标识 主要因为利用以太网组件的二层网络可以包含 两个和两个以上 的接口
  • 学完责任链之后,逻辑思维上升了一个段位,我马上写了一个月薪3万的简历,HR看了让我去上班

    经过上一篇的文章 我们学习了责任链模式和策略模式 设计模式相对重要 对架构 项目拓展性 移植性要求比较高 下面我会说到简历 对于开发来说 简历是程序员的第二生命 技术是第一生命 简历第二生命 学历第三生命 简历到底是什么 简历是你的第二生命
  • js密码验证

    js密码验证
  • Paper Reading:《LISA: Reasoning Segmentation via Large Language Model》

    目录 简介 目标 创新点 方法 训练 实验 总结 简介 LISA Reasoning Segmentation via Large Language Model 基于大型语言模型的推理分割 日期 2023 8 1 v1 单位 香港中文大学
  • python函数参数里面带*是什么意思

    文章参考 https blog csdn net jiangkejkl article details 121346940 1 函数参数定义中使用独立的符号 在函数定义时 使用了一个独立的符号 这表示在符号后面的参数 调用函数时 必须使用k
  • NAPI机制分析

    NAPI机制分析 NAPI 的核心在于 在一个繁忙网络 每次有网络数据包到达时 不需要都引发中断 因为高频率的中断可能会影响系统的整体效率 假象一个场景 我们此时使用标准的 100M 网卡 可能实际达到的接收速率为 80MBits s 而此
  • 解决 IDEA中springboot项目 修改页面无法生效问题

    解决 IDEA中springboot项目 修改页面无法生效问题 之前网上找了很多解决办法 都是无效的 所以找到解决办法后 先发个博客说一下 至此就完成了springboot 无需重启则对html修改生效 如出现偶尔无效时 请刷新浏览器 之前
  • Linux下使用Git上传和更新代码

    一 上传代码 1 去github上根据网站的提示来创建自己的远程Repository 仓库 2 建立本地git仓库 git init 注意 此指令本地源码根目录执行 执行成功后 会在当前目录生成一个隐藏的名字为 git 的目录 所有对本地仓
  • 【ClickHouse数据库】如何在Win10的Ubuntu上通过ClickHouse存取行情数据

    如何在Win10的Ubuntu上通过ClickHouse存取行情数据 前言 一 ClickHouse是什么 二 如何在Ubuntu上安装ClickHouse 三 添加用户并设置密码 四 使用 1 使用DBeaver操作数据库 2 向Clic
  • 计算机图形学方向和前景&&3D

    我是刚入坑计算机图形学的小菜鸟 在百度上搜索计算机图形学方向和前景和3D 几乎不能搜到什么有用的东西 google还能搜到些有用的 但是需要翻墙 恰好前几天山大承办的games 北京大学陈宝权老师提出了图形学的新疆界 10个左右的国内图形学
  • vue 如何获取input中光标位置,并且点击按钮在当前光标后追加内容

    1 第一步 监听输入框的鼠标失焦事件
  • (原创)c++11中的日期和时间库

    c 11提供了日期时间相关的库chrono 通过chrono相关的库我们可以很方便的处理日期和时间 c 11还提供了字符串的宽窄转换功能 也提供了字符串和数字的相互转换的库 有了这些库提供的便利的工具类 我们能方便的处理日期和时间相关的转换
  • linux服务器管理与维护,linux服务器管理与维护速训..ppt

    linux服务器管理与维护速训 入门级命令 1990年秋天 Linus在芬兰首都赫尔辛基大学学习操作系统课程 因为上机需要排队等待 Linus买了台PC机 开发了第一个程序 程序包括两个进程 分别向屏幕上写字母A和B 然后用定时器来切换进程
  • mysql必考知识_可能是全网最好的MySQL重要知识点 !面试必备

    标题有点标题党的意思 但希望你在看了文章之后不会有这个想法 这篇文章是作者对之前总结的 MySQL 知识点做了完善后的产物 可以用来回顾MySQL基础知识以及备战MySQL常见面试问题 Python资源共享群 484031800 什么是My
  • 在GPU上实现光线跟踪

    include cuda h include book h include cpu bitmap h define DIM 1024 生成图像的大小 DIM DIM define SPHERES 20 生成的图像中球体的个数 define
  • Laplace smoothing in Naïve Bayes algorithm(拉普拉斯平滑)

    在这里转载只是为了让不能够科学搜索的同学们看到好文章而已 个人无收益只是分享知识 顺手做个记录罢了 原网址 https towardsdatascience com laplace smoothing in na C3 AFve bayes
  • 走进计算机的0和1

    一 计算机的产生在历史的长河中 人类发明和创造了许多算法与计算工具 在我国商朝时期就有算珠 春秋战国时期的算表 唐宋时期的算盘 欧洲在16世纪也发明了许多的计算工具 经过一系列的发展 知道1946年一月 世界上第一台计算机诞生 计算机比较笨
  • matlab 神经网络设计多层隐含层_【MATLAB深度学习】多层神经网络

    多层神经网络 对于多层神经网络的训练 delta规则是无效的 因为应用delta规则训练必须要误差 但在隐含层中没有定义 输出节点的误差是指标准输出和神经网络输出之间的差别 但训练数据不提供隐藏层的标准输出 真正的难题在于怎么定义隐藏节点的
  • 【Python】`*args` 和 `**kwargs`的用法【最全详解】

    args 和 kwargs的用法 猛滴打开博客 发现实在有段时间没更新了 又刚好用到了 kwargs 遂想起了许久之前总结的这篇博客 夸张点说也算是自己的一个呕心沥血之作吧 相信不少同学在看大神的程序时 总会看见 args kwargs 这
  • Deep Ensemble Bootstrapped Q-Learning (Deep-EBQL)【代码复现】

    Deep EBQL理论基础 原文链接 Ensemble Bootstrapping for Q Learning Deep EBQL是EBQL的深度学习版本 也即是在DQN的基础上 引入集成的思想 解决DQN过估计的问题 深度版本的EBQL