深度强化学习-DQN算法

2023-05-16

论文地址:https://arxiv.org/abs/1312.5602

        先讲下在线,离线,同策略和异策略

        同策略(on-policy)和异策略(off-policy)的根本区别在于生成样本的策略和参数更新时的策略是否相同。

        对于同策略,行为策略和要优化的策略是同一策略,更新了策略后,就用该策略的最新版本对数据进行采样;对于异策略,其使用任意行为策略来对数据进行采样,并利用其更新目标策略。例如, Q 学习在计算下一状态的预期奖励时使用了最大化操作,直接选择最优动作,而当前策略并不一定能选择到最优的动作,因此这里生成样本的策略和学习时的策略不同,所以 Q 学习算法是异策略算法;相对应的 Sarsa 算法则是基于当前的策略直接执行一次动作选择,然后用动作和对应的状态更新当前的策略,因此生成样本的策略和学习时的策略相同,所以 Sarsa 算法为同策略算法。

深度 Q 网络和 Q 学习异同点

        整体来说,两者的目标价值以及价值的更新方式基本相同。但有如下不同点:

1)深度 Q 网络将 Q 学习与深度学习结合,用深度网络来近似动作价值函数,而 Q 学习则是采用表格进行存储。

2)深度 Q 网络采用了经验回放的技巧,从历史数据中随机采样,而 Q 学习直接采用下一个状态的数据进行学习。

深度 Q 网络中的两个技巧——目标网络和经验回放

(1)在深度 Q 网络中某个动作价值函数的更新依赖于其他动作价值函数。如果我们一直更新价值网络的参数,会导致更新目标不断变化,也就是我们在追逐一个不断变化的目标,这样势必会不太稳定。为了解决基于时序差分的网络中,优化目标 Qπ (st, at) = rt + Qπ (st+1, π (st+1)) 左右两侧会同时变化使得训练过程不稳定,从而增大回归难度,目标网络选择将优化目标的右边即 rt + Qπ (st+1, π (st+1)) 固定,通过改变优化目标左边的网络参数进行回归。

(2)对于经验回放,其会构建一个回放缓冲区,用来保存数据,每一个数据的内容包括:状态 st、采取的动作 at、得到的奖励 rt、下一个状态 st+1。我们使用 π 与环境交互多次,把收集到的数据都放到回放缓冲区中。防止占用过多的内存,当回放缓冲区装满后,就会自动删去最早进入缓冲区的数据。在训练时,对于每一轮迭代都有相对应的批量(采样得到),然后用这个批量中的数据去更新 Q 函数。即 Q 函数在采样和训练的时候会用到过去的经验数据,也可以消除样本之间的相关性。

算法流程

        算法伪代码

代码实现

DQN

class CNNDQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(CNNDQN, self).__init__()
        self._input_shape = input_shape
        self._num_actions = num_actions

        self.features = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        self.fc = nn.Sequential(
            nn.Linear(self.feature_size, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def forward(self, x):
        x = self.features(x).view(x.size()[0], -1)
        return self.fc(x)

    @property
    def feature_size(self):
        x = self.features(torch.zeros(1, *self._input_shape))
        return x.view(1, -1).size(1)

    def act(self, state, epsilon, device):
        if random() > epsilon:
            state = torch.FloatTensor(np.float32(state)) \
                .unsqueeze(0).to(device)
            q_value = self.forward(state)
            action = q_value.max(1)[1].item()
        else:
            action = randrange(self._num_actions)
        return action

        其中输入的shape为(4,84,84)

初始化网络,由于用的是cpu训练,所以加载模型时映射到cpu上

def load_model(environment, model, target_model):
    model_name = join('pretrained_models', '%s.pth' % environment)
    model.load_state_dict(torch.load(model_name,map_location='cpu'))
    target_model.load_state_dict(model.state_dict())
    return model, target_model


def initialize_models(environment, env, device, transfer):
    model = CNNDQN(env.observation_space.shape,
                   env.action_space.n).to(device)
    target_model = CNNDQN(env.observation_space.shape,
                          env.action_space.n).to(device)
    if transfer:
        model, target_model = load_model(environment, model, target_model)
    return model, target_model

计算loss

def compute_td_loss(model, target_net, replay_buffer, gamma, device,
                    batch_size, beta):
    batch = replay_buffer.sample(batch_size, beta)
    state, action, reward, next_state, done, indices, weights = batch

    state = Variable(FloatTensor(np.float32(state))).to(device)
    next_state = Variable(FloatTensor(np.float32(next_state))).to(device)
    action = Variable(LongTensor(action)).to(device)
    reward = Variable(FloatTensor(reward)).to(device)
    done = Variable(FloatTensor(done)).to(device)
    weights = Variable(FloatTensor(weights)).to(device)

    q_values = model(state)
    next_q_values = target_net(next_state)

    q_value = q_values.gather(1, action.unsqueeze(-1)).squeeze(-1)
    next_q_value = next_q_values.max(1)[0]
    expected_q_value = reward + gamma * next_q_value * (1 - done)

    loss = (q_value - expected_q_value.detach()).pow(2) * weights
    prios = loss + 1e-5
    loss = loss.mean()
    loss.backward()
    replay_buffer.update_priorities(indices, prios.data.cpu().numpy())

算法图解

​​​​​​​

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度强化学习-DQN算法 的相关文章

随机推荐

  • VNC使用介绍

    VNC在内部网络中经常被大家用到 xff0c 该工具同时具备远程操作和传输文件的双重功能 xff0c 而且速度也是很快的 xff0c xff08 低版本不具备文件传输功能 xff09 深受大众喜爱 xff0c 今天就简单写下在使用VNC的过
  • 清除chrome浏览器缓存

    之前有写过设置缓存 本文解决清除html缓存 如何才能清除缓存呢 xff1f 一下是几个清除浏览器缓存的方法 xff1a 方法1 chrome浏览器地址 xff1a chrome settings clearBrowserData xff1
  • Iterator 接口

    具有原生的Iterator 接口的数据结构有 Array Map Set String TypedArray arguments对象 NodeList对象 面我们来实现将class 和 object 也变成迭代的对象 实现的关键就是 Sym
  • 容器和LXC简单命令

    容器和LXC简单命令 文章目录 容器和LXC简单命令一 CGroup xff08 控制组 xff09 的功能1 cgroup xff08 容器控制组 xff09 1 1 功能 xff1a 1 2 具体功能 xff1a 1 3 控制组可以限制
  • Podman设置容器开机自启

    Podman设置容器开机自启 1 podman管理员容器开机自启动 span class token number 1 span span class token operator span span class token operato
  • Linux中tty、pty、pts的概念区别

    http blog sina com cn s blog 638ac15c01012e0v html 基本概念 xff1a 1 gt tty 终端设备的统称 tty一词源于Teletypes xff0c 或teletypewriters x
  • Linux下vnc的安装、使用以及设置开机启动

    安装和使用VNC resbian系统自带realvnc vnc server 启动vnc服务 vncserver 1 xff08 1类似与端口号 xff0c 也可以理解为桌面序号 xff09 关闭vnc服务 vncserver kill 1
  • 单例模式与双重锁

    设计模式中 xff0c 最为基础与常见的就是单例模式 这也是经常在面试过程中被要求手写的设计模式 下面就先写一个简单的单例 xff1a public class Singleton private static Singleton sing
  • tensorflow安装时成功,但引用时提示:Could not load dynamic library ‘cudart64_101.dll‘…… if you do not have a GPU

    问题 xff1a 前几天tensorflow已经安装成功 xff0c 并顺利引用 但是这几天安装了与之冲突的包 xff1b 在重新调整各个包的版本后 xff0c 引用tensorflow提示出错 xff1a gt gt gt import
  • 【Linux】线程实例 | 简单线程池

    今天来写一个简单版本的线程池 1 啥是线程池 池塘 xff0c 顾名思义 xff0c 线程池就是一个有很多线程的容器 我们只需要把任务交到这个线程的池子里面 xff0c 其就能帮我们多线程执行任务 xff0c 计算出结果 与阻塞队列不同的是
  • pandas数据读取与清洗视频05-批量读取excel文件并合并

    本系列课程适用人群 xff1a python零基础数据分析的朋友 xff1b 在校学生 xff1b 职场中经常要处理各种数据表格 xff0c 或大量数据 xff08 十万级以上 xff09 的朋友 xff1b 喜欢图表可视化的朋友 xff1
  • 解决Xp提示未激活状态

    今天不知是什么原因电脑突然桌面背景变为黑色 xff0c 右下角提示 You may be a victim of software counterfeiting xff0c 如下图 所示 解决方法 xff1a xff08 亲测可以解决 xf
  • 微软软件运行库下载 (DirectX,.NET Framework,VC++库..)

    运行库是程序在运行时所需要的库文件 xff0c 运行库中一般包括编程时常用的函数 xff0c 如字符串操作 文件操作 界面等内容 不同的语言所支持的函数通常是不同的 xff0c 所以使用的库也是完全不同的 xff0c 这就是为什么有VB运行
  • 解决笔记本win7系统玩游戏不能全屏办法

    我们在使用笔记本win7系统玩游戏时 xff0c 经常会发现屏幕居中两边有黑条 而有一些台式机的宽屏显示器也经常出现下玩游戏不能全屏的问题 下面系统之家给大家介绍游戏不能全屏问题通用解决方法 1 修改注册表中的显示器的参数设置 Win键 4
  • MouseWithoutBorders无界鼠标安装配置教程

    第一步 xff1a 怎样修改系统计算机全名 xff08 链接教程 xff09 win7如何修改计算机的名字 百度经验 所有虚拟机必须改成不一样的名字 xff08 至关重要 xff09 第二步 xff1a 必须防火墙为开启的状态 xff08
  • 更换 PVE7 软件仓库源和 CT模板(LXC)源为国内源

    PVE7 安装后默认配置的 apt 软件源和 CT LXC 容器模板源均是官方默认的 xff0c 国内使用性能不佳 xff0c 建议替换为 清华 Tuna 提供的国内镜像源 xff0c 速度将有一个较大的提升 如果 pve 官网 iso 镜
  • Proxmox 7.3 换国内源安装

    Proxmox 7 2 默认来自官方的源 xff0c 国内慢的一逼高峰期只有个几KB的速度 xff0c 所以换源 Debian系统源 阿里云源 和中科大proxmox源 一 更换阿里云的源 vi etc apt sources list 替
  • 在x86平台制作龙芯版debian 10系统(mips64el)

    OS ubuntu 18 04 使用debootstrap制作根文件系统会分成两个阶段 第一阶段是 xff0c 使用debootstrap命令来下载软件包 第二阶段是安装软件包 安装debootstap 等相关工具 sudo apt ins
  • Mac安装homebrew报错curl: (7) Failed to connect to raw.githubusercontent.com port 443: Operation的解决办法

    在mac上安装homebrew的时候一般都是在终端输入以下的命令安装的 xff1a bin bash c 34 curl fsSL https raw githubusercontent com Homebrew install maste
  • 深度强化学习-DQN算法

    论文地址 xff1a https arxiv org abs 1312 5602 先讲下在线 xff0c 离线 xff0c 同策略和异策略 同策略 xff08 on policy xff09 和异策略 xff08 off policy xf