DoubleDQN的理论基础及其代码实现【Pytorch + Pendulum-v0】

2023-11-06

Double DQN 理论基础

普通的 DQN 算法通常会导致对值的过高估计(overestimation)。传统 DQN 优化的 TD 误差目标为
r + γ max ⁡ a ′ Q ω − ( s ′ , a ′ ) r+\gamma \max _{a^{\prime}} Q_{\omega^{-}}\left(s^{\prime}, a^{\prime}\right) r+γmaxaQω(s,a)
其中 max ⁡ a ′ Q ω − ( s ′ , a ′ ) \max _{a^{\prime}} Q_{\omega^{-}}\left(s^{\prime}, a^{\prime}\right) maxaQω(s,a)由目标网络(参数为 w − w^- w)计算得出,我们还可以将其写成如下形式:
Q ω − ( s ′ , arg ⁡ max ⁡ a ′ Q ω − ( s ′ , a ′ ) ) Q_{\omega^{-}}\left(s^{\prime}, \arg \max _{a^{\prime}} Q_{\omega^{-}}\left(s^{\prime}, a^{\prime}\right)\right) Qω(s,argmaxaQω(s,a))

换句话说, max ⁡ \max max操作实际可以被拆解为两部分:

  • 首先选取状态 s ′ s' s下的最优动作 a ∗ = arg ⁡ max ⁡ a ′ Q ω − ( s ′ , a ′ ) a^{*}=\arg \max _{a^{\prime}} Q_{\omega^{-}}\left(s^{\prime}, a^{\prime}\right) a=argmaxaQω(s,a)
  • 接着计算该动作对应的价值 Q ω − ( s ′ , a ∗ ) Q_{\omega^{-}}\left(s^{\prime}, a^*\right) Qω(s,a)

当这两部分采用同一套Q网络进行计算时,每次得到的都是神经网络当前估算的所有动作价值中的最大值。考虑到通过神经网络估算的Q值本身在某些时候会产生正向或负向的误差,在 DQN 的更新方式下神经网络会将正向误差累积
例如,我们考虑一个特殊情形:在状态 s ′ s' s下所有动作的值均为 0,即 Q ( s ′ , a i ) = 0 , ∀ i Q\left(s^{\prime}, a_{i}\right)=0, \forall i Q(s,ai)=0,i,此时正确的更新目标应为 r + 0 = r r+0=r r+0=r,但是由于神经网络拟合的误差通常会出现某些动作的估算有正误差的情况,即存在某个动作 a ′ a' a Q ( s ′ , a ′ ) > 0 Q\left(s^{\prime}, a^{\prime}\right)>0 Q(s,a)>0,此时我们的更新目标出现了过高估计, r + γ max ⁡ Q > r + 0 r+\gamma \max Q>r+0 r+γmaxQ>r+0。因此,当我们用 DQN 的更新公式进行更新时, Q ( s , a ) Q(s,a) Q(s,a)也就会被过高估计了。同理,我们拿这个 Q ( s , a ) Q(s,a) Q(s,a)来作为更新目标来更新上一步的Q值时,同样会过高估计,这样的误差将会逐步累积。对于动作空间较大的任务,DQN 中的过高估计问题会非常严重,造成 DQN 无法有效工作的后果。

为了解决这一问题,Double DQN 算法提出利用两个独立训练的神经网络估算 max ⁡ a ′ Q ∗ ( s ′ , a ′ ) \max _{a^{\prime}} Q_{*}\left(s^{\prime}, a^{\prime}\right) maxaQ(s,a)。具体做法是将原有的 max ⁡ a ′ Q ω − ( s ′ , a ′ ) \max _{a^{\prime}} Q_{\omega^{-}}\left(s^{\prime}, a^{\prime}\right) maxaQω(s,a)更改为 Q ω − ( s ′ , arg ⁡ max ⁡ a ′ Q ω ( s ′ , a ′ ) ) Q_{\omega^{-}}\left(s^{\prime}, \arg \max _{a^{\prime}} Q_{\omega}\left(s^{\prime}, a^{\prime}\right)\right) Qω(s,argmaxaQω(s,a)),即利用一套神经网络 Q w Q_w Qw的输出选取价值最大的动作,但在使用该动作的价值时,用另一套神经网络 Q w − Q_w^- Qw计算该动作的价值。这样,即使其中一套神经网络的某个动作存在比较严重的过高估计问题,由于另一套神经网络的存在,这个动作最终使用的Q值不会存在很大的过高估计问题。

传统的 DQN 算法中,本来就存在两套Q函数的神经网络——目标网络和训练网络,只不过 max ⁡ a ′ Q ω − ( s ′ , a ′ ) \max _{a^{\prime}} Q_{\omega^{-}}\left(s^{\prime}, a^{\prime}\right) maxaQω(s,a)的计算只用到了其中的目标网络,那么我们恰好可以直接将训练网络作为 Double DQN 算法中的第一套神经网络来选取动作,将目标网络作为第二套神经网络计算值,这便是 Double DQN 的主要思想。由于在 DQN 算法中将训练网络的参数记为 w w w,将目标网络的参数记为 w − w^- w,因此,我们可以直接写出如下 Double DQN 的优化目标:
r + γ Q ω − ( s ′ , arg ⁡ max ⁡ a ′ Q ω ( s ′ , a ′ ) ) r+\gamma Q_{\omega^{-}}\left(s^{\prime}, \underset{a^{\prime}}{\arg \max } Q_{\omega}\left(s^{\prime}, a^{\prime}\right)\right) r+γQω(s,aargmaxQω(s,a))

Double DQN 代码实现

总的来说,DQN 与 Double DQN 的差别只是在于计算状态 s ′ s' s下Q值时如何选取动作:

  • DQN 的优化目标可以写为 r + γ Q ω − ( s ′ , arg ⁡ max ⁡ a ′ Q ω − ( s ′ , a ′ ) ) r+\gamma Q_{\omega^{-}}\left(s^{\prime}, \arg \max _{a^{\prime}} Q_{\omega^{-}}\left(s^{\prime}, a^{\prime}\right)\right) r+γQω(s,argmaxaQω(s,a)),动作的选取依靠目标网络 Q w − Q_{w^-} Qw
  • Double DQN 的优化目标为 r + γ Q ω − ( s ′ , arg ⁡ max ⁡ a ′ Q ω ( s ′ , a ′ ) ) r+\gamma Q_{\omega^{-}}\left(s^{\prime}, \arg \max _{a^{\prime}} Q_{\omega}\left(s^{\prime}, a^{\prime}\right)\right) r+γQω(s,argmaxaQω(s,a)),动作的选取依靠训练网络 Q w Q_w Qw

所以 Double DQN 的代码实现可以直接在 DQN 的基础上进行,无须做过多修改。

Pendulum环境介绍

本次使用的环境是倒立摆(Inverted Pendulum),该环境下有一个处于随机位置的倒立摆。环境的状态包括倒立摆角度的正弦值 sin ⁡ θ \sin \theta sinθ,余弦值 cos ⁡ θ \cos \theta cosθ,角速度 θ ˙ \dot{\theta} θ˙;动作为对倒立摆施加的力矩。每一步都会根据当前倒立摆的状态的好坏给予智能体不同的奖励,该环境的奖励函数为 − ( θ 2 + 0.1 θ ˙ 2 + 0.001 a 2 ) -\left(\theta^{2}+0.1 \dot{\theta}^{2}+0.001 a^{2}\right) (θ2+0.1θ˙2+0.001a2),倒立摆向上保持直立不动时奖励为 0,倒立摆在其他位置时奖励为负数。环境本身没有终止状态,运行 200 步后游戏自动结束。

Pendulum环境的状态空间

标号 名称 最小值 最大值
0 cos ⁡ θ \cos\theta cosθ -1.0 1.0
1 sin ⁡ θ \sin\theta sinθ -1.0 1.0
2 θ ˙ \dot{\theta} θ˙ -8.0 8.0

Pendulum环境的动作空间

标号 动作 最小值 最大值
0 力矩 -2.0 2.0

力矩大小是在范围内的连续值。由于 DQN 只能处理离散动作环境,因此我们无法直接用 DQN 来处理倒立摆环境,但倒立摆环境可以比较方便地验证 DQN 对Q值的过高估计****:倒立摆环境下值的最大估计应为 0(倒立摆向上保持直立时能选取的最大值),值出现大于 0 的情况则说明出现了过高估计。为了能够应用 DQN,我们采用离散化动作的技巧。例如,下面的代码将连续的动作空间离散为 11 个动作。动作分别代表 [ 0 , 1 , 2 , … , 9 , 10 ] [0,1,2, \ldots, 9,10] [0,1,2,,9,10],力矩为 [ − 2 , − 1.6 , − 1.2 , … , 1.2 , 1.6 , 2 ] [-2,-1.6,-1.2, \ldots, 1.2,1.6,2] [2,1.6,1.2,,1.2,1.6,2]

Double DQN 代码实现

在 DQN 代码的基础上稍做修改就可实现 Double DQN。

class DQN:
    def __init__(self, args):
        self.args = args
        self.hidden_dim = 128
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.gamma = args.gamma  # 折扣因子
        self.epsilon = args.epsilon  # epsilon-贪婪策略
        self.target_update = args.target_update  # 目标网络更新频率
        self.count = 0  # 计数器,记录更新次数
        self.num_episodes = args.num_episodes
        self.minimal_size = args.minimal_size
        self.dqn_type = args.dqn_type

        self.env = gym.make(args.env_name)

        random.seed(args.seed)
        np.random.seed(args.seed)
        self.env.seed(args.seed)
        torch.manual_seed(args.seed)

        self.replay_buffer = ReplayBuffer(args.buffer_size)

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = 11  # 将连续动作分成11个离散动作

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.q_net = Qnet(self.state_dim, self.hidden_dim, self.action_dim).to(self.device)
        self.target_q_net = Qnet(self.state_dim, self.hidden_dim, self.action_dim).to(self.device)

        self.optimizer = Adam(self.q_net.parameters(), lr=self.lr)

    def select_action(self, state):  # epsilon-贪婪策略采取动作
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.action_dim)
        else:
            state = torch.tensor([state], dtype=torch.float).to(self.device)
            action = self.q_net(state).argmax().item()
        return action

    def max_q_value(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        return self.q_net(state).argmax().item()

    def update(self, transition):
        states = torch.tensor(transition["states"], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition["actions"]).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition["rewards"], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition["next_states"], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition["dones"], dtype=torch.float).view(-1, 1).to(self.device)

        q_values = self.q_net(states).gather(1, actions)  # Q value

        # 下个状态的最大Q值
        ##################################################################
        if self.dqn_type == 'DoubleDQN':
            max_action = self.q_net(next_states).max(1)[1].view(-1, 1)
            max_next_q_values = self.target_q_net(next_states).gather(1, max_action)
        else:  # DQN
            max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)
        ##################################################################

        q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)  # TD error

        loss = torch.mean(F.mse_loss(q_values, q_targets))  # 均方误差损失函数
        self.optimizer.zero_grad()  # PyTorch中默认梯度会累积,这里需要显式将梯度置为0
        loss.backward()  # 反向传播更新参数
        self.optimizer.step()

        if self.count % self.target_update == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())  # 更新目标网络

        self.count += 1

只在update()函数里面有所更改,注意看Double DQN的实现方式。另外,max_q_value()函数是为了后面验证过高估计使用的。

def dis_to_con(discrete_action, env, action_dim):
    """离散动作转回连续的函数"""
    action_lowbound = env.action_space.low[0]  # 连续动作的最小值
    action_upbound = env.action_space.high[0]  # 连续动作的最大值
    return action_lowbound + (discrete_action / (action_dim - 1)) * (action_upbound - action_lowbound)

DQN与Double DQN的训练结果对比

接下来我们对比一下 DQN 和 Double DQN 的训练情况,为了便于后续多次调用,我们进一步将 DQN 算法的训练过程定义成一个函数。训练过程会记录下每个状态的最大Q值,在训练完成后我们可以将结果可视化,观测这些Q值存在的过高估计的情况,以此来对比 DQN 和 Double DQN 的不同。

def train_DQN(self):
    return_list = []
    max_q_value_list = []
    max_q_value = 0
    for i in range(10):
        with tqdm(total=int(self.num_episodes / 10), desc=f'Iteration {i}') as pbar:
            for episode in range(self.num_episodes // 10):
                episode_return = 0
                state = self.env.reset()
                while True:
                    action = self.select_action(state)
                    max_q_value = self.max_q_value(state) * 0.005 + max_q_value * 0.995  # 平滑处理
                    max_q_value_list.append(max_q_value)  # 保存每个状态的最大Q值

                    action_continuous = dis_to_con(action, self.env, self.action_dim)
                    next_state, reward, done, _ = self.env.step([action_continuous])

                    self.replay_buffer.add(state, action, reward, next_state, done)

                    if self.replay_buffer.size() > self.minimal_size:
                        s, a, r, s_, d = self.replay_buffer.sample(self.batch_size)
                        transitions = {"states": s, "actions": a, "rewards": r, "next_states": s_, "dones": d}
                        self.update(transitions)

                    state = next_state
                    episode_return += reward

                    if done: break

                return_list.append(episode_return)
                if (episode + 1) % 10 == 0:
                    pbar.set_postfix(
                        {
                            "episode": f"{self.num_episodes / 10 * i + episode + 1}",
                            "return": f"{np.mean(return_list[-10:]):3f}"
                        }
                    )
                pbar.update(1)
    return return_list, max_q_value_list

首先训练 DQN 并打印出其学习过程中最大Q值的情况。

args = define_args()
model = DQN(args)
return_list, max_q_value_list = model.train_DQN()

episodes_list = list(range(len(return_list)))
mv_return = utils.moving_average(return_list, 5)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('DQN on {}'.format(args.env_name))
plt.show()

frames_list = list(range(len(max_q_value_list)))
plt.plot(frames_list, max_q_value_list)
plt.axhline(0, c='orange', ls='--')
plt.axhline(10, c='red', ls='--')
plt.xlabel('Frames')
plt.ylabel('Q value')
plt.title('DQN on {}'.format(args.env_name))
plt.show()

image.png
image.png
根据代码运行结果我们可以发现,DQN 算法在倒立摆环境中能取得不错的回报,最后的期望回报在-200 左右,但是不少Q值超过了 0,有一些还超过了 10,该现象便是 DQN 算法中的Q值过高估计。

现在我们来看一下 Double DQN 是否能对此问题进行改善。

args.dqn_type = "DoubleDQN"
agent = DQN(args)
return_list, max_q_value_list = agent.train_DQN()

episodes_list = list(range(len(return_list)))
mv_return = utils.moving_average(return_list, 5)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Double DQN on {}'.format(args.env_name))
plt.show()

frames_list = list(range(len(max_q_value_list)))
plt.plot(frames_list, max_q_value_list)
plt.axhline(0, c='orange', ls='--')
plt.axhline(10, c='red', ls='--')
plt.xlabel('Frames')
plt.ylabel('Q value')
plt.title('Double DQN on {}'.format(args.env_name))
plt.show()

image.png
image.png
可以发现,与普通的 DQN 相比,Double DQN 比较少出现值Q大于 0 的情况,说明Q值过高估计的问题得到了很大缓解。

另外对于解决Q值过估计问题,还有一些其他的方法,比如DQL、EBQL等方法,后续咱慢慢实现。附带这两篇论文,感兴趣的可以先去看看:

  • Peer O, Tessler C, Merlis N, et al. Ensemble bootstrapping for Q-Learning[C]//International Conference on Machine Learning. PMLR, 2021: 8454-8463.
  • Hasselt H. Double Q-learning[J]. Advances in neural information processing systems, 2010, 23.

\quad
\quad

参考

  • 《动手学强化学习》
    \quad
    \quad

持续更新~有错误的话敬请指正!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

DoubleDQN的理论基础及其代码实现【Pytorch + Pendulum-v0】 的相关文章

随机推荐

  • 进制压缩加密_token参数

    进制压缩加密 token参数 网址 https sh meituan com meishi c17 进入抓包 查看要获取的数据 可以在请求地址 找到 getPoiList 的请求链接 请求参数有多个 但是多次请求对比发现只有 token 参
  • LeetCode:118(Python)—— 杨辉三角(简单)

    杨辉三角 概述 给定一个非负整数 numRows 生成 杨辉三角 的前 numRows 行 在 杨辉三角 中 每个数是它左上方和右上方的数的和 输入 numRows 5 输出 1 1 1 1 2 1 1 3 3 1 1 4 6 4 1 输入
  • 超经典!分割任务数据集介绍。

    文章目录 前言 一 IRSTD 1k 二 Pascal VOC2012 1 数据简介 2 分割任务数据集介绍 三 iSAID 总结 前言 在探索网络的过程中 比较基础和重要的工作是了解数据 今天来总结下我目前使用过的分割任务数据集 本博文将
  • Linux进阶_DNS服务和BIND之实战案例篇

    成功不易 加倍努力 1 实战案例 实现DNS正向主服务器 2 实战案例 实现DNS从服务器 3 实战案例 实现DNS forward 缓存 服务器 4 实战案例 利用view实现智能DNS 5 实战案例 综合案例 实现Internet 的D
  • 【linux多线程(四)】——线程池的详细解析(含代码)

    目录 什么是线程池 线程池的应用场景 线程池的实现 线程池的代码 C linux线程 壹 初识线程 区分线程和进程 线程创建的基本操作 线程 二 互斥量的详细解析 线程 三 条件变量的详细解析 什么是线程池 线程池是一种线程使用模式 它是将
  • java 栅栏_Java并发基础-栅栏(CountDownLatch)与闭锁(CyclicBarrier)

    1 闭锁CountDownLatch 闭锁CountDownLatch用于线程间的同步 它可以使得一个或者多个线程等待其它线程中的某些操作完成 它有一个int类型的属性count 当某个线程调用CountDownLatch对象的await方
  • android获取各种系统路径的方法

    android获取各种系统路径的方法 整理了一些安卓开发中可能会用到的各种路径的获取方法 欢迎评论 通过Environment获取的Environment getDataDirectory getPath 获得根目录 data 内部存储路径
  • Spring Boot + 阿里OSS实现图片上传,返回预览的地址,实现图片预览

    阿里OSS实现图片上传 返回预览地址 注册阿里OSS 首先进入阿里云的官网 https www aliyun com 紧接着点击首页上的立即开通 点击这个创建一个bucket 其余的默认就可以 可以根据自己的实际需求去写 使用代码操作阿里O
  • Redis AOF和RDB

    Redis AOF和RDB Redis是内存型数据库 为了保证数据在断电后不会丢失 需要将内存中的数据持久化到硬盘上 RDB持久化 将某个时间点的所有数据都存放到硬盘上 可以将快照复制到其他服务器从而创建具有相同数据的服务器副本 如果系统发
  • vue不是内部或外部命令,也不是可运行的程序

    使用vue脚手架初始化vue项目时 总是报 vue不是内部或外部命令 也不是可运行的程序 这样的错误 检查基础环境是否具备 1 node v查看版本 已经安装 2 npm v查看版本 已经安装 3 node 系统环境变量已经设置 于是乎 查
  • Error: Cannot fit requested classes in a single dex file (# methods: 65948 > 65536) 解决方法

    Error Cannot fit requested classes in a single dex file methods 65948 gt 65536 解决方法 最近写项目 写着写着运行时突然就报错了 运行不起来了 报错如下 Erro
  • 【django】admin后台管理的坑

    自定义的主键 必须要在fields或者fieldsets里 但是默认添加的或者自主添加的autofield字段可以不在admin页面里添加 保存时会自动添加
  • A股投资日历

    A股投资日历 12月2日 国11月非农就业报告 21 30 中证AAA综合债指数系列 8条 发布 2022中国 博鳌 国际黄金市场年度大会举办 影响 宏观 债券 黄金 12月2 3日 第四届大宗商品金融服务创新锋会 影响 大宗商品 12月2
  • Linux下嵌入式程序仿真调试(GDB)(二)

    目录 目录 前言 Ubuntu下Qt的GDB环境搭建未成功 Qt5的设置 命令行调试问题记录 总结 链接地址 前言 Linux下嵌入式程序仿真调试 GDB 一 主要介绍了GDB交叉调试环境的搭建过程 本想把交叉编译好的gdb程序放置到Qt中
  • SpringBean的自动装配运行原理

    SpringBean的自动装配运行原理 引言 在现代的软件开发领域中 快速且灵活地处理依赖关系是至关重要的 Spring框架以其强大的依赖注入功能 使得开发者能够轻松管理各种对象之间的依赖关系 其中 自动装配是Spring框架中一项重要的功
  • (Oracle功能篇) Oracle 数据库连接池

    使用 proxool 0 9 1 zip http ncu dl sourceforge net project proxool proxool 0 9 1 proxool 0 9 1 zip 相关代码 package yerasel im
  • SpringBoot+Mybatis 整合 xml配置使用+免xml使用

    SpringBoot作为现在非常流行的微服务框架 Mybatis作为现在非常流行的ORM框架 他们整合在一起是不是会产生火花呢 今天就搭建一个SpringBoot Mybatis的微服务开发环境 IEDA JDK1 8首先我们先创建个mav
  • h3c 生成树协议及stp配置命令

    STP 作用 1 通过阻断冗余链路来消除桥接网络中可能存在的路径回环 2 当前路径发生故障时 激活冗余备份链路 恢复网络连通性 STP Spanning Tree Protocol 生成树协议 是用于在局域网中消除数据链路层物理环路的协议
  • hive配置优化

    错误描述 执行 hive 任务报错 highlighting text 版本 Hive 2 2 0 Hadoop 2 7 6 Exit code is 143 Container exited with a non zero exit co
  • DoubleDQN的理论基础及其代码实现【Pytorch + Pendulum-v0】

    Double DQN 理论基础 普通的 DQN 算法通常会导致对值的过高估计 overestimation 传统 DQN 优化的 TD 误差目标为 r max