睿智的seq2seq模型1——利用seq2seq模型对数字进行排列

2023-11-17

学习前言

快乐学习新知识,seq2seq还是很重要的!
在这里插入图片描述

seq2seq简要介绍

seq2seq属于encoder-decoder结构的一种。

seq2seq的encoder常用循环神经网络,可以使用LSTM或者RNN,当输入一个字符串的时候,可以对其进行特征提取,获得语义编码C。

而decoder则将encoder得到的编码C作为输入,输入到decoder的RNN中,得到输出序列。

语义编码C是输入内容的特征集合体,decoder可以讲这个特征集合体解码成输出序列。
在这里插入图片描述

利用seq2seq实现数组排序

实现的目标如下:
在这里插入图片描述
当输入一列无序的数列后,可以对其进行排序。

实现方式

一、对输入格式输出格式进行定义

将输入的1、2、3的格式转化为如下格式,第几位为1代表这个数字是几:
由于数组的序列号是从0开始的,所以:
1转化为[0,1,0,0];
2转化为[0,0,1,0];
3转化为[0,0,0,1]。
这样的格式转换更有利于代码处理。
在这里插入图片描述

二、建立神经网络

1、神经网络的输入

假设输入的数字组合是2、1、3。
输入的顺序就是循环神经网络的时间顺序STEP,也就是按时间序列输入2、1、3。
本例子中输入的维度为4。

此时输入的数组为:

[[0,0,1,0],
 [0,1,0,0],
 [0,0,0,1]]

通过如下代码就可以完成输入编码神经网络的构建。

inputs = Input([STEP_SIZE,INPUT_SIZE])
x = LSTM(CELL_SIZE, input_shape=(STEP_SIZE, INPUT_SIZE))(inputs)
x = Dropout(0.25)(x)

示意图如下:
在这里插入图片描述

2、语义编码c的处理

上一步获得的LSTM的输出为一个CELL_SIZE维的输出向量,是一个浓缩的特征。

由于我们的decoder也是一个LSTM模型,输入就是这个浓缩的特征,而且我们需要对STEP_SIZE个输出进行排序。所以decoder的输入序列的时间SIZE也是STEP_SIZE大小。

我们可以利用RepeatVector函数讲encoder的输出的语义编码c重复STEP_SIZE次。
实现CELL_SIZE维向量 -> (STEP_SIZE, CELL_SIZE)维向量的转换
实现代码为:

# CELL_SIZE-> STEP_SIZE,CELL_SIZE
# 分为STEP_SIZE个时间序列
x = RepeatVector(STEP_SIZE)(x)

示意图如下:
在这里插入图片描述

3、输出神经网络

输出神经网络的输入是上一步获得的(STEP_SIZE, CELL_SIZE)维向量。

输出也是(STEP_SIZE, CELL_SIZE)维向量。

因此输出神经网络由两部分组成,第一部分是普通的LSTM网络,其会输出(STEP_SIZE, CELL_SIZE)维度的向量。

第二部分是TimeDistributed+Dense,TimeDistributed+Dense意味着对LSTM网络的每一个STEP的输出进行全连接,会输出(STEP_SIZE, INPUT_SIZE )维度的向量。
实现代码、执行思路如下:

# STEP_SIZE, CELL_SIZE -> STEP_SIZE, NERVE_NUM
# 当return_sequences=True时,会输出时间序列
# STEP_SIZE代表时间序列,CELL_SIZE代表每一个时间序列的输出
x = LSTM(CELL_SIZE, return_sequences=True)(x)

# 对每一个STEP进行全连接
x = TimeDistributed(Dense(INPUT_SIZE))(x)
x = Dropout(0.5)(x)
x = Activation('softmax')(x)

示意图如下:
在这里插入图片描述

4、网络构建部分全部代码

全部示意图:
在这里插入图片描述

inputs = Input([STEP_SIZE,INPUT_SIZE])
x = LSTM(CELL_SIZE, input_shape=(STEP_SIZE, INPUT_SIZE))(inputs)
x = Dropout(0.25)(x)
# CELL_SIZE-> STEP_SIZE,CELL_SIZE
# 分为STEP_SIZE个时间序列
x = RepeatVector(STEP_SIZE)(x)

# STEP_SIZE, CELL_SIZE -> STEP_SIZE, NERVE_NUM
# 当return_sequences=True时,会输出时间序列
# STEP_SIZE代表时间序列,CELL_SIZE代表每一个时间序列的输出
x = LSTM(CELL_SIZE, return_sequences=True)(x)

# 对每一个STEP进行全连接
x = TimeDistributed(Dense(INPUT_SIZE))(x)
x = Dropout(0.5)(x)
x = Activation('softmax')(x)

model = Model(inputs,x)

全部实现代码

实现代码为:

from keras.models import Sequential
from keras.layers.core import Activation, RepeatVector, Dropout, Dense
from keras.layers import TimeDistributed,Input
from keras.models import Model
from keras.layers.recurrent import LSTM
import numpy as np


def encode(X, seq_len, vocab_size):
    x = np.zeros((len(X), seq_len, vocab_size), dtype=np.float32)
    for ind, batch in enumerate(X):
        for j, elem in enumerate(batch):
            x[ind, j, elem] = 1
    return x


def batch_gen(batch_size=32, seq_len=10, max_no=100):
    while True:
        x = np.zeros((batch_size, seq_len, max_no), dtype=np.float32)
        y = np.zeros((batch_size, seq_len, max_no), dtype=np.float32)

        X = np.random.randint(max_no, size=(batch_size, seq_len))
        Y = np.sort(X, axis=1)

        for ind, batch in enumerate(X):
            for j, elem in enumerate(batch):
                x[ind, j, elem] = 1

        for ind, batch in enumerate(Y):
            for j, elem in enumerate(batch):
                y[ind, j, elem] = 1
        yield x, y


batch_size = 64
STEP_SIZE = 10
INPUT_SIZE = 75
CELL_SIZE = 100
inputs = Input([STEP_SIZE,INPUT_SIZE])
x = LSTM(CELL_SIZE, input_shape=(STEP_SIZE, INPUT_SIZE))(inputs)
x = Dropout(0.25)(x)
# CELL_SIZE-> STEP_SIZE,CELL_SIZE
# 分为STEP_SIZE个时间序列
x = RepeatVector(STEP_SIZE)(x)

# STEP_SIZE, CELL_SIZE -> STEP_SIZE, NERVE_NUM
# 当return_sequences=True时,会输出时间序列
# STEP_SIZE代表时间序列,CELL_SIZE代表每一个时间序列的输出
x = LSTM(CELL_SIZE, return_sequences=True)(x)

# 对每一个STEP进行全连接
x = TimeDistributed(Dense(INPUT_SIZE))(x)
x = Dropout(0.5)(x)
x = Activation('softmax')(x)

model = Model(inputs,x)
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='adam',
              metrics=['accuracy'])

for ind, (X, Y) in enumerate(batch_gen(batch_size, STEP_SIZE, INPUT_SIZE)):
    loss, acc = model.train_on_batch(X, Y)
    if ind % 250 == 0:
        print("ind:",ind)
        testX = np.random.randint(INPUT_SIZE, size=(1, STEP_SIZE))
        test = encode(testX, STEP_SIZE, INPUT_SIZE)
        print("before is")
        print(testX)
        y = model.predict(test, batch_size=1)
        print("actual sorted output is")
        print(np.sort(testX))
        print("sorting done by LSTM is")
        print(np.argmax(y, axis=2))
        print("\n")

效果为:

ind: 0
before is
[[68 27 33 21 14  8 26 36 45 18]]
actual sorted output is
[[ 8 14 18 21 26 27 33 36 45 68]]
sorting done by LSTM is
[[55 55 55 55 55 55 55 52 52 52]]


ind: 250
before is
[[17 63 41 63 49 64 64 44 16 25]]
actual sorted output is
[[16 17 25 41 44 49 63 63 64 64]]
sorting done by LSTM is
[[ 7 16 32 44 49 49 49 63 64 69]]


ind: 500
before is
[[ 4  7 18 12  9  5 25 48 49 46]]
actual sorted output is
[[ 4  5  7  9 12 18 25 46 48 49]]
sorting done by LSTM is
[[ 0  2  7  7 18 18 25 48 48 58]]

……

ind: 3500
before is
[[36 19 72 45 34 67 39 38 60  3]]
actual sorted output is
[[ 3 19 34 36 38 39 45 60 67 72]]
sorting done by LSTM is
[[ 3 19 34 36 38 39 45 60 67 72]]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

睿智的seq2seq模型1——利用seq2seq模型对数字进行排列 的相关文章

  • Python3, 多种方法实现文件/目录的监听,只想说一个字:泰裤辣。

    多种方法实现文件 目录监听 1 引言 2 代码实战 2 1 os模块 2 2 watchdog库 2 2 1 安装 2 2 2 示例 2 3 inotify 2 3 1 安装 2 3 2 示例 3 总结 1 引言 小屌丝 鱼哥 帮我看下这段
  • 说透 Nacos 一致性协议

    1 Nacos 致性协议 1 1 为什么 Nacos 需要 致性协议 Nacos尽可能减少用户部署以及运维成本 做到用户只需要 个程序包 就快速单机模式启动 Nacos 或集群模式启动 Nacos 而 Nacos 是 个需要存储数据的组件
  • java基础—HashMap实现原理,如何保证HashMap的线程安全

    在多线程条件下 容易导致死循环 具体表现为CPU使用率100 因此多线程环境下保证 HashMap 的线程安全性 主要有如下几种方法 1 替换成Hashtable Hashtable通过对整个表上锁实现线程安全 因此效率比较低 2 使用Co
  • 台式计算机的配置怎么看,台式电脑配置怎么看

    电脑的性能 价格决定于电脑的配置 很多人电脑新手在购买电脑的时候对电脑配置的相关情况不太了解 导致新买的电脑频频出问题 所以了解自己电脑配置是很重要的 这里我们就简单的来说说台式电脑配置怎么看 电脑配置一般CPU 显卡 主板 内存 硬盘 显

随机推荐

  • lambda表达式二之Stream流

    Stream流 是数据渠道 用于操作数据源 集合 数组等 所生成的元素序列 集合讲的是数据 流讲的是计算 Stream自己不会存储元素 Stream不会改变源对象 会返回一个持有结果的新Stream Stream操作是延迟执行的 意味着会等
  • LeetCode312. 戳气球 (分治,记忆化搜索,动态规划)

    LeetCode312 戳气球 解题思路 记忆化搜索 动态规划 解题思路 官方题解 参考题解 核心思想 由于戳气球的操作会导致两个气球从不相邻变成相邻 使得后续操作难以处理 于是我们倒过来看这些操作 将全过程看成每次添加一个气球 solve
  • CMake入门实践(一) 什么是cmake

    一 CMake简介 CMake是一个跨平台的安装 编译 工具 可以用简单的语句来描述所有平台的安装 编译过程 他能够输出各种各样的makefile或者project文件 能测试编译器所支持的C 特性 类似UNIX下的automake 只是
  • mac AE 快捷键

    项目窗口 新项目 Ctrl Alt N 打开项目 Ctrl O 打开项目时只打开项目窗口 按住Shift键 打开上次打开的项目 Ctrl Alt Shift P 保存项目 Ctrl S 选择上一子项 上箭头 选择下一子项 下箭头 打开选择的
  • Flink + Hudi 实现多流拼接(大宽表)

    1 背景 经典场景 Flink 侧实现 业务侧通常会基于实时计算引擎在流上做多个数据源的 JOIN 产出这个宽表 但这种解决方案在实践中面临较多挑战 主要可分为以下两种情况 维表 JOIN 场景挑战 指标数据与维度数据进行关联 其中维度数据
  • .net 配置网关(使用Ocelot)

    本文演示一个最简单的demo 来模拟如何通过网关来访问服务 而不是直接访问服务 创建三个asp net core web api项目 一个作为网关 两个作为服务 分别配置项目的访问路径 网关的项目使用https localhost 5001
  • MQTT-java使用说明

    MQTT java使用说明 本文的资料下载 链接 https pan baidu com s 1OCfsQ NqcehKy86kYkA wg pwd 1234 提取码 1234 MQTT基本介绍 MQTT是一个客户端服务端架构的发布 订阅模
  • DNS在架构设计中的巧用

    DNS在架构设计中的巧用 一 缘起 一个http请求从客户端到服务端 整个执行流程是怎么样的呢 一个典型流程如上 1 客户端通过域名daojia com请求dns server 2 dns server返回域名对应的外网ip 1 2 3 4
  • python拟合二次函数_Python 最小二乘法 拟合 二次曲线

    最小二乘 Python 二次拟合 随机生成数据 并且加上噪声干扰 构造需要拟合的函数形式 使用最小二乘法进行拟合 输出拟合后的参数 将拟合后的函数与原始数据绘图后进行对比 import numpy as np import matplotl
  • 讯飞aiui的webapi+python使用记录

    1 demo一直不能出语义理解 我以为是我的问题 直到 当前页面配置修改仅在测试环境生效 设备端体验需要SDK传参时在情景模式后加 box 或 更新发布 至生产环境体验 这不坑爹吗 记得在情景模式后加 box
  • BFS的常见算法题-二叉树的最小深度

    背景 对某个二叉树 我们除了用肉眼可以看出其深度 还可以用算法来计算出它的深度 比如 下面的二叉树 一共有三层 它的深度就是3 如果某个分支的叶子结点没有左右子节点 就是它深度中较小的一个 leetcode中 有一题求最小深度 如下图 最小
  • 各种日志关系

    slf4j是日志的门面 也是会说是日志框架
  • 【Unity开发】Unity获取设备屏幕分辨率

    using UnityEngine using System Collections public class ExampleClass MonoBehaviour void Start Resolution resolutions Scr
  • Vscode ssh远程连接失败解决办法

    问题描述 Vscode 通过remote ssh连接远程ubuntu时出现 192 168 x x has fingerprint SHA256 如下图所示 按照提示选择 continue 然后输入正确密码却显示Permission Den
  • java md5 解密_“实用”的JAVA开发工具类库

    简介 Hutool是一个小而全的Java工具类库 通过静态方法封装 降低相关API的学习成本 提高工作效率 使Java拥有函数式语言般的优雅 让Java语言也可以 甜甜的 Hutool中的工具方法来自于每个用户的精雕细琢 它涵盖了Java开
  • 免费的 AI 代码辅助工具-codeium

    不是标题党 是真免费 几天之前 GitHub 发布了 GitHub Copilot X 这是一款基于 OpenAI 的 GPT 4 模型开发的 AI 代码辅助工具 看介绍应该是和 Microsoft 365 Copilot 很像的产物 属于
  • ChatGLM-6B部署笔记

    前言 本笔记基于ChatGLM 6B开源网站 https github com THUDM ChatGLM 6B 完成ChatGLM的本地部署 首先电脑已经安装python3 10 anaconda pycharm2022 3 如若使用本地
  • Application.targetFrameRate安卓apk上设置帧率问题

    一般游戏为了更好的适配各种机型 会对游戏进行锁帧 就会使用Application targetFrameRate这个方法设置帧率 pc上测试是没问题的 但是安卓机上面测试就会发现 设置的帧率只能在30和60帧两个数值来回跳动 参考了unit
  • 21-angular.merge

    通过从src对象 s 复制自己的可枚举属性到dst 深度扩展了目标对象的dst 您可以指定多个src对象 如果您想保留原始对象 那么可以通过将空对象作为目标来实现 var object angular merge object1 objec
  • 睿智的seq2seq模型1——利用seq2seq模型对数字进行排列

    睿智的seq2seq模型1 利用seq2seq模型对数字进行排列 学习前言 seq2seq简要介绍 利用seq2seq实现数组排序 实现方式 一 对输入格式输出格式进行定义 二 建立神经网络 1 神经网络的输入 2 语义编码c的处理 3 输