【python】CliffWalking悬崖寻路问题

2023-11-01

简介

机器学习:监督学习、非监督学习、强化学习

  • 模仿人类和动物的试错机制进行学习
  • 智能体环境交互,根据当前的环境状态s,按照一定策略采取行动a,获得回报r
  • 目标:获取最大累积期望回报

在这里插入图片描述
脉络介绍:
在这里插入图片描述

gym库-CliffWalking

安装标准化实验环境

pip install gym

CliffWalking:悬崖寻路问题,4*12网格,红色为悬崖,36为起始,47为目标

动作:0-4,上右下左,如果移出除网络则不变
奖励:{-1,100},与悬崖为-100,否则为-1

在这里插入图片描述

SARSA

行动选择策略:ε-greedy,以ε的概率进行探索,以1-ε的概率进行利用
Q值更新:根据下一次实际行动更新,胆小,选择离悬崖远的路线

 td_target += gamma * Q[next_state, next_action]
 Q[state, action] += lr * (td_target - Q[state, action])

Q-learning

行动选择策略:ε-greedy,以ε的概率进行探索,以1-ε的概率进行利用
Q值更新:选取最优的行动更新Q值,胆大,最终选择离悬崖近的路线

 td_target += gamma * max(Q[next_state, :])
 Q[state, action] += lr * (td_target - Q[state, action])

示例

SARSA

import numpy as np
import pandas as pd
import gym
from tqdm import tqdm


def max_index(a):
    # return np.argmax(a)
    candidate = np.where(a == np.max(a))[0]  
    index = np.random.randint(0, len(candidate))  
    return candidate[index]


def eps_greedy(Q, state):
    a = Q[state, :]
    if np.random.random() < 1 - eps:
        return max_index(a)
    return np.random.randint(0, len(a))  # [start,end)


def calc_policy(Q):
    state_number = Q.shape[0]
    policy = np.zeros(shape=state_number, dtype=np.int8)
    for i in range(state_number):
        policy[i] = np.argmax(Q[i, :])
    return policy


# 0123:{上右下左}
def print_optimal_action(pi, row, col):
    print(actions)
    for i in range(row):
        for j in range(col):
            print(actions[pi[i * col + j]], end=' ')
        print()


# 比较a,b之间的差值是否小于阈值
def is_same(a, b, thresold=0.001):
    e = np.abs(a - b) > thresold
    return np.sum(e) == 0


if __name__ == '__main__':
    eps = 0.01
    lr = 0.01
    gamma = 0.99

    np.set_printoptions(suppress=True)
    row, col = 4, 12
    state_number = row * col
    action_number = 4
    actions = list('↑→↓←')  # 上右下左:0123
    Q = np.zeros((state_number, action_number), dtype=np.float64)
    Q_last = [np.ones_like(Q), np.ones_like(Q), np.ones_like(Q), np.ones_like(Q), np.ones_like(Q)]
    env = gym.make('CliffWalking-v0')
    print('状态数量:', env.observation_space.n)
    print('行为数量:', env.action_space.n)

    #  10万个episode的迭代
    for i in tqdm(range(1, 100000)):
        env.reset()
        state = 36  # 初始位置
        done = False
        action = eps_greedy(Q, state)
        while not done:
            next_state, reward, done, info = env.step(action)
            # print('state, action, reward:', state, action, reward)
            next_action = eps_greedy(Q, next_state)
            td_target = reward
            if not done:
                td_target += gamma * Q[next_state, next_action]
            Q[state, action] += lr * (td_target - Q[state, action])
            state = next_state
            action = next_action
        if is_same(Q_last[0], Q):
            print('Q-table迭代完成,提前退出:', i)
            break
        Q_last = Q_last[1:]
        Q_last.append(np.copy(Q))
    pi = calc_policy(Q)
    print('Q Table:\n', Q)
    # np.savetxt('Q_table.txt', Q, fmt='%.5f')
    pd.DataFrame(Q).to_excel('Q_table.xlsx', index=True)
    for s_id in range(state_number):
        print(s_id, s_id // col, s_id % col, Q[s_id, :], pi[s_id], actions[pi[s_id]])
    print('最优策略:', pi)
    print_optimal_action(pi, row, col)

    # 输出最终路径(状态及坐标)
    env.reset()
    state = 36
    done = False
    trace = [{state: (state // col, state % col)}]
    while not done:
        action = pi[state]
        state, _, done, _ = env.step(action)
        trace.append({state: (state // col, state % col)})
        if len(trace) > 48:
            break
    print(len(trace), ':', trace)

Q-learning

import numpy as np
import pandas as pd
import gym
from tqdm import tqdm


def max_index(a):
    candidate = np.where(a == np.max(a))[0]
    index = np.random.randint(0, len(candidate))
    return candidate[index]


def eps_greedy(Q, state):
    a = Q[state, :]
    if np.random.random() < 1-eps:
        return max_index(a)
    return np.random.randint(0, len(a))


def calc_policy(Q):
    state_number = Q.shape[0]
    policy = np.zeros(shape=state_number, dtype=np.int8)
    for i in range(state_number):
        policy[i] = np.argmax(Q[i, :])
    return policy


# 0123:{上右下左}
def print_optimal_action(pi, row, col):
    a = list('↑→↓←')
    print(a)
    for i in range(row):
        for j in range(col):
            print(a[pi[i*col+j]], end=' ')
        print()


if __name__ == '__main__':
    eps = 0.1 # 10%概率探索,90%概率利用
    lr = 0.01
    gamma = 0.99

    np.set_printoptions(suppress=True)

    row = 4
    col = 12
    state_number = row * col
    action_number = 4   # 上下左右
    Q = np.zeros((state_number, action_number), dtype=np.float64)
    env = gym.make('CliffWalking-v0')
    for i in tqdm(range(10000)):   # 10000个episode的训练
        env.reset()
        state = 36  
        done = False
        action = eps_greedy(Q, state)
        while not done:
            next_state, reward, done, info = env.step(action)
            next_action = eps_greedy(Q, next_state)
            td_target = reward
            if not done:
                td_target += gamma * max(Q[next_state, :])  # Q-learning
            Q[state, action] += lr * (td_target - Q[state, action])
            state = next_state
            action = next_action
    pi = calc_policy(Q)
    print('Q Table:\n', Q)
    pd.DataFrame(Q).to_excel('Q_table.xlsx', index=True)
    for s_id in range(state_number):
        print(s_id, s_id//col, s_id % col, Q[s_id, :], pi[s_id])
    print_optimal_action(pi, row, col)

    # 输出最终路径
    env.reset()
    state = 36
    done = False
    trace = [{state: (state // col, state % col)}]
    while not done:
        action = np.argmax(Q[state, :])
        state, _, done, _ = env.step(action)
        trace.append({state: (state//col, state % col)})
        if len(trace) > 48:
            break
    print(len(trace), ':', trace)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【python】CliffWalking悬崖寻路问题 的相关文章

  • python sys.path 故障排除

    python 文档位于http docs python org library sys html http docs python org library sys html比如说sys path is 从环境变量 PYTHONPATH 以及
  • 此 TypeError 消息中提到的“代码对象”是什么?

    在尝试使用Python时exec声明 我收到以下错误 TypeError exec arg 1 must be a string file or code object 我不想传递字符串或文件 但什么是代码对象 如何创建一个 创建代码对象的
  • 使用python查找txt文件中字母出现的次数

    我需要从 txt 文件中读取该字母并打印 txt 文件中出现的次数 到目前为止 我已经能够在一行中打印内容 但计数有问题 有人可以指导吗 infile open grades txt content infile read for char
  • 将非常大的Python列表输出保存到mysql表中

    我想将 python 生成的列表的输出保存在 mysql 数据库的表中 该表如下所示 mysql 中的 myapc8 表 https i stack imgur com 4B4Hz png这是Python代码 在此输入图像描述 https
  • 在python中将数据库表写入文件的最快方法

    我正在尝试从数据库中提取大量数据并将其写入 csv 文件 我正在尝试找出最快的方法来做到这一点 我发现在 fetchall 的结果上运行 writerows 比下面的代码慢 40 with open filename a as f writ
  • 多处理中的动态池大小?

    有没有办法动态调整multiprocessing Pool尺寸 我正在编写一个简单的服务器进程 它会产生工作人员来处理新任务 使用multiprocessing Process对于这种情况可能更适合 因为工作人员的数量不应该是固定的 但我需
  • 将 numpy 数组合并为单个 int

    numpy 数组怎么可以这样 10 22 37 45 转换为单个 int32 数字 如下所示 10223745 这可以工作 gt gt gt int join map str 10 22 37 45 10223745 基本上你使用map s
  • 在Python上获取字典的前x个元素

    我是Python的新手 所以我尝试用Python获取字典的前50个元素 我有一本字典 它按值降序排列 k 0 l 0 for k in len dict d l 1 if l lt 51 print dict 举个小例子 dict d m
  • AttributeError:“模块”对象没有属性[重复]

    这个问题在这里已经有答案了 我有两个 python 模块 a py import b def hello print hello print a py print hello print b hi b py import a def hi
  • 给定一个排序数组,就地删除重复项,使每个元素仅出现一次并返回新长度

    完整的问题 我开始在线学习 python 但对这个标记为简单的问题有疑问 给定一个排序数组 就地删除重复项 使得每个 元素只出现一次并返回新的长度 不分配 另一个数组的额外空间 您必须通过修改输入来完成此操作 数组就地 具有 O 1 额外内
  • 如何查找或安装适用于 Python 的主题 tkinter ttk

    过去 3 个月我一直在制作一个机器人 仅用代码就可以完美运行 现在我的下一个目标是为它制作一个 GUI 但是我发现了一些障碍 主要的一个是能够看起来不像一个 30 年前的程序 我使用的是 Windows 7 我仅使用 Python 3 3
  • 是否需要关闭没有引用它们的文件?

    作为一个完全的编程初学者 我试图理解打开和关闭文件的基本概念 我正在做的一项练习是创建一个脚本 允许我将内容从一个文件复制到另一个文件 in file open from file indata in file read out file
  • 在骨架图像中查找线 OpenCV python

    我有以下图片 我想找到一些线来进行一些计算 平均长度等 我尝试使用HoughLinesP 但它找不到线 我能怎么做 这是我的代码 sk skeleton mask rows cols sk shape imgOut np zeros row
  • 在Raspberry pi上升级skimage版本

    我已经使用 Raspberry Pi 2 上的 synaptic 包管理器安装了 python 包 然而 skimage 模块版本 0 6 是 synaptic 中最新的可用版本 有人可以指导我如何将其升级到0 11 因为旧版本中缺少某些功
  • 使用另一个数据帧在数据帧中创建子列

    我对 python 和 pandas 很陌生 在这里 我有一个以下数据框 did features offset word JAPE feature manual feature 0 200 0 aa 200 200 0 200 11 bf
  • 无法通过 Python 子进程进行 SSH

    我需要通过堡垒 ssh 进入机器 因此 该命令相当长 ssh i
  • AWS Lambda 不读取环境变量

    我正在编写一个 python 脚本来查询 Qualys API 中的漏洞元数据 我在 AWS 中将其作为 lambda 函数执行 我已经在控制台中设置了环境变量 但是当我执行函数时 出现以下错误 module initialization
  • 如何给URL添加变量?

    我正在尝试从网站收集数据 我有一个 Excel 文件 其中包含该网站的所有不同扩展名 F i www example com example2 我有一个脚本可以成功从网站中提取 HTML 但现在我想为所有扩展自动执行此操作 然而 当我说 s
  • 带 Flask 的 RPI dht22:无法将第 4 行设置为输入 - 等待 PulseIn 消息超时

    我正在尝试制作一个 Raspberry Pi 3 REST API 使用 DHT22 提供温度和湿度 整个代码 from flask import Flask jsonify request from sds011 import SDS01
  • IndexError - 具有匀称形状的笛卡尔 PolygonPatch

    我曾经使用 shapely 制作一个圆圈并将其绘制在之前填充的图上 这曾经工作得很好 最近 我收到索引错误 我将代码分解为最简单的操作 但它甚至无法执行最简单的循环 import descartes import shapely geome

随机推荐

  • Spring Security:保护Spring应用程序的最佳实践

    目录 1 Spring Security是什么 它的作用是什么 2 Spring Security如何实现身份验证和授权 3 什么是Spring Security过滤器链 4 Spring Security如何防止跨站点请求伪造 CSRF
  • 单片机使用有线以太网联网的解决方案

    1 有MII RMII接口 且内置MAC 的单片机 如 STM32F407 STM32F107 ESP32 方案 外置PHY 且内部程序要运行TCP IP协议栈 PHY芯片推荐列表 LAN8720 LAN8742 DP83848 2 无MI
  • 【编译原理】机测笔记

    A 小C语言 词法分析程序 lt 参考代码 gt include iostream using namespace std 定义6个关键词 string S 6 main for if else int while Todo 设置displ
  • TypeScript:void, null, undefined的区别

    void Typescript中的void 与C语言中使用void定义一个函数时的意义一样 表示该函数没有返回值 function noReturn void console log This function don t have ret
  • win7 Embedded EWF与HORM特性(实战验证)

    前言 这两天在网上搜了很久 发现描述EWF特性的文章 大部分都是关于xp embedded的 真正运用在win7 embedded的少之又少 特别是中文描述的就更少了 于是 将自己这两天整理的结果供大家参考一下 先决条件 1 目标机 能够安
  • iOS动画—UIView动画以及CoreAnimation动画

    温故知新 一 UIView动画 1 1稍微简单点的动画 1 2稍微复杂的动画 二 CoreAnimation动画 CA动画的特点 只能添加到UIView的CALayer上面 必须需要引入
  • 树的创建、遍历及可视化

    许久不复习数据结构了 对于知识点都有些遗忘了 想着来写一些树的遍历 查找 发现连创建一棵树都快忘记了 不过幸好 还是可以看懂别人的代码 还算是有一些基础的 最终也写出来了 因为觉得这样太过于麻烦了 所以 我就在思考一个问题 如何简化这个过程
  • 自动化测试与自动化测试生命周期

    1 1 自动化测试的定义及概述1 1 1 软件测试的定义与分类 软件测试 2 就是在软件投入运行前 对软件需求分析 设计规格说明和编码的最终复查 是软件质量保证的关键步骤 定义1 软件测试是为了发现错误而在规定的条件下执行程序的过程 定义2
  • python常用库之colorama (python命令行界面打印怎么加颜色)

    文章目录 python常用库之colorama python命令行界面打印怎么加颜色 背景 colorama介绍 colorama使用 colorama打印红色闪烁 打印颜色组合 python常用库之colorama python命令行界面
  • JavaWeb基础5——HTTP,Tomcat&Servlet

    导航 黑马Java笔记 踩坑汇总 JavaSE JavaWeb SSM SpringBoot 瑞吉外卖 SpringCloud SpringCloudAlibaba 黑马旅游 谷粒商城 目录 一 Web概述 1 1 Web和JavaWeb的
  • 实战演习(十)——通过LSTM训练天气污染程度预测模型

    我的公众号为 livandata 近期由于工作用到LSTM模型 借这个机会整理一下思路 在网上找了很多资料 受益匪浅 本文参考 https blog csdn net u012735708 article details 82769711
  • 盘点:大数据处理必备的十大工具

    摘要 随着互联网的愈来愈开放 电子商务平台和社交网络的盛行 导致数据在日益增长 给企业管理大量的数据带来了挑战的同时也带来了一些机遇 随着互联网的愈来愈开放 电子商务平台和社交网络的盛行 导致数据在日益增长 给企业管理大量的数据带来了挑战的
  • JupyterNotebook--基础--02--安装

    JupyterNotebook 基础 02 安装 1 安装 pip3 install jupyter pip3 install ipython 2 生成配置文件 用于后面写入ip 端口号 密码等 jupyter notebook gener
  • angularjs php登录验证,AngularJs用户登录时交互及验证步奏详解

    这次给大家带来AngularJs用户登录时交互及验证步奏详解 AngularJs用户登录时交互及验证的注意事项有哪些 下面就是实战案例 一起来看一下 1 静态页面搭建及ng的form表单验证实现 ng disabled loginForm
  • 为什么提高断路器分闸速度,能减少电弧重燃的可能性和提高灭弧能力?

    为什么提高断路器分闸速度 能减少电弧重燃的可能性和提高灭弧能力 答 提高断路器的分闸速度 即在相同的时间内触头间的距离增加较大 电场强度降低 与相应的灭弧室配合 使之在较短时间内建立强有力的灭弧能力 又能使熄弧后的间隙在较短时间内获得较高的
  • c++智能指针——原理与实现

    转子 https www cnblogs com wxquare p 4759020 html 1 智能指针的作用 C 程序设计中使用堆内存是非常频繁的操作 堆内存的申请和释放都由程序员自己管理 程序员自己管理堆内存可以提高了程序的效率 但
  • Mac系统完美安装PHP7详细教程

    使用第三方包homebrew来安装 非常迅速有效 安装教程 1 启动Apache 首先我们启动系统自带的Apache服务 打开Terminal 输入如下指令 开启Apache服务 sudo apachectl start 查看Apache版
  • DCN和DCNv2(可变性卷积)学习笔记(原理代码实现方式)

    DCN和DCNv2 可变性卷积 网上关于两篇文章的详细描述已经很多了 我这里具体的细节就不多讲了 只说一下其中实现起来比较困惑的点 黑体字会讲解 DCNv1解决的问题就是我们常规的图像增强 仿射变换 线性变换加平移 不能解决的多种形式目标变
  • Java的mkdir()与mkdirs()引发的悲剧---关于java的mkdir()方法无法创建文件目录问题

    昨晚深夜在做项目的文件上传 在上传之前要先判断指定的文件目录是否存在 如果不存在就先创建改目录 因为之前已经做过类似的功能了 所以就把判断文件目录以及创建的代码直接copy过来了 然而很郁闷的是 一模一样的代码 这回却遇到一个特别奇葩的问题
  • 【python】CliffWalking悬崖寻路问题

    强化学习 简介 gym库 CliffWalking SARSA Q learning 示例 SARSA Q learning 简介 机器学习 监督学习 非监督学习 强化学习 模仿人类和动物的试错机制进行学习 智能体与环境交互 根据当前的环境