【强化学习】

2023-11-08

强化学习DQN


提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档


DQN算法的简介

提示:这里可以添加本文要记录的大概内容:

DQN算法可以看作是Q_learning算法的改进,可以用来解决连续动作和离散动作的场景,当场景中的状态数量级很大的时候,计算机中存储Q表格不现实,因此采用函数拟合的方法来估计Q值,即将复杂的Q值表格视作数据,使用一个参数化的Q*来拟合这些数据。
这种函数拟合的方法存在一定的精度损失,因此被称为近似方法,下面的例子是DQN用来解决连续状态下离散动作的问题,将连续的状态进行离散化。


提示:以下是本篇文章正文内容,下面案例可供参考

一、环境的介绍

构建强化学习的环境是非常重要的,其中主要的两个因素是智能体的状态和智能体的动作。
  使用的是 CartPole 环境,它的状态是连续的,动作是离散的,它的场景可以描述为:在车杆环境中,有一辆小车,智能体的任务是通过左右移动保持车上的杆竖直,若杆的倾斜度数过大,或者车子离初始位置左右的偏离程度过大,或者坚持时间到达 200 帧,则游戏结束。智能体的状态是一个维数为 4 的向量,每一维都是连续的,其动作是离散的,动作空间大小为 2。
在这里插入图片描述
在这里插入图片描述

在游戏中每坚持一帧,智能体能获得分数为 1 的奖励,坚持时间越长,则最后的分数越高,坚持 200 帧即可获得最高的分数。

二、DQN算法

1、DQN算法的关键技术

现在我们想在类似车杆的环境中得到动作价值函数,由于状态每一维度的值都是连续的,无法使用表格记录,因此一个常见的解决方法便是使用函数拟合(function approximation)的思想。由于神经网络具有强大的表达能力,因此我们可以用一个神经网络来表示函数Q。若动作是连续(无限)的,神经网络的输入是状态和动作,然后输出一个标量,表示在状态下采取动作能获得的价值。若动作是离散(有限)的,除了可以采取动作连续情况下的做法,我们还可以只将状态输入到神经网络中,使其同时输出每一个动作的Q值。通常 DQN(以及 Q-learning)只能处理动作离散的情况,因为在函数Q的更新过程中有MAXa这一操作。假设神经网络用来拟合函数w的参数是 ,即每一个状态s下所有可能动作a的值我们都能表示为Q(s,a)。我们将用于拟合函数Q函数的神经网络称为Q 网络,如图 7-2 所示。
在这里插入图片描述
Q_Learing 的更新规则:

在这里插入图片描述
在这里插入图片描述

我们可以将Q_Learing拓展到神经网络的形式–深度Q网络算法。由于DQN算法是离线的策略算法,因此可以使用贪婪搜索策略来平衡探索与利用,将收集到的数据存储起来,在后续的训练中使用。
DQN 中还有两个非常重要的模块——经验回放和目标网络,它们能够帮助 DQN 取得稳定、出色的性能。

2.DQN代码

接下来,我们就正式进入 DQN 算法的代码实践环节。我们采用的测试环境是 CartPole-v0,其状态空间相对简单,只有 4 个变量,因此网络结构的设计也相对简单:采用一层 128 个神经元的全连接并以 ReLU 作为激活函数。当遇到更复杂的诸如以图像作为输入的环境时,我们可以考虑采用深度卷积神经网络。
从 DQN 算法开始,我们将会用到rl_utils库,它包含一些专门为本书准备的函数,如绘制移动平均曲线、计算优势函数等,不同的算法可以一起使用这些函数。

2.1 导入库

代码如下(示例):

import random
import gym
import numpy as np
import collections
from tqdm import tqdm
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import rl_utils
2.2 定义类
class ReplayBuffer:
    ''' 经验回放池 '''
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)  # 队列,先进先出

    def add(self, state, action, reward, next_state, done):  # 将数据加入buffer
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):  # 从buffer中采样数据,数量为batch_size
        transitions = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transitions)
        return np.array(state), action, reward, np.array(next_state), done

    def size(self):  # 目前buffer中数据的数量
        return len(self.buffer)


class Qnet(torch.nn.Module):
    ''' 只有一层隐藏层的Q网络 '''
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(Qnet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))  # 隐藏层使用ReLU激活函数
        return self.fc2(x)


class DQN:
    ''' DQN算法 '''
    def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma,
                 epsilon, target_update, device):
        self.action_dim = action_dim
        self.q_net = Qnet(state_dim, hidden_dim,
                          self.action_dim).to(device)  # Q网络
        # 目标网络
        self.target_q_net = Qnet(state_dim, hidden_dim,
                                 self.action_dim).to(device)
        # 使用Adam优化器
        self.optimizer = torch.optim.Adam(self.q_net.parameters(),
                                          lr=learning_rate)
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # epsilon-贪婪策略
        self.target_update = target_update  # 目标网络更新频率
        self.count = 0  # 计数器,记录更新次数
        self.device = device

    def take_action(self, state):  # epsilon-贪婪策略采取动作
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.action_dim)
        else:
            state = torch.tensor([state], dtype=torch.float).to(self.device)
            action = self.q_net(state).argmax().item()
        return action

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)

        q_values = self.q_net(states).gather(1, actions)  # Q值
        # 下个状态的最大Q值
        max_next_q_values = self.target_q_net(next_states).max(1)[0].view(
            -1, 1)
        q_targets = rewards + self.gamma * max_next_q_values * (1 - dones
                                                                )  # TD误差目标
        dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))  # 均方误差损失函数
        self.optimizer.zero_grad()  # PyTorch中默认梯度会累积,这里需要显式将梯度置为0
        dqn_loss.backward()  # 反向传播更新参数
        self.optimizer.step()

        if self.count % self.target_update == 0:
            self.target_q_net.load_state_dict(
                self.q_net.state_dict())  # 更新目标网络
        self.count += 1


2.3 训练画图
lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
epsilon = 0.01
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

env_name = 'CartPole-v0'
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
env.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,
            target_update, device)

return_list = []
for i in range(10):
    with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
        for i_episode in range(int(num_episodes / 10)):
            episode_return = 0
            state = env.reset()
            done = False
            while not done:
                action = agent.take_action(state)
                next_state, reward, done, _ = env.step(action)
                replay_buffer.add(state, action, reward, next_state, done)
                state = next_state
                episode_return += reward
                # 当buffer数据的数量超过一定值后,才进行Q网络训练
                if replay_buffer.size() > minimal_size:
                    b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                    transition_dict = {
                        'states': b_s,
                        'actions': b_a,
                        'next_states': b_ns,
                        'rewards': b_r,
                        'dones': b_d
                    }
                    agent.update(transition_dict)
            return_list.append(episode_return)
            if (i_episode + 1) % 10 == 0:
                pbar.set_postfix({
                    'episode':
                    '%d' % (num_episodes / 10 * i + i_episode + 1),
                    'return':
                    '%.3f' % np.mean(return_list[-10:])
                })
            pbar.update(1)

episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('DQN on {}'.format(env_name))
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('DQN on {}'.format(env_name))
plt.show()

在这里插入图片描述


总结

提示:这里对文章进行总结:

本章讲解了 DQN 算法,其主要思想是用一个神经网络来表示最优策略的函数,然后利用 Q-learning 的思想进行参数更新。为了保证训练的稳定性和高效性,DQN 算法引入了经验回放和目标网络两大模块,使得算法在实际应用时能够取得更好的效果。在 2013 年的 NIPS 深度学习研讨会上,DeepMind 公司的研究团队发表了 DQN 论文,首次展示了这一直接通过卷积神经网络接受像素输入来玩转各种雅达利(Atari)游戏的强化学习算法,由此拉开了深度强化学习的序幕。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【强化学习】 的相关文章

随机推荐

  • JasperSoft Studio的使用(1)——软件介绍及创建空白模板

    最近工作中需要用到报表打印 像pdf中多个table的展示 一个List在新的一页中显示列头等 JasperSoft 正好可以满足这些需求 所以记录一下 用以分享 软件介绍 JasperSoft Studio是一个面向 JasperRepo
  • 华为Java社招面试经历详解【已拿到offer】

    这篇文章主要介绍了华为Java社招面试经历 详细记录了华为java面试的流程 相关面试题与参考答案 需要的朋友可以参考下 看看自己能答对多少 如果能回答70 的题目 就大胆去阿里以及各互联网公司试试身手吧 本篇建议大家收藏 备用 华为Jav
  • Qt Creator release版本进行调试

    一 背景 我们在进行性Qt Creator 进行开发时 想要调试代码 通常是只需要编译 debug 版本的程序 但对于很多大型项目 引用外部第三方库中难免只存在release版本的动态库 所以 当我们的程序进行debug 调试时 往往会编译
  • window零基础部署langchain-ChatGLM

    一 介绍 从0开始安装运行langchain ChatGLM 6b int4模型 主要是版本要配套不然特别容易报错 我的机器配置CPU是Intel Core i7 7700HQ CPU 2 80GHz 2 80 GHz GPU8G 二 相关
  • 腾讯 13 年,我所总结的Code Review终极大法

    关注并星标腾讯云开发者 每周1 鹅厂工程师带你审判技术 第3期 林强 Code Review 我都 CR 些什么 谚语曰 Talk Is Cheap Show Me The Code 知易行难 知行合一难 嘴里要讲出来总是轻松 把别人讲过的
  • 蓝桥杯C/C++省赛:颠倒的价牌

    目录 题目描述 思路分析 AC代码 题目描述 小李的店里专卖其它店中下架的样品电视机 可称为 样品电视专卖店 其标价都是4位数字 即千元不等 小李为了标价清晰 方便 使用了预制的类似数码管的标价签 只要用颜色笔涂数字就可以了 这种价牌有个特
  • Java之美[从菜鸟到高手演变]之设计模式四

    在阅读过程中有任何问题 请及时联系 egg 邮箱 xtfggef gmail com 微博 http weibo com xtfggef 转载请说明出处 http blog csdn net zhangerqing 其实每个设计模式都是很重
  • html表格标签使用与注意事项

    表格的基本标签 场景 在网页中以行 列的单元格的方式整齐展示和数据 如 学生成绩表 基本标签 标签名 作用 table 表格的整体 用于包含多个tr tr 表格的每行 用于包含多个td td 表格单元格 用于包含内容 注意事项 嵌套关系为
  • 微信小程序分包加载,分包加载的优势

    微信小程序分包加载 有时候我们的小程序太大 首次打开小程序的时候会比较慢 可以进行分包处理 按照功能的划分 拆分成几个分包 让用户在操作小程序的时候按需下载资源 用户在进入某些页面的时候才去下载相应的资源 加载这个功能对应的分包 使用分包可
  • springboot设置logback-spring.xml的加载路径

    springboot将应用程序打包成jar以后 默认是将logback spring xml放在jar包里面根路径下 图 如果我们需要springboot加载jar包外部的logback spring xml应该怎么做了 例如我们想加载与x
  • shuffle机制详解

    将map输出作为输入传递给reducer的过程称为shuffle Shuffle过程包含在Map和Reduce两端 map阶段大致过程为 写数据 分区 排序 将属于同一分区的输出合并一起写在磁盘上 每个map任务都有一个环形内存缓冲区用于存
  • 服务里面找不到MySQL

    今天在连接数据库时发现自己的数据库出现了问题 在命令窗口输入 net start mysql 命令 还是启动不了 发现在服务里面竟然没有mysql服务了 1 以管理员身份运行cmd 切换到mysql安装目录的bin路径下 2 运行命令 my
  • C++构造函数简单实现电梯控制程序

    对于电梯 属性之一就是位置 所以要实现这一程序 要设置电梯的初始位置和按下电梯按钮改变的电梯的位置 代码如下 include
  • 【Linux】利用云服务器搭建云盘替代百度网盘、OneDrive等,docker安装seafile服务端,实现网页端上传下载,本地Linux、Windows安装客户端实时同步

    写在前面 博主使用OneDrive比较多 教育版有1t的大小 但是由于OneDrive在Linux系统中通过API不能连接学校的教育版 因此迫切需要一个云盘来替代OneDrive 由于之前也使用过Seafile 因此考虑使用Seafile搭
  • 编辑器正则替换px为rem

    正则部分 d d px 被替换部分 calc 1rem 100 注 此方法只能替换原css文件内无calc 运算的
  • 关于Unicode,UTF-8,GB编码详解

    内容来自网络 有部分修正 一 首先我们需要明白关于字符 character 字符集 character set 字符编码方式 character encoding 的概念 字符 字符是抽象的最小文本单位 它没有固定的形状 可能是一个字形 而
  • [901]sqlite数据库的导出与导入

    文章目录 SQLite 获取所有表名 通过 sqlite3 test db 命令进入sqlite数据库的shell 操作 python 脚本 help 直接导出csv文件 SQLite 仅仅支持 ALTER TABLE 语句的一部分功能 我
  • ansible常用模块使用方法

    ansible playbook执行方法 这个是你选择的主机 hosts webservers 这个是变量 vars http port 80 max clients 200 远端的执行权限 remote user root tasks 如
  • 实战技术产品经理

    文章转自 人人都是产品经理 并不代表企业实战 工具使用 办公工具的使用比如AXURE OFFICE 云笔记 PS等 决定办公效率 系统熟练 对后端数据及前端设计规范的了解程度 决定验收能力和设计合理度 沟通表达 对开发跟进及资源争取方面的推
  • 【强化学习】

    强化学习DQN 提示 写完文章后 目录可以自动生成 如何生成可参考右边的帮助文档 文章目录 强化学习DQN DQN算法的简介 一 环境的介绍 二 DQN算法 1 DQN算法的关键技术 2 DQN代码 2 1 导入库 2 2 定义类 2 3