强化学习原理与python实现原理pdf_深度强化学习笔记——DQN原理与实现(pytorch+gym)...

2023-11-12

概要

本文主要总结深度强化学习中无模型基于值方法的DQN算法,说明其算法原理并用该算法在gym提供的cartpole上进行实现。

有任何不准确或错误的地方望指正!

1. DQN(Deep Q-Network)基本原理

DQN算法相当于对传统Q-learning算法的改进,与之不同的是,DQN使用了神经网络(结构可以自行设计)对action value(即Q值)进行估计。

1.1 DQN算法的基本组成元素

DQN的伪代码如下,从中可以看出几个关键步骤:目标网络,

-greedy选择动作和经验重放机制。 (摘自《Human-level control through deep reinforcement learning》)

A. 目标网络(Target Network)

首先,目标网络解决的是一个「回归问题」(与分类问题中网络产生一个分布不同),其输入是环境的状态,输出是多个动作产生的不同值,也就是动作值。(实际过程中,我们需要通过索引来获取这个Q值,即

,这里的
代表网络中的参数)。确定了网络的输入输出之后,就需要解决如何更新网络中的这些参数的问题。

其思路就是基于贝尔曼方程,并利用temporal difference的方法,让target network和用于训练的网络(这里就简记为agent网络)的差值尽可能近似于收益值,该收益值指的是从当前状态经过决策之后到达下一个状态所获取的收益。需要注明的是,DQN中的target network的参数就是直接拷贝agent网络的参数,使用的是一样的网络结构。但是在实际训练中,只能通过固定target network的输出来训练agent,而固定该网络的输出的方法就是延迟更新target network的参数,使其在固定步骤内输出不变,这样能够有效化agent网络的参数更新过程。

B. 经验重放(Experience Replay Buffer)

如果agent每次更新参数的时候都要与环境互动,这就大大降低了模型参数更新的效率,所以经验重放机制被提出。该机制就类似于一个有固定空间大小的存储器,把agent与环境互动所产生的部分结果(

)进行存储,其每一行的维数就是
(每次只能选取一个动作,得到的收益值也是一个标量)。等到了训练阶段的时候,每一次训练过程都会从该存储器中
「均匀采样」出一批 (batch) 数量的样本(总量远小于存储器的最大容量),用于agent网络模型参数的更新。

C.

-greedy(策略的选择)

Q-learning中策略的选择(假设这里是确定性策略)就是选取能够使动作值达到最大的那个动作,用数学形式表示就是:

-greedy方法是贪心算法的一个变体。具体实现的方法就是先让程序由均匀分布生成一个
区间内的随机数,如果该数值小于预设的
,则选取能够最大化动作值的动作,否则随机选取动作。

1.2 常用的提升DQN算法的技巧

A. Double DQN

DQN的实践过程中会出现一些问题,比如高估了动作值(overestimation),这时候研究人员就提出了Double DQN的技术。从下图可以看出,原先的DQN选用的target值其实还是由同一个网络生成的值,只是说这个网络所选用的参数是之前的参数。而Double DQN中将target的值做了小的改变,能够达到它是由“两个网络”生成的效果。从第二行的表达式可以看出,尽管这里依旧用的是agent含有旧参数的网络,但是这里的动作索引是通过agent当前参数网络得到的,取得该值的方法就是最大化agent当前参数的网络所输出的动作值(其输入值是环境返回的下一个状态),显然这样就解耦了动作的选取和动作值的计算,动作的选取(产生的是一系列大小为(batch_size, 1)的索引)是由新参数的agent网络获取,动作值的估算是由旧参数的agent网络所得到。

B. Dueling DQN

Dueling DQN最重要的一点就是改进了DQN中的网络结构,将Q值

拆分成状态值
和优势函数(Advantage Function)
。该方法能够更有效率地对Q值进行更新,因为每一次V值更新之后,都要加在A函数的所有维度上(相当于一个bias),相当于其他动作的值也同时被更新了。

C. Prioritized Experience Replay

与经验重放机制不同的是,该技巧将有主次地对存储器中的经验进行采样,使参数更新过程更有效率。这个priority是基于target值与当前agent网络输出的差值(既为TD error),该误差越大,那么产生这个较大误差所对应的经验(

)就有更高的概率被采样到。这里先把论文中《Prioritized Experience Replay》的伪代码放在这里,等实践完策略网络之后再来仔细学习一下~

2. DQN的pytorch实现

2.1 所需要的环境配置

gym

windows下gym的安装非常简单,conda activate到某个环境下使用pip install gym安装即可。

pytorch

windows下的快速pytorch安装可以参考我的这篇博客(简单来说就是找到版本所对应的.whl文件,然后本地进行pip install的安装)

2.2 DQN代码及详细注释

代码部分我就按莫烦pytorch教程重新码了一遍,并详细写明了注释(比如比较重要的变量的维度等)

基本流程就是:(第一步导入相关功能包就省略不写了)

  1. 定义超参数(batch_size,learning rate,discount等)
  2. 构建用于agent和target的网络架构(全连接或卷积或循环神经网络等)
  3. 开始构建DQN算法(初始化memory空间,定义损失函数和优化器,神经网络中的参数初始化;根据gym环境返回的状态信息选择动作,将得到的收益值和下一个状态的信息存储起来;对memory中的experiences进行采样,对agent网络参数进行更新,将agent网络输出的q值与目标值比较产生的均方差作为损失用梯度下降(Adam优化器)进行反向传播;在固定步数之后更新target网络。

(这里gym返回的「状态信息」是位置、角度和对应的速度和角速度,维度是4;其返回的动作只有两个,左或右)

# -*- coding: utf-8 -*-
# import the necessary packages
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import gym

# 1. Define some Hyper Parameters
BATCH_SIZE = 32     # batch size of sampling process from buffer
LR = 0.01           # learning rate
EPSILON = 0.9       # epsilon used for epsilon greedy approach
GAMMA = 0.9         # discount factor
TARGET_NETWORK_REPLACE_FREQ = 100       # How frequently target netowrk updates
MEMORY_CAPACITY = 2000                  # The capacity of experience replay buffer

env = gym.make("CartPole-v0") # Use cartpole game as environment
env = env.unwrapped
N_ACTIONS = env.action_space.n  # 2 actions
N_STATES = env.observation_space.shape[0] # 4 states
ENV_A_SHAPE = 0 if isinstance(env.action_space.sample(), int) else env.action_space.sample().shape     # to confirm the shape

# 2. Define the network used in both target net and the net for training
class Net(nn.Module):
    def __init__(self):
        # Define the network structure, a very simple fully connected network
        super(Net, self).__init__()
        # Define the structure of fully connected network
        self.fc1 = nn.Linear(N_STATES, 10)  # layer 1
        self.fc1.weight.data.normal_(0, 0.1) # in-place initilization of weights of fc1
        self.out = nn.Linear(10, N_ACTIONS) # layer 2
        self.out.weight.data.normal_(0, 0.1) # in-place initilization of weights of fc2
        
        
    def forward(self, x):
        # Define how the input data pass inside the network
        x = self.fc1(x)
        x = F.relu(x)
        actions_value = self.out(x)
        return actions_value
        
# 3. Define the DQN network and its corresponding methods
class DQN(object):
    def __init__(self):
        # -----------Define 2 networks (target and training)------#
        self.eval_net, self.target_net = Net(), Net()
        # Define counter, memory size and loss function
        self.learn_step_counter = 0 # count the steps of learning process
        self.memory_counter = 0 # counter used for experience replay buffer
        
        # ----Define the memory (or the buffer), allocate some space to it. The number 
        # of columns depends on 4 elements, s, a, r, s_, the total is N_STATES*2 + 2---#
        self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) 
        
        #------- Define the optimizer------#
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
        
        # ------Define the loss function-----#
        self.loss_func = nn.MSELoss()
        
    def  choose_action(self, x):
        # This function is used to make decision based upon epsilon greedy
        
        x = torch.unsqueeze(torch.FloatTensor(x), 0) # add 1 dimension to input state x
        # input only one sample
        if np.random.uniform() < EPSILON:   # greedy
            # use epsilon-greedy approach to take action
            actions_value = self.eval_net.forward(x)
            #print(torch.max(actions_value, 1)) 
            # torch.max() returns a tensor composed of max value along the axis=dim and corresponding index
            # what we need is the index in this function, representing the action of cart.
            action = torch.max(actions_value, 1)[1].data.numpy()
            action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)  # return the argmax index
        else:   # random
            action = np.random.randint(0, N_ACTIONS)
            action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)
        return action
    
        
    def store_transition(self, s, a, r, s_):
        # This function acts as experience replay buffer        
        transition = np.hstack((s, [a, r], s_)) # horizontally stack these vectors
        # if the capacity is full, then use index to replace the old memory with new one
        index = self.memory_counter % MEMORY_CAPACITY
        self.memory[index, :] = transition
        self.memory_counter += 1
        
    
    def learn(self):
        # Define how the whole DQN works including sampling batch of experiences,
        # when and how to update parameters of target network, and how to implement
        # backward propagation.
        
        # update the target network every fixed steps
        if self.learn_step_counter % TARGET_NETWORK_REPLACE_FREQ == 0:
            # Assign the parameters of eval_net to target_net
            self.target_net.load_state_dict(self.eval_net.state_dict())
        self.learn_step_counter += 1
        
        # Determine the index of Sampled batch from buffer
        sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE) # randomly select some data from buffer
        # extract experiences of batch size from buffer.
        b_memory = self.memory[sample_index, :]
        # extract vectors or matrices s,a,r,s_ from batch memory and convert these to torch Variables
        # that are convenient to back propagation
        b_s = Variable(torch.FloatTensor(b_memory[:, :N_STATES]))
        # convert long int type to tensor
        b_a = Variable(torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int)))
        b_r = Variable(torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2]))
        b_s_ = Variable(torch.FloatTensor(b_memory[:, -N_STATES:]))
        
        # calculate the Q value of state-action pair
        q_eval = self.eval_net(b_s).gather(1, b_a) # (batch_size, 1)
        #print(q_eval)
        # calculate the q value of next state
        q_next = self.target_net(b_s_).detach() # detach from computational graph, don't back propagate
        # select the maximum q value
        #print(q_next)
        # q_next.max(1) returns the max value along the axis=1 and its corresponding index
        q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1) # (batch_size, 1)
        loss = self.loss_func(q_eval, q_target)
        
        self.optimizer.zero_grad() # reset the gradient to zero
        loss.backward()
        self.optimizer.step() # execute back propagation for one step
 
'''
--------------Procedures of DQN Algorithm------------------
'''
# create the object of DQN class
dqn = DQN()

# Start training
print("nCollecting experience...")
for i_episode in range(400):
    # play 400 episodes of cartpole game
    s = env.reset()
    ep_r = 0
    while True:
        env.render()
        # take action based on the current state
        a = dqn.choose_action(s)
        # obtain the reward and next state and some other information
        s_, r, done, info = env.step(a)
        
        # modify the reward based on the environment state
        x, x_dot, theta, theta_dot = s_
        r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8
        r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
        r = r1 + r2
        
        # store the transitions of states
        dqn.store_transition(s, a, r, s_)
        
        ep_r += r
        # if the experience repaly buffer is filled, DQN begins to learn or update
        # its parameters.       
        if dqn.memory_counter > MEMORY_CAPACITY:
            dqn.learn()
            if done:
                print('Ep: ', i_episode, ' |', 'Ep_r: ', round(ep_r, 2))
        
        if done:
            # if game is over, then skip the while loop.
            break
        # use next state to update the current state. 
        s = s_  

2.3 训练结果

这里总共要训练400个episodes, 在300左右的时候已经可以训练的很好了,这里录了一小段结果,如下。

---待更新---

(DQN中的常见tips的代码实现)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

强化学习原理与python实现原理pdf_深度强化学习笔记——DQN原理与实现(pytorch+gym)... 的相关文章

  • Windows7 Python3 搭建Scrapy 爬虫框架

    Windows7 64位 Python3 7 安装Scrapy 提示如下错误信息 解决办法 1 在python库中下载twisted相应的包 whl文件 官网地址 https www lfd uci edu gohlke pythonlib
  • Android底部导航栏的三种风格实现

    一 效果图展示 如果动图没有动的话 也可以看下面这个静态图 以下挨个分析每个的实现 这里只做简单的效果展示 大家可以基于目前代码做二次开发 二 BottomNavigationView 这是 Google 给我们提供的一个专门用于底部导航的
  • JWT 登录认证及 token 自动续期方案解读

    欢迎关注方志朋的博客 回复 666 获面试宝典 方志朋 号主为CSDN博客之星 博客访问量突破一千万 著有畅销书 深入理解SpringCloud与微服务构建 主要分享Java 后端架构等技术 用大厂程序员的视角来探讨技术进阶 面试指南 职业
  • 昨天看了一本c#的教程

    昨天看了一本c 的教程 昨天看了一本c 的教程 那是本很早前就买了的书 虽然也不是没看过 但是昨天重新看了下 感觉收获还是不小的 从c 的类型 到它的方法 还有就是面向对象的一些概念 覆盖 继承 我不敢说我学到了多少 但是我很喜欢 post
  • 2024年计算机专业毕业设计题目大全-吊炸天的2024届计算机毕业设计选题推荐参考

    作者 计算机源码社 个人简介 本人七年开发经验 擅长Java Python PHP NET 微信小程序 爬虫 大数据等 大家有这一块的问题可以一起交流 学习资料 程序开发 技术解答 文档报告 JavaWeb项目 微信小程序项目 Python
  • openwrt编译ipk包提示缺少feeds.mk文件

    问题具体表现如下 这个问题困扰了我两个多星期 总算解决了 解决方案如下 首先 先应该把配置菜单调好 我的硬件是7620a 要编译的ipk包为helloworld 所以应该使用 make menuconfig命令进入配置菜单 进入后 将1号框
  • TCP/IP基础&pysocket

    TCP IP基础 pysocket 1 网络简述 网络 计算机网络功能主要包括实现资源共享 实现数据信息的快速传递 网络协议 在网络数据传输中 都遵循的执行规则 规范 C S 服务器 Server 向客户端提供资源 保存客户端数据 处理客户
  • 【华为OD机试真题 python】支持优先级的队列【2023 Q2

    题目描述 支持优先级的队列 实现一个支持优先级的队列 高优先级先出队列 同优先级时先进先出 如果两个输入数据和优先级都相同 则后一个数据不入队列被丢弃 队列存储的数据内容是一个整数 输入描述 一组待存入队列的数据 包含内容和优先级 输出描述
  • PHP SQL实现公司数据库的增删改查

    文末附文件 题目要求 Use the following SQL DDL statements to create the six tables required for this project Note that you need to
  • python之celery

    Celery是由Python开发的一个简单 灵活 可靠的处理大量任务的分发系统 可以实时处理任务 也可以定时异步处理任务 每次分发任务后得到一个ID 然后根据这个ID查询任务执行情况 安装 pip install celery eventl
  • sqllabs详解与知识点汇总(内含代码审计)

    sqllabs 1 65 详解 关于注释符的详解 SQL注入注释符 使用条件及其他注释方式的探索 impulse 博客园 cnblogs com HTTP请求方法 GET 对比 POST HTTP 方法 GET 对比 POST 菜鸟教程 r
  • docker基本操作

    Docker官方建议在Ubuntu中安装 建议安装在CentOS7 X以上版本 1 安装Docker 1 yum包更新到最新 sudo yum update 2 安装需要的软件包 yum util提供yum config manager功能
  • java.math.BigDecimal用法

    Java在java math包中提供的API类BigDecimal 用来对超过16位有效位的数进行精确的运算 双精度浮点型变量double可以处理16位有效数 在实际应用中 需要对更大或者更小的数进行运算和处理 float和double只能
  • 继承和多态的内存图解

    今天被继承和多态困扰 在CSDN上找了好几个内存分配讲解 个人感觉不全吧 就把他们做了个整合 讲解的是多态的方法和成员调用和继承中的方法和变量的调用 什么是多态 同一个对象 在不同时刻表现出来的不同形态 多态的前提 要有继承或实现关系要有方
  • web robotframework xpath元素定位

    1 定位购买按钮 在这里 我写的是 td class text center button class ng isolate scope span text 购买 提示找不到元素 原因是button的class值 我把他改成class bt
  • 调试osgEarth(七)地图map图层的构建过程-添加layer(4)--打开ImageLayer

    继续调试 创建空影像 建了个1x1x1的空图片 这个也比较简单 ImageLayer建立了一个1x1x1的空图片
  • spring boot 2.x 应用可视化监控

    来源 简书 内容 应用可视化监控 prometheus grafana https www jianshu com p 7ecb57a3f326 修改为spring boot 2 0时 1 首先 添加依赖如下依赖
  • E: Unable to locate package kubelet 解决

    昨天搭建k8s集群环境时 安装报错 显示无法找到 1 打开vim etc apt sources list 写入阿里云的源 deb https mirrors aliyun com kubernetes apt kubernetes xen
  • aiVMS----CentOS7.6安装RabbitMQ安装

    entOS7 6安装RabbitMQ安装 安装一 快速的安装方法是使用Package Cloud提供的脚本 Package Cloud也可以用于通过yum安装最新的Erlang版本 使用PackageCloud安装RabbitMQ 官网参考

随机推荐

  • table问题总结

    前景 最近开发需要原生table 之前使用很少用 了解比较少 这次对于样式和功能要求也比较高 对与遇到的问题做下总结和分享 问题与解决方案 行高不定问题 描述 表格每一行的高度不确定 会自动适配 设置行高和高度均无效 产生原因 表格设置了固
  • R语言用ROCR包出现载入程辑包:‘gplots’ The following object is masked from ‘package:stats’错误

    谢谢点进来 如果你觉得有帮助 麻烦点个赞 假如在R studio运行的代码是这样的 library ROCR 首先看到这个问题的时候 我认为没有安装gplots包 可以按下图所示看是否有该包 如果没有则点击install输入包名安装 奇怪的
  • Ledger of Harms

    Under immense pressure to prioritize engagement and growth technology platforms have created a race for human attention
  • JavaScript快速排序算法

  • C#单线程和多线程端口扫描器

    C 单线程和多线程端口扫描器 一 项目创建以及页面设计 一 项目新建 二 页面设计 二 单线程实现端口扫描 一 代码实现 二 运行结果 三 多线程实现端口扫描 一 程序实现 二 运行结果 四 总结 五 参考资料 一 项目创建以及页面设计 一
  • JCenter下载太慢?教你修改Maven仓库地址为国内镜像

    转载自 http www yrom net blog 2015 02 07 change gradle maven repo url 近来迁移了一些项目到Android Studio 采用Gradle构建确实比原来的Ant方便许多 但是编译
  • StyleCLIP学习笔记

    https github com orpatashnik StyleCLIP The main inferece script is placed in mapper scripts inference py Inference argum
  • 安装librocksdb.so.4.1的共享库

    安装librocksdb so 4 1的共享库 注 以下命令需在root模式下进行 1 clone rocksDB 命令行运行git clone https github com facebook rocksdb git 2 切换到4 1
  • Java调试原理初探

    对于所有程序员 程序调试是一项必备的技能 在java程序中 最简单的就是通过 System out println 来打印输出各种变量来发现问题 而用的最多的莫过于通过各种调试器来进行调试 如图一所示的eclipse调试器 甚至还可以进行远
  • 微信号正则校验

    由于最近有朋友做微信开发 让我帮其找一个微信号正则校验 代码 本来以为网上会有很多 但一搜才发现 没有一个可用的校验微信号的正则 所以只好自己写一个了 废话不多说 直接贴结果 首先我们要明确微信号规则 微信账号仅支持6 20个字母 数字 下
  • linux内核分析笔记----内核同步

    内核同步讲的比较多了 我也就不太啰嗦了 先说一些概念 然后就是方法 同步就是避免并发和防止竞争条件 有关临界区的例子我就不举了 随便一本操作系统的书上都有 锁机制的提出也算解决了一些问题 我们待会再说 现在只要知道锁的使用是自愿的 非强制的
  • 【机器学习】鸢尾花Iris数据集进行线性分类

    目录 一 实验准备 二 线性分类 1 原始数据 2 训练模型 3 绘制决策边界 4 设置参数C 三 鸢尾花数据集分类 1 取萼片的长宽作特征分类 2 取花瓣的长宽作特征分类 四 参考 一 实验准备 安装python3 6 3 7 Anaco
  • crypto++加密算法库的编译和在项目中的使用

    简述 Crypto Library是一个免费的C 类加密方案库 该库包含以下算法 算法 名称 认证的加密方案 GCM CCM EAX 高速流密码 ChaCha 8 12 20 Panama Sosemanuk Salsa20 8 12 20
  • QT Modbus RTU调试助手(包含算法实现CRC MODBUS16校验)

    QT Modbus RTU调试助手 在类构造函数中将UI初始化和串口对象定义以及查找串口 串口设置 串口接受 QT延时函数 CRC校验 发送串口数据函数 总结 在类构造函数中将UI初始化和串口对象定义以及查找串口 foreach const
  • ElementUI el-table组件 树形数据不对齐的解决方案

    ElementPlus的el table组件在展示树状数据时 左侧的展开小箭头在部分情况下会导致第一列数据起始位置不对齐 添加一段css即可解决 环境 Vue3 0 Element Plus 1 0 2 beta 55 先看默认效果 效果图
  • Exception starting filter struts2 java.lang.NullPointerException 解决方法

  • springboot线程池ThreadPoolTaskExecutor使用

    https mp weixin qq com s 3DRBX9Wb OA NIfPXZjcw 前言 程池ThreadPoolExecutor 而用的是Spring Boot项目 可以用Spring提供的对ThreadPoolExecutor
  • Github搭建个人博客(2019最新版,亲测)

    版权声明 本文为徐代龙原创文章 未经徐代龙允许不得转载 https blog csdn net xudailong blog article details 78762262 敲黑板 如何写一个自己的小程序并上线 一 前言 建议 慢慢看 也
  • Windows Java环境变量设置 & Maven环境变量设置 & 常用环境问题设置

    Windows Java环境变量设置 Maven环境变量设置 常用环境问题设置 1 Java环境变量设置 Java8环境变量设置 1 进入环境变量设置界面 我的电脑 gt 属性 gt 高级系统设置 gt 环境变量 2 创建系统变量JAVA
  • 强化学习原理与python实现原理pdf_深度强化学习笔记——DQN原理与实现(pytorch+gym)...

    概要 本文主要总结深度强化学习中无模型基于值方法的DQN算法 说明其算法原理并用该算法在gym提供的cartpole上进行实现 有任何不准确或错误的地方望指正 1 DQN Deep Q Network 基本原理 DQN算法相当于对传统Q l