Deep Spatio-Temporal Residual Networks(深度时空残差神经网络)

2023-11-13

目录:

  1. 业务场景
  2. 环境搭建
  3. 数据及目录结构
  4. 模型
  5. 代码(建模、训练)
  6. 预测及结果呈现


文章只是对模型的学习与实践做简要记录,以免日后给忘了,并没有对模型优劣、应用的场景等理论方面有过多分析。适合快速动手搭建,成功运行、分析代码,并学习怎样用keras实现模型的同学。为更好的阅读下文,需提前下载模型代码(https://github.com/lucktroy/DeepST,不包含预测代码和分析参数的代码)和数据集(北京出租车:http://pan.baidu.com/s/1qYq7ja8,BikeNYC:http://pan.baidu.com/s/1mhIPrRE)。


业务场景: 

模型主体是基于残差神经网络,同时加上在时间维度的数据采样,主要的适用场景应是具有空间关系、并具有时间规律的数据。例如预测某区域人群流动、密度,交通流动等,其数据特点:在空间维度上,会受到周围区域的影响,在时间维度上,每天、每周甚至月、年会出现一定规律的变化。模型、数据以及研究成果来源于微软研究院 郑宇教授,相关学习资料https://github.com/lucktroy/DeepST,http://www.weiot.net/article-95817.html,https://www.microsoft.com/en-us/research/publication/deep-spatio-temporal-residual-networks-for-citywide-crowd-flows-prediction/。


环境搭建:

Windows、Linux都可搭建运行环境,建议linux。


(1)linux

Anaconda2-5.0.1-Linux-x86_64.sh对应python 2.7、keras 2.0.6、TensorFlow1.2.1(或theano 0.9.0)


(2)Windows

Anaconda 5.0.1 python 3.6、keras 2.0.8、theano 0.9.0(对应TensorFlow版本不太好找)

注:1. 源代码是python2.7版本的,python3.6需要修改部分代码

2. keras配置文件.keras.json可修改backend后台支持(theano /TensorFlow),具体可百度

3. .keras.json中修改image_data_format为channels_first

4. 添加数据的路径到环境变量(诡异的很):建立DATAPATH,对应的值为数据路径(如:E:\project\peopleFlowPredict\data)


(3)deepst包安装

除上面基础包安装外,还需安装模型自带的包deepst,当前目录下python setup.py install 和python setup.py develop


数据及目录结构:

目录:

data:数据目录,deepst程序依赖包目录,scripts主体程序入口,script中分为北京出租车、纽约自行车、**省人群流动预测:exptTaxiBJ.py模型建模训练程序,HB_analyse.py测试集分析呈现,HB_prediction.py模型预测函数,HB_weightShow.py模型参数分析呈现,MODEL训练生成的模型,RET模型训练过程记录,preprocessing.pkl保存归一化对象,数据集data:分别是北京出租车、纽约自行车、**省人群流动数据集,CATCH为第一次运行程序读取原数据集后,自动生成模型的输入数据集。





数据

数据是以图的形式,逻辑上为矩阵,如北京出租车数据,首先将北京地区栅格化为32*32的逻辑矩阵,每个样本为某个时间点的出租车在整个北京区域栅格的流动分布。数据分为原数据集和缓存数据集,程序第一次运行会读取原始数据,程序会根据设置的时间采样参数读取位置数据,结合天气、节假日等特征生成模型所需要的输入数据。数据集分为训练集和测试集,例如北京出租车测试集选取后四月的数据。


原始数据集中:分为位置数据集、气象数据集和节假日数据集。位置数据集为shape(n,2,32,32),n为样本个数,2为输入输出两维度,32*32map流动值。external feature维度28维,其中星期7维+1维是否为工作日+1维是否为节假日+19维气象,温度、风速归一化,其他为0-1编码。


CACHE:数据集中含有11对键值对,训练数据和测试数据各5对+1对external feature的维度。5对训练数据中包含近(shape(n,6,32,32))、中(shape(n,2,32,32))、远(shape(n,2,32,32))、external 特征(shape(n,28))、对应的时间序列(shape(n,)),其中6和2代表设定近中远的取样参数为3和1,再乘以输入输出两维度。测试集类似。


模型输入数据结构:训练集和测试集结构类似,只是样本个数不同。模型所需要的输入数据的维度为shape(4,n,m,32,32),其,4代表近中远+external ,n为样本个数,m对应的近中远参数值*2,4拆开分别对应是近(shape(n,6,32,32))、中(shape(n,2,32,32))、远(shape(n,2,32,32))、external 特征(shape(n,28))。


模型

三个角度出发,空间、时间、额外因素建模,以残差神经网络(DNN改进,可让网络层数变得更深)为模型基础。时间影响:时间轴上分段采样,近(时刻)、中(天)、远(周)。空间影响:在近中远三段数据上分别采用多层残差神经网络。额外因素:利用全连接神经网络建模。近、中、远三个模型输出经过加权融合后生成图中Xrex,并和Xext激活后生成X't。结合下图理解。


上面近中远三段数据分布残差神经网络建模,其模型结构类似。模型先是通过一层卷积,然后通过n层的残差(每个残差单元含有两个卷积层),最后再通过一层卷积。模型adam训练,学习率0.0002,近、中、远为3、1、1,每天时刻划分T为48(每个24/T生成一个流动数据图),代价函数和检验标准为rmse,模型训练停止方法有两种:一个是early-stopping 一个是设置最大迭代次数。卷积核为3*3,个数为64。


代码(建模及训练)

模型的代码实现是基于keras,keras的banckend可以选择tensorflow或者theano

代码执行:比如北京出租车,在目录scripts\papers\AAAI17\TaxiBJ中执行 python exptTaxiBJ.py 2,其中2代表残差神经网络层数。


exptTaxiBJ.py中,主要函数:def build_model(建模)def cache(生成数据缓存)def read_cache(读取缓存)def main(程序入口函数)。

mmn = pickle.load(open('preprocessing.pkl', 'rb'))获取归一化对象

model = stresnet(c_conf=c_conf...)具体建模函数

TaxiBJ.load_data(...)按照时间段采样原始数据

early_stopping = EarlyStopping最早停止条件定义

model_checkpoint = ModelCheckpoint模型训练条件

history = model.fit模型训练

model.save_weights保存模型参数

model.load_weights载入模型参数

score = model.evaluate模型评估

注:存在两次模型训练,两次的截止条件不一样


STResNet.py建模文件,def stresnet建模主函数,def _residual_unit残差单元构造,其中具体语法参照keras管网


TaxiBJ.py:北京出租车的载入数据函数文件,def load_holiday(读取节假日)def load_meteorol(读取气象)def load_data(读取位置数据)


STMatrix.py:按照设置的长中短时间参数采样,create_dataset主要的采样原数据函数,其中i代表当前要预测的数据,while循环读取每条i的长中短的数据(X)及当前i的数据(Y)


预测及结果呈现

HB_prediction.py(预测函数):模型训练会生产模型参数,保存在modle文件夹中。预测函数首先按照定义的模型建模,并加载训练好的参数,输入测试数据便可得到预测值y,此函数可预测未来一段时间段的y'。其中,predictNext()函数是用预测值作为输入预测下一刻的值。predicVal()函数是获取真实值和真实值作为输入的预测值。


HB_weightShow.py(参数分析函数),此函数是读取keras生成的参数文件,并将远中近的融合时的参数呈现,可以分别看出在远中近三个时间段不同栅格位置的权重影响。



本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Deep Spatio-Temporal Residual Networks(深度时空残差神经网络) 的相关文章

  • P2P原理以及如何实现(整理)

    前言 这几天看了p2p的原理以及实现的demo 整理一下 一共分为三部分 第一是概念原理 第二是demo实现 第三是p2p协议相关以及分类 一 概念原理 比较全面的理解 https zhuanlan zhihu com p 30351943
  • 地球坐标系 (WGS-84) 到火星坐标系 (GCJ-02) 的转换算法

    原文 WGS 84 到 GCJ 02 的转换 即 GPS 加偏 算法是一个普通青年轻易无法接触到的 公开 的秘密 这个算法的代码在互联网上是公开的 详情请使用 Google 搜索 wgtochina lb 整理后的算法代码请参考 https

随机推荐

  • 什么是外包公司?要不要去外包公司?

    01 什么是软件外包 软件外包分为 人力外包和项目外包两个方向 1 1 劳务派遣 指的是把员工外派到对应的用工企业打 短工 比如很多工程师虽然签约了中软国际 东软 文思海辉 软通动力 润和等软件公司 但实际工作地点是在华为 接受华为员相关负
  • c语言还有用吗?

    c语言还有用吗 这个问题有很多人在问 c语言真的没用吗 答案是有用的 用处还很大呢 这门语言虽然是很早以前发明的 新兴语言如c vb 功能十分强大 但每一个能代替C语言 原因 C
  • 【MySQL-约束篇】

    目录 1 空值 Null 2 默认值 3 主键 4 自增 5 唯一键 6 外键 1 空值 Null 先看一个表结构 Field Type Null Key Default Extra id int 11 YES NULL name
  • Java Long类型的查询结果与前端TypeScript显示不一致,后端传值与前端对不上,出现精度损失

    自己折腾了一个项目 使用的技术是SpringBoot MP Vben admin MySql 今天瞎搞的时候发现了一个让我很懵逼的问题 如下图所示 上方是浏览器打印出来的log 下方是数据库实际存在的数据 或者也可以说是后台接口断点调试的数
  • 电赛猜题?我觉得没用,还不如做好这些!

    01 前言 大家好 我是张巧龙 转眼又到22年电赛 这个公众号上有很多同学可能都参加过电赛 有毕业的已经工作的 也有没毕业的今年要参加 我第一次接触电赛是在大一暑期 从参加电赛到指导学生参加电赛 转眼快十年了 20年省赛有6个省一等奖 21
  • 2021-04-09

    jar 自动装配 springboot帮我们配置了什么 xxxxAutoConfiguration 向容器中自动装配组件 xxxxProperties 自动装配类 装配配置文件中自定义的一些内容 要 导入静态资源 首页 jsp 装配扩展Sp
  • ubuntu安装llvm教程

    安装必要工具 sudo apt get install build essential sudo apt get install cmake sudo apt get install python3 8 安装llvm wget https
  • RabbitMQ--基础--7.4--工作模式--路由模式(Direct)

    RabbitMQ 基础 7 4 工作模式 路由模式 Direct 代码位置 https gitee com DanShenGuiZu learnDemo tree master rabbitMq learn rabbitMq 03 1 介绍
  • GB/T28181-2022图像抓拍规范解读及技术实现

    规范解读 GB28181 2022相对2016 增加了设备软件升级 图像抓拍信令流程和协议接口 我们先回顾下规范说明 图像抓拍基本要求 源设备向目标设备发送图像抓拍配置命令 携带传输路径 会话ID等信息 目标设备完成图像传输后 发送图像抓拍
  • java基础-java的发展进程和特性

    1 JAVA的发展历程 1 1上世纪90年代 由于单片机出现引起了自动控制领域的关注 单片机可以大幅度提升电子消费产品的智能化程度 比如电视机顶盒 烤箱 移动电话等 Sun公司成立了Green的项目小组 专攻计算机在家电产品上的嵌入式开发
  • 总结45 SpringMVC框架的基本应用(替代Servlet)

    流程图 面试用 概念 当要实现在Spring框架下的web服务时 那么servlet将无法兼容 因为Spring无法依赖注入到Servlet 因此将会通过SpringMVC来替代servlet 从而提供WEB服务 也就是说 在今后的实际开发
  • 错误 1452:无法添加或更新子行:外键约束失败

    1452 Cannot add or update a child row a foreign key constraint fails goaread views CONSTRAINT views ibfk 1 FOREIGN KEY s
  • [Java多线程 八]---JUC包下的锁和工具类

    原文链接 http www cnblogs com skywang12345 p 3496098 html 概述 根据锁的添加到Java中的时间 Java中的锁 可以分为 同步锁 和 JUC包中的锁 同步锁 实现方式 即通过synchron
  • 解决Xshell显示中文乱码的问题

    执行echo LANG命令输出的是当前的编码方式 执行locale命令得到系统中所有可用的编码方式 要让Xshell不显示乱码 则要将编码方式改为UTF 8 在Xshell中 file gt open gt 在打开的session中选择连接
  • Python习题练习1--变量赋值交换

    题目 已知a的值是1 b的值是2 如何交换a b的值 打印a的值为2 b的值为1 这时候我们就可以思考了 是不是可以直接交换呢 在python中特有这种写法 可以看下下面解法 a 1 定义a的值为1 b 2 定义b的值为2 a b b a
  • SVM算法的参数

    1 c float参数 默认值为1 0 错误项的惩罚系数 c越大 即对分错样本的惩罚程度越大 因此在训练样本中准确率越高 但是泛化能力降低 也就是对测试数据的分类准确率降低 相反 减小c的话 允许训练样本中有一些误分类错误样本 泛化能力强
  • Linux嵌入式学习---C语言之赋值

    Linux嵌入式学习 C语言之赋值 一 语句的作用和分类 1 常见的9种控制语句 2 函数调用语句 3 表达式语句 4 空语句 5 复合语句 二 赋值语句 1 赋值运算符 2 复合的赋值运算符 3 变量赋初值 一 语句的作用和分类 1 常见
  • 彩光和灰光模块_Y5T265 【5G光模块】一个基站前传到底是用6个还是12个,或者是24模块...

    Y5T264 5G光模块 模式分配噪声 接着昨天继续分析5G前传光模块白皮书 问题1 在白皮书的第三页 前传光模块的需求 有不同的传输距离 这个我们能理解 有需要灰光或者彩光的 我们也能理解 但我们不能理解的是为什么前传模块的速率有10G
  • signature今日头条php实现,今日头条_signature 分析

    某天群友问了一句头条翻页算法 然后随手把算法摘出来 现在分享出来 window TAC console log userInfo id a t navigator window navigator userAgent function as
  • Deep Spatio-Temporal Residual Networks(深度时空残差神经网络)

    目录 业务场景 环境搭建 数据及目录结构 模型 代码 建模 训练 预测及结果呈现 文章只是对模型的学习与实践做简要记录 以免日后给忘了 并没有对模型优劣 应用的场景等理论方面有过多分析 适合快速动手搭建 成功运行 分析代码 并学习怎样用ke