用Tensorflow Agents实现强化学习DQN

2023-11-01

在我之前的博客中强化学习笔记(4)-深度Q学习_gzroy的博客-CSDN博客,实现了用Tensorflow keras搭建DQN模型,解决小车上山问题。在代码里面,需要自己实现经验回放,采样等过程,比较繁琐。

Tensorflow里面有一个agents库,实现了很多强化学习的算法和工具。我尝试用agents来实现一个DQN模型来解决小车上山问题。Tensorflow网上的DQN教程是解决CartPole问题的,如果直接照搬这个代码来解决小车上山问题,则会发现模型无法收敛。经过一番研究,我发现原来是在agents里面,默认环境的回合步数是限制在200步,这样导致小车一直无法到达回合结束的位置,模型学习到的总回报一直保持不变。

以下代码是加载训练环境和评估环境,需要注意的是max_episode_steps需要设置为0,即不限制回合的最大步数:

from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.agents.dqn import dqn_agent
from tf_agents.networks import q_network
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.policies import random_tf_policy
from tf_agents.utils import common
from tf_agents.drivers import dynamic_step_driver
from tf_agents.policies import EpsilonGreedyPolicy
import tensorflow as tf
from tqdm import trange
from tf_agents.policies.q_policy import QPolicy
import seaborn as sns
from matplotlib.ticker import MultipleLocator
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib notebook

env_name = 'MountainCar-v0'
env = suite_gym.load(env_name)
train_py_env = suite_gym.load(env_name, max_episode_steps=0)
eval_py_env = suite_gym.load(env_name, max_episode_steps=0)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

然后我们建立一个DQN agent,这个agent包括了一个Q_net和一个target network,这两个network的结构是相同的,其中Q_net用于学习状态动作对的Q值,target network分享Q_net的权重,用于给定状态输入下找到最大Q值的动作。target_update_tau和target_update_period两个参数用于控制何时更新target network的权重,这里的设定是每一步更新target network的权重W_target = (1-0.005)*W_target + 0.005*W_q。gamma参数表示下一状态对应的Q值有多少计入到U值。epsilion_greedy用于控制有多少百分比的概率是随机挑选动作而不是根据Q值。

q_net = q_network.QNetwork(
    train_env.time_step_spec().observation,
    train_env.action_spec(),
    fc_layer_params=(64,))

target_q_net = q_network.QNetwork(
    train_env.time_step_spec().observation,
    train_env.action_spec(),
    fc_layer_params=(64,))

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    target_update_tau=0.005,
    target_update_period=1,
    gamma=0.99,
    epsilon_greedy=0.1,
    td_errors_loss_fn=common.element_wise_squared_loss,
    optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.001))

设置一个缓冲池,用于存放和回放历史经验数据

replay_buffer_capacity = 10000

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_capacity)

# Add an observer that adds to the replay buffer:
replay_observer = [replay_buffer.add_batch]

先用一个随机动作策略来收集一些历史数据到缓冲池

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(), train_env.action_spec())
initial_driver = dynamic_step_driver.DynamicStepDriver(
      train_env,
      random_policy,
      observers=replay_observer,
      num_steps=1)
for _ in range(1):
    time_step = train_env.reset()
    step = 0
    while not time_step.is_last():
        step += 1
        if step>1000:
            break
        time_step, _ = initial_driver.run(time_step)

搜集数据之后,我们可以把replay_buffer转换为dataset来方便读取数据。这里的num_steps=2表示每次需要取两条相邻的经验数据,因为计算U值的时候需要用下一条数据的Q值来计算。

dataset = replay_buffer.as_dataset(
    num_parallel_calls=3,
    sample_batch_size=128,
    num_steps=2).prefetch(3)

iterator = iter(dataset)

定义一个评估函数,用于评估训练效果:

def compute_avg_return(environment, policy, num_episodes=10):
    total_return = 0.0
    for _ in range(num_episodes):
        time_step = environment.reset()
        episode_return = 0.0
        step = 0
        while not time_step.is_last():
            step += 1
            if step>1000:
                break
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
            episode_return += time_step.reward
        total_return += episode_return
    avg_return = total_return / num_episodes
    return avg_return.numpy()[0]

定义一个函数绘制训练过程中的每回合回报和评估回报:

class Chart:
    def __init__(self):
        self.fig, self.ax = plt.subplots(figsize = (8, 6))
        x_major_locator = MultipleLocator(1)
        self.ax.xaxis.set_major_locator(x_major_locator)
        self.ax.set_xlim(0.5, 50.5)

    def plot(self, data):
        self.ax.clear()
        sns.lineplot(data=data, x=data['episode'], y=data['reward'], hue=data['type'], ax=self.ax)
        self.fig.canvas.draw()

最后是训练和评估的代码。这里设置随着训练回合的增加,epsilion_greedy的值也逐渐减小,相当于在训练初期,随机寻找动作的概率较大,随着训练的增加,Q_net能更好的反映真实的Q值,因此随机动作的概率需要相应减小。另外要注意的是,由于初始回合里面需要通过一定的随机概率才能找到合适的动作结束回合,有可能会碰到回合经过很多步仍不能到达回合结束的条件,例如我曾经碰到第一回合运行了15000多步仍不能结束回合,这是可以重新进行训练。

train_episodes = 50
num_eval_episodes = 5
epsilon = 0.1
chart = Chart()

for episode in range(1,train_episodes):
    lr_step.assign(episode)
    learning_rate = learning_rate_fn(episode)
    episodes.append(episode)
    episode_reward = 0
    if epsilon>0.01:
        train_policy = EpsilonGreedyPolicy(agent.policy, epsilon=epsilon)
        train_driver = dynamic_step_driver.DynamicStepDriver(
              train_env,
              train_policy,
              observers=replay_observer,
              num_steps=1)
        epsilon -= 0.01
    time_step = train_env.reset()
    total_loss = 0
    step = 0
    while not time_step.is_last():
        step += 1
        time_step, _ = train_driver.run(time_step, _)
        experience, unused_info = next(iterator)
        train_loss = agent.train(experience).loss
        total_loss += train_loss
        episode_reward += time_step.reward.numpy()[0]
        if step%100==0:
            print("Epsiode_{}, step_{}, loss:{}".format(episode, step, total_loss/step))
    if episode==1:
        rewards_df = pd.DataFrame([[episode, episode_reward, 'train']], columns=['episode','reward','type'])
    else:
        rewards_df = rewards_df.append({'episode':episode, 'reward':episode_reward, 'type':'train'}, ignore_index=True)

    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    rewards_df = rewards_df.append({'episode':episode, 'reward':avg_return, 'type':'eval'}, ignore_index=True)
    chart.plot(rewards_df)

训练完成后,以下代码可以把训练后的策略在评估环境上运行,并生成视频,可以看到训练效果:

import imageio
import base64
import IPython

def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

def create_policy_eval_video(policy, filename, num_episodes=1, fps=30):
  filename = filename + ".mp4"
  with imageio.get_writer(filename, fps=fps) as video:
    for _ in range(num_episodes):
      time_step = eval_env.reset()
      video.append_data(eval_py_env.render())
      while not time_step.is_last():
        action_step = policy.action(time_step)
        time_step = eval_env.step(action_step.action)
        video.append_data(eval_py_env.render())
  return embed_mp4(filename)

create_policy_eval_video(agent.policy, "trained-agent")

视频如下:

trained-agent

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

用Tensorflow Agents实现强化学习DQN 的相关文章

  • 如何避免使用 python 处理空的标准输入?

    The sys stdin readline 返回之前等待 EOF 或新行 所以如果我有控制台输入 readline 等待用户输入 相反 我想打印帮助并在没有需要处理的情况下退出并显示错误 而不是等待用户输入 原因 我正在寻找一个Pytho
  • 从文本文件中删除特定字符

    我对 Python 和编码都很陌生 我当时正在做一个小项目 但遇到了一个问题 44 1 6 23 2 7 49 2 3 53 2 1 68 1 6 71 2 7 我只需要从每行中删除第三个和第六个字符 或者更具体地说 从整个文件中删除 字符
  • 将非常大的Python列表输出保存到mysql表中

    我想将 python 生成的列表的输出保存在 mysql 数据库的表中 该表如下所示 mysql 中的 myapc8 表 https i stack imgur com 4B4Hz png这是Python代码 在此输入图像描述 https
  • 使用 Django 将文件异步上传到 Amazon S3

    我使用此文件存储引擎在上传文件时将文件存储到 Amazon S3 http code welldev org django storages wiki Home http code welldev org django storages w
  • 协程从未被等待

    我正在使用一个简单的上下文管理器 其中包含一个异步循环 class Runner def init self self loop asyncio get event loop def enter self return self def e
  • 运行 Python 单元测试,以便成功时不打印任何内容,失败时仅打印 AssertionError()

    我有一个标准单元测试格式的测试模块 class my test unittest TestCase def test 1 self tests def test 2 self tests etc 我的公司有一个专有的测试工具 它将作为命令行
  • python中basestring和types.StringType之间的区别?

    有什么区别 isinstance foo types StringType and isinstance foo basestring 对于Python2 basestring是两者的基类str and unicode while type
  • 如何查找或安装适用于 Python 的主题 tkinter ttk

    过去 3 个月我一直在制作一个机器人 仅用代码就可以完美运行 现在我的下一个目标是为它制作一个 GUI 但是我发现了一些障碍 主要的一个是能够看起来不像一个 30 年前的程序 我使用的是 Windows 7 我仅使用 Python 3 3
  • Ubuntu systemd 自定义服务因 python 脚本而失败

    希望获得有关 Ubuntu 中的 systemd 守护进程服务的一些帮助 我写了一个 python 脚本来禁用 Dell XPS 上的触摸屏 这更像是一个问题 而不是一个有用的功能 该脚本可以工作 但我不想一直启动它 这就是为什么我想到编写
  • 在骨架图像中查找线 OpenCV python

    我有以下图片 我想找到一些线来进行一些计算 平均长度等 我尝试使用HoughLinesP 但它找不到线 我能怎么做 这是我的代码 sk skeleton mask rows cols sk shape imgOut np zeros row
  • Python 中维基百科 API 中的 DisambiguationError 和 GuessedAtParserWarning

    我想获得维基百科与搜索词相关的可能且可接受的名称列表 在这种情况下是 电晕 当输入以下内容时 print wikipedia summary Corona 这给出了以下输出 home virej local lib python3 8 si
  • 在Raspberry pi上升级skimage版本

    我已经使用 Raspberry Pi 2 上的 synaptic 包管理器安装了 python 包 然而 skimage 模块版本 0 6 是 synaptic 中最新的可用版本 有人可以指导我如何将其升级到0 11 因为旧版本中缺少某些功
  • 具有不同尺寸图像的 Tensorflow 输入数据集

    我正在尝试使用不同大小的输入图像来训练完全卷积神经网络 我可以通过循环训练图像并在每次迭代时创建单个 numpy 输入来做到这一点 即 for image input label in zip image data labels train
  • 检测是否从psycopg2游标获取?

    假设我执行以下命令 insert into hello username values me 我跑起来就像 cursor fetchall 我收到以下错误 psycopg2 ProgrammingError no results to fe
  • 如何给URL添加变量?

    我正在尝试从网站收集数据 我有一个 Excel 文件 其中包含该网站的所有不同扩展名 F i www example com example2 我有一个脚本可以成功从网站中提取 HTML 但现在我想为所有扩展自动执行此操作 然而 当我说 s
  • 如何从namedtuple实例列表创建pandas DataFrame(带有索引或多索引)?

    简单的例子 from collections import namedtuple import pandas Price namedtuple Price ticker date price a Price GE 2010 01 01 30
  • python从二进制文件中读取16字节长的双精度值

    我找到了蟒蛇struct unpack 读取其他程序生成的二进制数据非常方便 问题 如何阅读16 字节长双精度数出二进制文件 以下 C 代码将 1 01 写入二进制文件三次 分别使用 4 字节浮点型 8 字节双精度型和 16 字节长双精度型
  • 用于插入或替换 URL 参数的 Django 模板标签

    有人知道 Django 模板标签可以获取当前路径和查询字符串并插入或替换查询字符串值吗 例如向 some custom path q how now brown cow page 3 filter person 发出请求 电话 urlpar
  • 定义在文本小部件中双击时选择哪些字符

    在 Windows 上 双击文本小部件中的单词也将选择连接的标点符号 有什么方法可以定义您想要选择的角色吗 tcl wordchars该变量的值是一个正则表达式 可以设置它来控制什么被视为 单词 字符 例如 通过双击 Tk 中的文本来选择单
  • 无法安装最新版本的 Numpy (1.22.3)

    我正在尝试安装最新版本的 numpy 即 1 22 3 但看起来 pip 无法找到最后一个版本 我知道我可以从源代码本地安装它 但我想了解为什么我无法使用 pip 安装它 PS 我有最新版本的pip 22 0 4 ERROR Could n

随机推荐

  • Oracle插入或修改数据怎么也不行的解决方法

    今天在公司操作数据库 在删除一条数据的时候忘记提交事务了 之后就去添加别的了 但是后来发现怎么也添加不上 所以觉的是事务锁住了 1 直接判断未提交事务引起的表的行锁 1 1判断哪个SESSION执行了DML Insert Update De
  • C语言-蓝桥杯-算法训练 印章

    问题描述 共有 n 种图案的印章 每种图案的出现概率相同 小A买了 m 张印章 求小A集齐 n 种印章的概率 输入格式 一行两个正整数n和m 输出格式 一个实数P表示答案 保留4位小数 样例输入 2 3 样例输出 0 7500 解题思路 共
  • PPTP穿透NAT之深入分析

    PPTP穿透NAT之深入分析 bytxl的专栏 CSDN博客大家好 现在是人静时分 我公司人员都以溜光 只有我还在面对computer 在经过不解 迷惑 结论之后 现与大家分享结果 感谢朋友Zyliday 见贤思齐的实验帮助 在研究技术原理
  • URP自定义后处理(相机滤镜)

    前言 之前做游戏一直想弄个可以实时触发相机滤镜的效果 自处找了教程和资料 想要做到自定义效果的话最好办法是在unity 内部实现 这个办法比较硬核 其实不适合我这样的小白 所以我在实现的过程中非常痛苦 我用的unity URP 模式其实自带
  • OMG!解释执行java字节码文件的命令

    美团一面 收到了HR的信息 通知我去面试 说实话真的挺紧张的 自己准备了近一个月的时间 很担心面试不过 到时候又后悔不该 裸辞 自我介绍 spring的IOC AOP原理 springmvc的工作流程 handlemapping接收的是什么
  • python中的list格式化输出

    在使用python时 我们经常会用到列表 list 由于它可以保存不同类型的数据 因此很多场景下我们都会使用它来保存数据 在写代码的过程中我们经常想要显示list的内容 直接调用print又会显得很丑 还会带着方括号 和逗号 这个太丑 又不
  • Hive数据库连接-连接池实现

    Hive数据库连接 连接池实现 通过HiveJDBC获取Hive的连接Connection 下面我们简单介绍HiveJDBC数据库连接实现 HiveJDBC配置文件 连接池配置文件hive jdbc properties 初始化连接池数 d
  • Linux运维跳槽必备的40道面试精华题

    1 什么是运维 什么是游戏运维 1 运维是指大型组织已经建立好的网络软硬件的维护 就是要保证业务的上线与运作的正常 在他运转的过程中 对他进行维护 他集合了网络 系统 数据库 开发 安全 监控于一身的技术 运维又包括很多种 有DBA运维 网
  • 鼠标点击获得opencv图像坐标和像素值

    目录 一 核心函数 二 在类中定义并且使用 1 将回调函数直接声明为友元函数 2 h 3 DW S OnMou cpp 4 main cpp 三 函数调用 1 OnMouse h 2 OnMouse cpp 一 核心函数 setMouseC
  • 如何在没有 USB 数据线的情况下使用 Android Studio 在手机中安装 Android

    背景 如何在没有 USB 数据线的情况下使用 Android Studio 在手机中安装 Android 应用程序 运行调式一个Android项目 写下必要的代码后 接下来的任务是在模拟器或手机上运行应用程序 测试应用程序是否正常 及deb
  • python numpy中对ndarry按照index(位置下标)增删改查

    在numpy中的ndarry是一个数组 因此index就是位置下标 注意下标是从0开始 增加 在插入时使用np insert 在末尾添加时使用np append 删除 需要使用np delete 修改 直接指定下标 查找 直接指定下标 示例
  • 【Shell】find文件查找

    语法格式 find 路径 选项 操作 选项参数对照表 常用选项 name 查找 etc目录下以conf结尾的文件ind etc nam iname 查找当前目录下文件名为aa的文件 不区分大小写 find iname aa user 查找文
  • [激光原理与应用-69]:激光焊接的10大常见缺陷及解决方法

    激光焊接是一种以高能量密度的激光束作为热源的高效精密焊接方法 如今 激光焊接已广泛应用于各个行业 如 电子零件 汽车制造 航空航天等工业制造领域 但是 在激光焊接的过程中 难免会出现一些缺陷或次品 只有充分了解这些缺陷并学习如何避免它们 才
  • 九轴传感器之测试篇

    关于九轴传感器的数据测试处理
  • CORS与CSRF

    本文首发于我的Github博客 本篇文章介绍了CORS和CSRF的概念 作者前几天在和带佬们聊天的时候把两个概念搞混了 所以才想要了解 简单来说 CORS Cross Origin Resource Sharing 跨域资源分享 是一种机制
  • (1)基础学习——图解pin、pad、port、IO、net 的区别

    本文内容有参考多位博主的博文 综合整理如下 仅做和人学习记录 如有专业性错误还请指正 谢谢 参考1 芯片资料中的pad和pin的区别 imxiangzi的博客 CSDN博客 pin和pad的区别 参考2
  • IntelliJ IDEA 运行卡顿解决方案

    IntelliJ IDEA 运行卡顿解决方案 1 开启IntelliJ IDEA缓慢 想要提升启动速度 则打开D JetBrains IntelliJ IDEA 2020 3 2 bin 依据实际安装路径 目录下对应文件idea64 exe
  • 对csv文件,又get了新的认知

    背景 在数据分析时 有时我们会碰到csv格式文件 需要先进行数据处理 转换成所需要的数据格式 然后才能进行分析 业务侧的同学可能对Excel文件比较熟悉 Excel可以把单个sheet直接保存为csv文件 也可以直接读取csv文件 变成Ex
  • Qt 进程间通信

    Qt进程间通信的方法 TCP IP Local Server Socket 共享内存 D Bus Unix库 QProcess 会话管理 TCP IP 使用套接字的方式 进行通信 之前介绍了 这里就不介绍了 Local Server Soc
  • 用Tensorflow Agents实现强化学习DQN

    在我之前的博客中强化学习笔记 4 深度Q学习 gzroy的博客 CSDN博客 实现了用Tensorflow keras搭建DQN模型 解决小车上山问题 在代码里面 需要自己实现经验回放 采样等过程 比较繁琐 Tensorflow里面有一个a