强化学习入门DQN详解

2023-05-16

Deep Q Network

参考资料:

B站莫烦:https://www.bilibili.com/video/BV13W411Y75P/?spm_id_from=333.337.search-card.all.click&vd_source=a8e8676617fb04db42af59b530b145fd

github(tensorflow):https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow

pytorch版本:https://github.com/ClownW/Reinforcement-learning-with-PyTorch

1、先验知识:

(1)Q-learning

    Q-learning使用Q表存储状态-动作-价值对应关系。

    用这种办法进行尝试遍历制作Q表再根据损失函数进行训练逼近,虽然非常直观但是当状态与动作种类非常多时,Q表将会非常巨大,不利于计算效率的提升甚至会无法完成。

    Deep-Qlearning引入神经网络通过训练得到Q值,将状态和动作组合作为神经网络的输入,无需构造Q表,借助神经网络生成相应的Q值。也可以指输入状态,通过神经网络生成采取相应动作后对应的价值,直接通过强化学习选取价值最大的动作。

(2)Deep Learning

  那么在强化学习中我们是如何训练神经网络的呢?

在这里插入图片描述

新 N N = 老 N N + α ∗ ( Q 现实 − Q 估计 ) 新NN=老NN+\alpha*(Q现实-Q估计) NN=NN+α(Q现实Q估计)

(3)Experience replay and fixed Q target

经验回放机制,DQN是一种离线学习法,它存在一个记忆库,用于存放之前的经历。训练的时候会随机抽取一些之前的经历,以此打乱经历之间的相关性使神经网络的更新更有效率。

​ Fixed Q target,在DQN中使用两个结构相同但是参数不同的神经网络,用一个不断更新的网络来预测Q估计的值,用一个比较老的神经网络来预测Q现实的值,这样可以加快算法的收敛速度,达到更好的效果。

2、 DQN算法更新循环思路

​    关键思路:DQN采用经验回放机制,所以一开始先不学习,首先建造一个记忆库,这里用step记录,当step大于200以后,再每五步学习一次(RL.learn)。每个回合的更新流程是

①从环境导入智能体此时的状态

②更新环境,根据e-greedy策略选择动作

③得到并存储该动作产生的下一状态以及获得的奖励(存储很重要,在RL-brain中编写)

④更新状态

⑤判断状态是否是最终状态,到达最终状态跳出循环

def run_maze():
	step = 0
	for episode in range(300):
		print("episode: {}".format(episode))
		observation = env.reset()
		while True:
			print("step: {}".format(step))
			env.render()
			action = RL.choose_action(observation)
			observation_, reward, done = env.step(action)
			RL.store_transition(observation, action, reward, observation_)
			if (step>200) and (step%5==0):
				RL.learn()
			observation = observation_
			if done:
				break
			step += 1
	print('game over')
	env.destroy()                  

3、神经网络搭建

    由于采用fixed Q target方法,我们需要搭建两个结构相同而参数不同的神经网络(这里分别记为eval和target,eval不断更新,target保持过去的参数),将两个网络得到的结果比较,加快算法的收敛速度,pytorch版的代码如下。这里莫烦老师教程采用的是tensorflow,可以在github获取。

(1)定义网络结构

class Net(nn.Module):
	def __init__(self, n_feature, n_hidden, n_output):
		super(Net, self).__init__()
		self.el = nn.Linear(n_feature, n_hidden)
		self.q = nn.Linear(n_hidden, n_output)

	def forward(self, x):
		x = self.el(x)
		x = F.relu(x)
		x = self.q(x)
		return x

(2)拿出两个网络中的Q值

def _build_net(self):
   self.q_eval = Net(self.n_features, self.n_hidden, self.n_actions)
   self.q_target = Net(self.n_features, self.n_hidden, self.n_actions)
   self.optimizer = torch.optim.RMSprop(self.q_eval.parameters(), lr=self.lr)

4、基本参数定义(DQN初始化部分)

​    定义并初始化可选动作、状态特征、神经网络层数、学习率、折扣因子、神经网络更新周期、记忆库大小、抽取样本大小等强化学习需要的基本参数。并设置learn_step_counter来记录目前学习的步数,方便观察结果和测试。
    初始化记忆库(pandas),我们首先将记忆库初始化为全零矩阵,高度是记忆库的大小memory_size。其宽度就是每条记忆的大小取决于run-maze中定义的结构,一共有几个变量需要存储。在这里就是RL.store_transition(observation, action, reward, observation_)结构,这里两个状态observation、observation_加上action和reward所以数据长度是n_features*2+2,注意这里action虽然上下左右四个选项,但是每次只选取一个。
    这个顺序也比较关键,在后面抽取数据的时候要记住上一步和下一步的状态分别在哪个位置,可以通过切片索引的方法拿到需要的数据。

class DeepQNetwork():
	def __init__(self, n_actions, n_features, n_hidden=20, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9,
				replace_target_iter=200, memory_size=500, batch_size=32, e_greedy_increment=None,
				):
		self.n_actions = n_actions
		self.n_features = n_features
		self.n_hidden = n_hidden
		self.lr = learning_rate
		self.gamma = reward_decay
		self.epsilon_max = e_greedy
		self.replace_target_iter = replace_target_iter
		self.memory_size = memory_size
		self.batch_size = batch_size
		self.epsilon_increment = e_greedy_increment
		self.epsilon = 0 if e_greedy_increment is not None else self.epsilon_max

		# total learning step
		self.learn_step_counter = 0

		# initialize zero memory [s, a, r, s_]
		self.memory = np.zeros((self.memory_size, n_features*2+2))

		self.loss_func = nn.MSELoss()
		self.cost_his = []

		self._build_net()

5、存储记忆store_transition

​     定义存储记忆函数,这里的输入参数分别是当前状态s,选取动作a,获得奖励r,下一状态s_。

     通过self.memory_counter进行索引,找到记忆库的某一行,再插入此时的这条数据。当counter到记忆库大小时,归零,返回上去,从头开始按条替换,保证记忆库是最近的数据,循环往复,不断覆盖更新。

	def store_transition(self, s, a, r, s_):
		if not hasattr(self, 'memory_counter'):
			self.memory_counter = 0
		transition = np.hstack((s, [a, r], s_))
		# replace the old memory with new memory
		index = self.memory_counter % self.memory_size
		self.memory[index, :] = transition 
		self.memory_counter += 1

6、动作选择部分

​    将当前状态observation输入神经网络,输出在该状态下选取每一个动作的Q值,根据e_greedy策略,在90%(可任意设置)情况下选取Q值最大的动作,剩下10%随机探索,避免局部最优。

	def choose_action(self, observation):
		observation = torch.Tensor(observation[np.newaxis, :])
		if np.random.uniform() < self.epsilon:
			actions_value = self.q_eval(observation)

			action = np.argmax(actions_value.data.numpy())
		else:
			action = np.random.randint(0, self.n_actions)
		return action

7、fixed Q target(神经网络更新)

​    该部分比较简单,在开始训练之前,我们第一步要判断此时的神经网络是否需要更新。target_net中的各参数是冻结的(Fixed Q target),当到达一个更新周期(replace_target_iter)后,将eval网络中的参数传给target网络,以此来实现网络的更新,这样可以避免神经网络的参数一直变化,不稳定,收敛困难的问题,加速其收敛速度。

		# check to replace target parameters
		if self.learn_step_counter % self.replace_target_iter == 0:
			self.q_target.load_state_dict(self.q_eval.state_dict())
			print("\ntarget params replaced\n")

8、训练过程

    神经网络训练时,每一次参数的更新所需要损失函数并不是由一个单独的数据{data:label}获得的,而是由一组数据加权得到的,这一组数据的数量就是[batch size]。我们用随机梯度下降法对数据集进行批量训练,每次随机抽取记忆库中的样本,这里设置当抽取数量大于记忆库大小时,即抽取已经存入的所有记忆。

def learn(self):
   # check to replace target parameters
   if self.learn_step_counter % self.replace_target_iter == 0:
      self.q_target.load_state_dict(self.q_eval.state_dict())
      print("\ntarget params replaced\n")

   # sample batch memory from all memory
   if self.memory_counter > self.memory_size:
      sample_index = np.random.choice(self.memory_size, size=self.batch_size)
   else:
      sample_index = np.random.choice(self.memory_counter, size=self.batch_size)
   batch_memory = self.memory[sample_index, :]

    接下来运行两个神经网络,分别得到他们的估计值,冻结网络target的输出作为Q现实,eval网络输出值作为Q估计,从而实现迭代更新。切片索引时注意这里target网络的输入是下一状态,eval网络输入是当前状态。

# q_next is used for getting which action would be choosed by target network in state s_(t+1)
		q_next, q_eval = self.q_target(torch.Tensor(batch_memory[:, -self.n_features:])), self.q_eval(torch.Tensor(batch_memory[:, :self.n_features]))
		# used for calculating y, we need to copy for q_eval because this operation could keep the Q_value that has not been selected unchanged,
		# so when we do q_target - q_eval, these Q_value become zero and wouldn't affect the calculation of the loss 
		q_target = torch.Tensor(q_eval.data.numpy().copy())

    最后计算损失、反向传播、更新参数,这里计算损失比较关键,由于两个Q值都是二维数组,如果直接矩阵相减,对应关系不正确,需要首先找到eval中的状态和对应动作,再去找到target中的对应值,才可以进行相减计算损失。

	batch_index = np.arange(self.batch_size, dtype=np.int32)
	eval_act_index = batch_memory[:, self.n_features].astype(int)
	reward = torch.Tensor(batch_memory[:, self.n_features+1])
	q_target[batch_index, eval_act_index] = reward + self.gamma*torch.max(q_next, 1)[0]

		loss = self.loss_func(q_eval, q_target)
		self.optimizer.zero_grad()
		loss.backward()
		self.optimizer.step()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

强化学习入门DQN详解 的相关文章

  • 分布式理论协议与算法 第二弹 ACID原则

    ACID 原则是在 1970年 被 Jim Gray 定义 xff0c 用以表示事务操作 xff1a 一个事务是指对数据库状态进行改变的一系列操作变成一个单个序列逻辑元操作 xff0c 数据库一般在启动时会提供事务机制 xff0c 包括事务
  • MongoDB:在 Java 中使用 MongoDB

    除了通过启动 mongo 进程进如 Shell 环境访问数据库外 xff0c MongoDB 还提供了其他基于编程语言的访问数据库方法 MongoDB 官方提供了 Java 语言的驱动包 xff0c 利用这些驱动包可使用多种编程方法来连接并
  • 分布式理论协议与算法 第三弹 BASE理论

    大部分人解释这 CAP 定律时 xff0c 常常简单的表述为 xff1a 一致性 可用性 分区容错性三者你只能同时达到其中两个 xff0c 不可能同时达到 实际上这是一个非常具有误导性质的说法 xff0c 而且在 CAP 理论诞生 12 年
  • Docker:独具魅力的开源容器引擎

    Docker 是一个开源的应用容器引擎 xff0c 让开发者可以打包他们的应用以及依赖包到一个可移植的镜像中 xff0c 然后发布到任何流行的 Linux 或 Windows操作系统的机器上 xff0c 也可以实现虚拟化 容器是完全使用沙箱
  • 在不同环境下 Docker 的安装部署

    本篇内容主要介绍了 xff1a Docker xff1a 不同环境下的安装部署 xff0c 包括 xff0c Docker 在 Centos7 下的安装 Docker 在 MacOS 下的安装 Docker 在 Windows 下的安装 以
  • Docker 应用实践-镜像篇

    一个 Docker 镜像往往是由多个镜像层 xff08 可读层 xff09 叠加而成 xff0c 每个层仅包含了前一层的差异部分 xff0c 单个镜像层也往往可以看作镜像使用 xff0c 当我们启动一个容器的时候 xff0c Docker
  • 如何通过限流算法防止系统过载

    限流算法 xff0c 顾名思义 xff0c 就是指对流量进行控制的算法 xff0c 因此也常被称为流控算法 我们在日常生活中 xff0c 就有很多限流的例子 xff0c 比如地铁站在早高峰的时候 xff0c 会利用围栏让乘客们有序排队 xf
  • Docker 应用实践-容器篇

    在 Docker 镜像篇中 xff0c 我们了解到 Docker 镜像类似于模板 xff0c 那么 Docker 容器就相当于从模板复制过来运行时的实例 xff0c Docker 容器可以被创建 复制 暂停和删除等 每一个 Docker 容
  • Java中Json字符串和Java对象的互转

    JSON xff08 JavaScript Object Notation xff09 是一种轻量级的数据交换格式 诞生于 2002 年 易于人阅读和编写 同时也易于机器解析和生成 JSON 是目前主流的前后端数据传输方式 JSON 采用完
  • 老板必看:1.初创业团队没有激情,咋办? 2.小股东的选择

    内容摘要 xff1a 本文有两个来自真实情况的案例 xff0c 因为涉及到 私隐 xff0c 部分内容经过处理 两个案例分别是 xff1a 1 xff09 新创业团队员工积极性差的问题 xff1b 2 xff09 小股东在两大股东的斗法中的
  • 持续事务管理过程中的事件驱动

    比较官方的定义 xff1a 事件驱动是指在持续事务管理过程中 xff0c 进行决策的一种策略 xff0c 即跟随当前时间点上出现的事件 xff0c 调动可用资源 xff0c 执行相关任务 xff0c 使不断出现的问题得以解决 xff0c 防
  • Docker 应用实践-仓库篇

    目前 Docker 官方维护了一个公共仓库 Docker Hub xff0c 用于查找和与团队共享容器镜像 xff0c 界上最大的容器镜像存储库 xff0c 拥有一系列内容源 xff0c 包括容器社区开发人员 开放源代码项目和独立软件供应商
  • 浅谈网络中接口幂等性设计问题

    所谓幂等性设计 xff0c 就是说 xff0c 一次和多次请求某一个资源应该具有同样的副作用 用数学的语言来表达就是 xff1a f x 61 f f x 在数学里 xff0c 幂等有两种主要的定义 在某二元运算下 xff0c 幂等元素是指
  • 分布式系统中的补偿机制设计问题

    我们知道 xff0c 应用系统在分布式的情况下 xff0c 在通信时会有着一个显著的问题 xff0c 即一个业务流程往往需要组合一组服务 xff0c 且单单一次通信可能会经过 DNS 服务 xff0c 网卡 交换机 路由器 负载均衡等设备
  • 关于基于标准库函数与基于HAL库函数的stm32编程方式的差异

    在之前的博客中 xff0c 我已经使用过通过标准库函数和HAL库函数对stm32进行编译工作 xff0c 在这篇博文里 xff0c 我将对之前的进行总结 关于标准库函数 由于stm32系列有着很多不同的芯片 xff0c 其所使用的寄存器也大
  • curl wget pip git-clone yum apt-get的区别

    在linux中 xff0c 会常用到这些命令进行文件下载 xff0c 软件安装以及url访问 xff0c 但总是分不清楚什么时候用什么命令去下载或者安装和访问 这里将这几个命令的用法和区别进行一个说明 xff0c 方便大家学习和记忆 1 首
  • CSS 三种样式

    本节我们要学习一下 CSS 样式的几种形式 xff0c 在实际应用中向 HTML 中引入 CSS 样式的方法有三种 xff0c 分别是行内样式 内部样式 外部样式 我们会依次学习这三种方式的优缺点以及应用场景 xff0c 本节我们先来讲一下
  • JavaFx-报错WindowsNativeRunloopThread

    问题 解决办法 需要卸载掉JDK1 8 并且将环境变量中的 34 JAVA HOME 34 指向改成JDK11的目录 点赞 收藏 关注 便于以后复习和收到最新内容 有其他问题在评论区讨论 或者私信我 收到会在第一时间回复 在本博客学习的技术
  • Gradle-JDK版本问题导致运行失败

    问题 解决办法 因为当前我们使用jdk8去运行Gradle 但是Gradle意思是必须使用11 43 的jdk版本 下面这个问题就是因为 我们默认是 改为 点赞 收藏 关注 便于以后复习和收到最新内容 有其他问题在评论区讨论 或者私信我 收
  • 权力的游戏,我是小股东,咋办?

    案例简述 xff1a 某初创业公司 xff0c 有A B两个大股东 xff0c 股份份额一样大 xff0c 另外还有一个小股东C A股东负责市场和销售 xff0c B股东负责研发和技术 xff0c B曾经是C的上司 xff0c 将C带入公司

随机推荐

  • Java-高版本没有jre的问题

    解决方案 jre 文件夹是可以用命令自动生成的 xff0c 在window环境中 xff0c 进入jdk目录所在的文件夹 xff0c 运行下面命令就会自动 生成jre文件夹 bin span class token punctuation
  • Java-ForkJoinPool(线程池-工作窃取算法)

    文章目录 概述工作窃取算法工作窃取算法的优缺点使用 ForkJoinPool 进行分叉和合并ForkJoinPool使用RecursiveActionRecursiveTask Fork Join 案例Demo 概述 Fork 就是把一个大
  • JavaFx-缺少JavaFX运行时组件,需要这些组件才能运行此应用程序

    问题 报错 缺少JavaFX运行时组件 需要这些组件才能运行此应用程序 解决办法 解决办法额外添加一个类似启动类的java文件 然后将需要启动的文件以class添加到launch里就行 span class token keyword pu
  • Mysql-解决Truncated incorrect DOUBLE value xxx

    问题 出现这种问题一般来说就是多表操作的时候 使用的字段类型不一致导致的 查询除外 我们来看下真实案例 在hd user表中parentId是binint类型 而在hd user increment copy1 96 表中parentId是
  • Mysql-解决创建存储函数This function has none of DETERMINISTIC

    问题 当二进制日志启用后 xff0c 这个变量就会启用 它控制是否可以信任存储函数创建者 xff0c 不会创建写入二进制日志引起不安全事件的存储函数 如果设置为0 xff08 默认值 xff09 xff0c 用户不得创建或修改存储函数 xf
  • JPA-ids for this class must be manually assigned before calling save (使用数据库的自增)

    问题 Spring Data JPA ids for this class must be manually assigned before calling save id的生成错误 xff0c 在调用 save 方法之前 xff0c 必须
  • Java-gradle编译忽略警告

    使用gradle打包的时候出现好多警告 如何忽略大部分的警告呢 使用如下配置即可 tasks span class token punctuation span span class token function withType span
  • JPA-排除实体类里不存在于数据库的字段

    在实体类与数据库表建立映射关系时添加 64 Table 注解 当表中不存在实体类中的某个属性的时候 就需要用到 64 Transient 注解 如果不好使那么在 64 Transient基础上在添加 64 Column updatable
  • SpringBoot-快速搭建一套JPA

    文章目录 结构Mavenapplication yml实体类daoservicecontroller测试 结构 Maven span class token tag span class token tag span class token
  • IntelliJ IDEA-Gradle-SpringBoot搭建

    前提条件 IntelliJ IDEA Gradle教学 Gradle 全局镜像配置和优先使用Maven 将Gradle进行安装和配置 创建项目 配置项目设置 指定自己的gradle的安装位置 以及仓库位置 用户主目录 用户主目录 Gradl
  • 我的喜马拉雅FM开播啦!

  • SpringBoot-JAP-JpaSpecificationExecutor详解

    文章目录 SpringBoot JAP JpaSpecificationExecutor详解使用方法接口介绍自定义工厂 SpringBoot JAP JpaSpecificationExecutor详解 JpaSpecificationEx
  • SwitchHosts-快速切换Hosts

    SwitchHosts是一个管理 快速切换Hosts小工具 xff0c 开源软件 xff0c 一键切换Hosts配置 xff0c 非常实用 xff0c 高效 其主要功能特性包括 xff1a 下载地址 https github com old
  • Java-新年抽奖-消息自动化发送脚本

    我们公司7点半开年会 然后大约8点半开始抽奖抢 使用腾讯会议的方式进行发关键字消息然后截图方式抽奖 然而我还在地铁上 手速慢的我只抽到了3等奖小米耳机一个 然后我回家后迫不及待第一时间赶紧使用java写一个机器人脚本 疯狂发消息 一言难尽啊
  • Java多线程-CompletableFuture(链式)

    线程池这个大家都知道 xff0c 是为了提高效率 xff0c 可以类比生活 xff0c 如果开个店 xff0c 需要几个员工 xff0c 正常的操作都是雇佣员工 xff0c 而不是每天使用临时工 xff0c 这样用完就解雇掉 xff0c 对
  • Java-Javassist(字节码修改)

    文章目录 开篇Javassist 常用类Javassist 的使用依赖代码示例 如何实现类似 AOP 的功能 开篇 说起 AOP 小伙伴们肯定很熟悉 xff0c 无论是 JDK 动态代理或者是 CGLIB 等 xff0c 其底层都是通过操作
  • Java多线程-Pip管道

    管道的意思 就是向一个管子一样从一端到另一端 只支持单方向的数据传输 需要注意的不能在同一个线程使用管道否则会导致死锁的情况 发生和接收必须在不同线程 通过使用管道 xff0c 实现不同线程间的通信 xff0c 而无需借助于临时文件之类的东
  • 新版本代码自动生成(MybatisPlus-generator) 代码示例+问题解决

    虽然MybatisPlus官网上已经给出了新版本代码生成器的核心依赖和核心代码 xff0c 但对于没接触过的小伙伴还是比较困难上手 x1f62d xff0c 本文将展现如何使用MybatisPlus generator快速生成代码 目录 1
  • 虚拟机如何使用共享文件夹传文件

    项目场景 xff1a 在使用VMware平台 xff0c ubuntu操作系统时 xff0c ftp文件传输一直报错 问题描述 xff1a 尝试了多种 xff0c 更改电脑设置 xff0c 甚至重装虚拟机 xff0c 始终如下图报错 解决方
  • 强化学习入门DQN详解

    Deep Q Network 参考资料 xff1a B站莫烦 xff1a https www bilibili com video BV13W411Y75P spm id from 61 333 337 search card all cl