使用DQN训练Grid_word任务

2023-11-17

“Tensorflow实战”一书中,强化学习一章里讲到了DQN网络,很有感触,在这里和大家分享一下。

DQN网络也是Q-learning的升级版,在原有的Q-learning中加入了卷积层。

由于深度学习需要大量的样本数据,DQN也就引入了Experience Replay,主要思想就是存储Agent的Experience,也就是样本,每次训练都会随机抽取一些样本。为了避免网络短视只学习到最新样本,Experience中的缓存Buffer会重复利用以前的样本,只有当容量满了以后,才会用新样本代替旧样本。

DQN还会使用Double模型来进行训练。一个网络制造学习目标,另一个进行实际的训练,每次更新模型都会让学习目标发生改变,如果更新频繁,幅度太大,就会导致训练失控。因此Target DQN会辅助主DQN,以一个非常低的速率学习主DQN的参数。同时让主DQN通过预测选择最大Q值的Action,再去获取Target DQN上的这个Action所对应的Q值,避免出现被高估的次优Action,总是超过最有Action,无法找到最优解。

此外,还使用Dueling DQN。将Q值进行拆分,一部分是静态环境本身具备的价值Value,另一部分是动态选择某个Action所带来的额外价值Advantage。每次计算Q值都是使用V+A的和,而动态A是由很多个Action组成,最后比较静态V加上哪个动态A的值最大。让DQN对环境估计能力更强。

接下来的任务是一个导航类的游戏,包含一个hero,两个fire,四个goal,目标控制hero移动,每次向上下左右方向移动一步,尽可能多的接触goal(reward为1),避开fire(reward为-1),在限定的步数拿到最高分。

import numpy as np
import random
import itertools
import tensorflow as tf
import scipy.misc
import matplotlib.pyplot as plt
import os

#建立一个任务环境
class gameOb():
    def __init__(self, coordinates, size, intensity, channel, reward, name):
        self.x = coordinates[0]
        self.y = coordinates[1]
        self.size = size
        self.intensity = intensity
        self.channel = channel
        self.reward = reward
        self.name = name

class gameEnv():
    def __init__(self, size):
        self.sizeX = size
        self.sizeY = size
        self.actions = 4
        self.objects = []
        a = self.reset()
        plt.imshow(a, interpolation="nearest")



    def reset(self):
        self.objects = []
        hero = gameOb(self.newPosition(), 1, 1, 2, None, 'hero')
        self.objects.append(hero)
        goal = gameOb(self.newPosition(), 1, 1, 1, 1, 'goal')
        self.objects.append(goal)
        hole = gameOb(self.newPosition(), 1, 1, 0, -1, 'fire')
        self.objects.append(hole)
        goal2 = gameOb(self.newPosition(), 1, 1, 1, 1, 'goal')
        self.objects.append(goal)
        hole2 = gameOb(self.newPosition(), 1, 1, 0, -1, 'fire')
        self.objects.append(hole2)
        goal3 = gameOb(self.newPosition(), 1, 1, 1, 1, 'goal')
        self.objects.append(goal3)
        goal4 = gameOb(self.newPosition(), 1, 1, 1, 1, 'goal')
        self.objects.append(goal4)

        state = self.renderEnv()
        self.state = state
        return state


    def moveChar(self, direction):
        hero = self.objects[0]
        heroX = hero.x
        heroY = hero.y
        if direction == 0 and hero.y >= 1:
            hero.y -= 1
        if direction == 1 and hero.y <= self.sizeY-2:
            hero.y += 1
        if direction == 2 and hero.x >= 1:
            hero.x -= 1
        if direction == 3 and hero.x <= self.sizeX-2:
            hero.x += 1
        self.objects[0] = hero


    def newPosition(self):
        iterables = [range(self.sizeX), range(self.sizeY)]
        points = []
        for t in itertools.product(*iterables):
            points.append(t)
        currentPositions = []
        for objectA in self.objects:
            if (objectA.x, objectA.y) not in currentPositions:
                currentPositions.append((objectA.x, objectA.y))
        for pos in currentPositions:
            points.remove(pos)
        location = np.random.choice(range(len(points)), replace=False)
        return points[location]


    def checkGoal(self):
        others= []
        for obj in self.objects:
            if obj.name == 'hero':
                hero = obj
            else:
                others.append(obj)
        for other in others:
            if hero.x == other.x and hero.y == other.y:
                self.objects.remove(others)
                if other.reward == 1:
                    self.objects.append(gameOb(self.newPosition(), 1, 1, 1, 1, 'goal1'))
                else:
                    self.objects.append(gameOb(self.newPosition(), 1, 1, 0, -1, 'fire'))
                return other.reward, False
        return 0.0, False


    def renderEnv(self):
        a = np.ones([self.sizeY+2, self.sizeX+2, 3])
        a[1:-1, 1:-1, :] = 0
        hero = None
        for item in self.objects:
            a[item.y+1:item.y+item.size+1, item.x+1:item.x+item.size+1, item.channel] = item.intensity
        b = scipy.misc.imresize(a[:, :, 0], [84 ,84, 1], interp='nearest')
        c = scipy.misc.imresize(a[:, :, 1], [84, 84, 1], interp='nearest')
        d = scipy.misc.imresize(a[:, :, 2], [84, 84, 1], interp='nearest')
        a = np.stack([b, c, d], axis=2)
        return a


    def step(self, action):
        self.moveChar(action)
        reward, done = self.checkGoal()
        state = self.renderEnv()
        return state, reward, done




env =gameEnv(size=5)

#设计DQN网络
class Qnetwork():
    def __init__(self, h_size):
        self.scalarInput = tf.placeholder(shape=[None, 21168], dtype=tf.float32)
        self.imageIn = tf.reshape(self.scalarInput, shape=[-1,84,84,3])
        self.conv1 = tf.contrib.layers.convolution2d(input=self.imageIn,
                                  kernel_size=[8,8],
                                  strides=[4,4],
                                  padding='VALID',
                                  num_outputs=32,
                                  biases_initializer=None)
        self.conv2 = tf.contrib.layers.convolution2d(input=self.conv1,
                                  kernel_size=[4,4],
                                  strides=[2,2],
                                  padding='VALID',
                                  num_outputs=64,
                                  biases_initializer=None)
        self.conv3 = tf.contrib.layers.convolution2d(input=self.conv2,
                                  kernel_size=[3,3],
                                  strides=[1,1],
                                  padding='VALID',
                                  num_outputs=32,
                                  biases_initializer=None)
        self.conv4 = tf.contrib.layers.convolution2d(input=self.conv3,
                                  kernel_size=[7,7],
                                  strides=[1,1],
                                  padding='VALID',
                                  num_outputs=512,
                                  biases_initializer=None)


#将conv4进行平均切分,分为静态A以及动态的V。
        self.streamAC, self.streamVC = tf.split(self.conv4,2,3)
        self.streamA = tf.contrib.layers.flatten(self.streamAC)
        self.streamV = tf.contrib.layers.flatten(self.streamVC)
        self.AW = tf.Variable(tf.random_normal([h_size//2, env.actions]))
        self.VW = tf.Variable(tf.random_normal([h_size//2,1]))
        self.Advantage = tf.matmul(self.streamA, self.AW)
        self.Value = tf.matmul(self.streamV, self.VW)

        self.Qout = self.Value + tf.subtract(self.Advantage, tf.reduce_mean(self.Advantage, reduction_indices=1,
                                                                            keep_dims=True))
        self.predict = tf.argmax(self.Qout, 1)
#主DQN预测Q值,是用主DQN的输出,乘以action转换为onehot编码。                                                         
self.targetQ = tf.placeholder(shape=[None],dtype=tf.float32)
self.actions = tf.placeholder(shape=[None], dtype=tf.int32)
self.actions_onehot = tf.one_hot(self.actions, env.actions, dtype=tf.float32)
self.Q = tf.reduce_sum(tf.multiply(self.Qout, self.actions_onehot), reduction_indices=1)
self.td_erroe = tf.square(self.targetQ - self.Q)
self.loss = tf.reduce_mean(self.td_error)
self.trainer = tf.train.AdamOptimizer(learning_rate=0.0001)
self.updateModel =  self.trainer.minimizer(self.loss)                       
#定义Experience Replay。
class experience_buffer():
    def __init__(self, buffer_size=50000):
        self.buffer = []
        self.buffer_size = buffer_size

    def add(self, experience):
        if len(self.buffer) + len(experience) >= self.buffer_size:
            self.buffer[0:(len(experience) + len(self.buffer)) - self.buffer_size] = []
        self.buffer.extend(experience)

    def sample(self, size):
        return np.reshape(np.array(random.sample(self.buffer,size)), [size,5])


def processState(states):
    return np.reshape(states, [21168])

#target DQN以tau的速率向主DQN学习。
def updateTargetGraph(tfVars, tau):
    total_vars = len(tfVars)
    op_holder = []
    for idx, var, in enumerate(tfVars[0:total_vars//2]):
        op_holder.append(tfVars[idx+total_vars//2].assign
                         ((var.value() * tau) + ((1-tau) * tfVars[idx+total_vars//2].value())))
    return op_holder

def updateTarget(op_holder, sess):
    for op in op_holder:
        sess.run(op)


batch_size =32
update_freq = 4
y = 0.99
startE = 1
endE = 0.1
anneling_steps = 10000
num_episodes = 1000
pre_train_steps = 10000
max_epLength = 50
load_model = False
path = "./dqn"
h_size = 512
tau = 0.001

mainQN = Qnetwork(h_size)
targetQN = Qnetwork(h_size)
init = tf.global_variables_initializer()
trainables = tf.trainable_variables()
targetOps = updateTargetGraph(trainables, tau=tau)


myBuffer= experience_buffer()

e = startE
stepDrop = (startE - endE)/anneling_steps
rList = []
total_steps = 0
saver = tf.train.Saver()
if not os.path.exists(path):
    os.makedirs(path)


with tf.Session() as sess:
    if load_model == True:
        print('Loading Model...')
        ckpt = tf.train.get_checkpoint_state(path)
        saver.restore(sess, ckpt.model_checkpoint_path)
    sess.run(init)
    updateTarget(targetOps, sess)
    for i in range(num_episodes+1):
        episodeBuffer = experience_buffer()
        s = env.reset()
        s = processState(s)
        d = False
        rAll = 0
        j = 0


        while j < max_epLength:
            j += 1
            if np.random.rand(1) < e or total_steps < pre_train_steps:
                a = np.random.randint(0,4)
            else:
                a = sess.run(mainQN.predict,
                             feed_dict={mainQN.scalarInput:[s]})[0]
            s1, r, d = env.step(a)
            s1 = processState(s1)
            total_steps += 1
            episodeBuffer.add(np.reshape(np.array([s,a,r,s1,d]), [1,5]))


            if total_steps > pre_train_steps:
                if e > endE:
                    e -= stepDrop
                if total_steps % (update_freq) == 0:
                    trainBatch = myBuffer.sample(batch_size)
                    A = sess.run(mainQN.predict,
                                 feed_dict={mainQN.scalarInput:np.vstack(trainBatch[:,3])})
                    Q = sess.run(targetQN.Qout,
                                 feed_dict={targetQN.scalarInput:np.vstack(trainBatch[:,3])})
                    doubleQ = Q[range(batch_size), A]
                    targetQ = trainBatch[:,2] + y*doubleQ
                    _ = sess.run(mainQN.updateModel,
                                 feed_dict={mainQN.scalarInput:np.vstack(trainBatch[:,0]),
                                            mainQN.targetQ:targetQ,
                                            mainQN.actions:trainBatch[:,1]})

                    updateTarget(targetOps, sess)

            rAll += r
            s = s1


            if d == True:
                break




            myBuffer.add(episodeBuffer.buffer)
            rList.append(rAll)
            if i>0 and i % 25 == 0:
                print('episode', i, ',average reward of last 25 episode', np.mean(rList[-25:]))
            if i > 0 and i %1000 ==0:
                saver.save(sess, path+'/model-' + str(i) + '.cptk')
                print("Saved Model1")
        saver.save(sess, path + '/model-' + str(i) + '.cptk')

最后强化学习虽然是属于深度学习的分支,但是内容很是庞大,最近也在学习诺威格博士的“人工智能”一书,里面对于Agent的研究让人大开眼界。有幸会再分享。

祝近安


本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用DQN训练Grid_word任务 的相关文章

随机推荐

  • MSP430 F5529的按钮控制led灯亮灭程序代码——按一下亮一下,再按一下暗

    2019 6 27 MP430F5529 电子工艺实习实验1 作业1 按下按键 LED亮 再按一次 LED灭 设置P8 1输出灯 P1 2输入按钮 P1 2下降沿 1 0 中断 中断标识为0 给按钮设置上拉电阻让其的高电位更加稳定 设置这两
  • 详解Java基础中注释添加的位置以及原则

    一 添加注释的位置 1 类 接口 这一部分注释是必须的 在这里 我们需要使用javadoc注释 需要标明 创建者 创建时间 版本 以及该类的作用 2 方法 在方法中 我们需要对入参 出参 以及返回值 均要标明 3 常量 对常量 我们需要使用
  • error LNK2005: _DllMain@12 already defined in MSVCRTD.lib

    本文主要分析和解决编译链接时产生的 LNK2005 错误 错误信息 mfcs90ud lib dllmodul obj error LNK2005 DllMain 12 already defined in MSVCRTD lib dllm
  • System.currentTimeMillis()

    System currentTimeMillis 计算方式与时间的单位转换 一 时间的单位转换 1秒 1000毫秒 ms 1毫秒 1 1 000秒 s 1秒 1 000 000 微秒 s 1微秒 1 1 000 000秒 s 1秒 1 00
  • Nginx 解决跨域

    项目准备 前端网站地址 http localhost 8080 服务端网址 http localhost 8081 确认服务端是没有处理跨域的 先用postman测试服务端接口是正常的 当前端网站8080去访问服务端接口时 就产生了跨域问题
  • 华硕笔记本开机自动进入bios,进不了windows系统的解决方法

    亲测有效解决办法 1 开机的时候长按F2键进入BIOS界面 通过方向键进 Secure 菜单 通过方向键选择 Secure Boot Control 选项 将其设定为 Disabled 2 通过方向键进入 Boot 菜单 通过方向键选择 L
  • ROS2执行source setup.bash命令报错及解决办法

    1 错误类型 在对ros2包编译通过后 在终端执行 source path to your workspace install setup bash 时报错 not found path to your workspace install
  • 快手直播怎么引流?快手直播效果怎么样?每个人对时尚的定义不同

    快手直播怎么引流 快手直播效果怎么样 每个人对时尚的定义不同 快手直播效果怎么样 每个人对时尚的定义不同 对于普通人来说 都会有对美的追求 比如找到适合自己的穿搭 适合自己的美妆 几乎每一种时尚风格在快手平台都能有被老铁认可的机会和其存在的
  • mysql常用的hint(原创)

    转自 http linux chinaunix net techdoc database 2008 07 29 1021449 shtml 对于经常使用Oracle的朋友可能知道 oracle的hint功能种类很多 对于优化sql语句提供了
  • 网络部署运维实验(pat 端口映射含命令)

    作者 小刘在这里 每天分享云计算网络运维课堂笔记 疫情之下 你我素未谋面 但你一定要平平安安 一 起努力 共赴美好人生 夕阳下 是最美的 绽放 愿所有的美好 再疫情结束后如约而至 目录 一 实验简介 二 图纸 三 实验命令 一 实验简介 本
  • 区块链开发团队,公链开发才是主战场

    在区块链技术开发公司不断完善的当下 很多企业都想加入进来 有远见的人永远能嗅到区块链未来市场的发展趋向 以区块链技术开发实体企业应用 在空白的市场里拥有无限开发潜力 而创业者要做的就是快人一步 才能夺得市场先机 我们团队作为一家专业的区块链
  • python统计字符串中,字母的个数、数字的个数、其它字符个数。

    str input 请输入 letter 0 num 0 other 0 for i in str if i isdigit num 1 elif i isalnum letter 1 else other 1 print letter n
  • axios post传递对象_POST 方法的content-type类型

    content type是http请求的响应头和请求头的字段 当作为响应头时 告诉客户端实际返回的内容的内容类型 作为请求头时 post或者put 客户端告诉服务器实际发送的数据类型 在前端开发过程中 通常需要跟后端工程师对接接口的数据格式
  • React 条件渲染最佳实践(7 种方法)

    在 React 中 条件渲染可以通过多种方式 不同的使用方式场景取决于不同的上下文 在本文中 我们将讨论所有可用于为 React 中的条件渲染编写更好的代码的方法 条件渲染在每种编程语言 包括 javascript 中都是的常见功能 在 j
  • 线性dp的题目汇总

    恩 挺多 慢慢看 衔接在此
  • ccrypt 在 Windows上的使用教程

    ccrypt是个加密解密工具包 一般情况下在Linux上使用 这是个windows版的使用教程 请注意 ccrypt是一个 命令行 程序 它只能从DOS提示符或shell中运行 它不是那种双击就能运行的程序 step1 到官网下载对应的安装
  • gvim for verilog简易配置

    目录 前言 一 gvim的主题和字体资源 二 gvim编辑器基本配置 三 gvim针对verilog配置 总结 前言 分别介绍了gvim的主题和字体资源推荐 gvim编辑器基本配置和针对verilog的配置 以下为正文 一 gvim的主题和
  • 递归算法(demo:斐波那契数列的实现,树的遍历,快速排序)

    递归算法 执行代码 并没执行完全的时候调用自己本身 然后等待条件不满足递归的时候 完全执行代码 执行完全后返回上一层 执行未完成的部分 递归算法与for where循环可以相互转换 通过一定的方案达到一样的效果 比如for循环可以通过栈 实
  • MacOS IDEA配置scala

    1 前提 scala已经在mac本地安装 2 新建项目后 点击项目 右键 点击Add Framwork Support 3 找到scala 刚开始时 scala dk 2 11 8 是没有的 需要自己找 所以点击create 再点击Brow
  • 使用DQN训练Grid_word任务

    Tensorflow实战 一书中 强化学习一章里讲到了DQN网络 很有感触 在这里和大家分享一下 DQN网络也是Q learning的升级版 在原有的Q learning中加入了卷积层 由于深度学习需要大量的样本数据 DQN也就引入了Exp