神经网络 Embedding层理解; Embedding层中使用预训练词向量

2023-11-10

1、Embedding层理解

高维稀疏特征向量到低维稠密特征向量的转换;嵌入层将正整数(下标)转换为具有固定大小的向量;把一个one hot向量变为一个稠密向量

参考:https://zhuanlan.zhihu.com/p/52787964

Embedding 字面理解是 “嵌入”,实质是一种映射,从语义空间到向量空间的映射,同时尽可能在向量空间保持原样本在语义空间的关系,如语义接近的两个词汇在向量空间中的位置也比较接近。
在这里插入图片描述

应用:

在深度学习推荐系统中,Embedding主要的三个应用方向:

1、在深度学习网络中作为Embedding层,完成从高维稀疏特征向量到低维稠密特征向量的转换;

2、作为预训练的Embedding特征向量,与其他特征向量连接后一同输入深度学习网络进行训练;

3、通过计算用户和物品的Embedding相似度,Embedding可以直接作为推荐系统或计算广告系统的召回层或者召回方法之一。

代码简单说明

keras Embedding接口:
参考:https://keras.io/zh/layers/embeddings/
input_dim 一般大于等于词表数
在这里插入图片描述
Embedding(4, 3, input_length=5) :
1)4是词表去重后的总数量,例如这里np.array([[0,1,2,1,1],[0,1,2,1,3]])共0、1、2、3四个数字,所以词表数为4
2)3是输出单个词向量维度
3)5是input_length输入的长度np.array([[0,1,2,1,1],[0,1,2,1,3]])中[0,1,2,1,1]长度为5

from keras.models import Sequential
from keras.layers import Flatten, Dense, Embedding 
import numpy as np

model = Sequential()
model.add(Embedding(4, 3, input_length=5))  

model.compile('rmsprop', 'mse')
data = np.array([[0,1,2,1,1],[0,1,2,1,3]])

res1 = model.predict(data)
res1
print(model.input_shape)
print(model.output_shape)
'''
(None, 5) #其中 None的取值是batch_size
(None, 5, 3)

input_shape:函数输入,尺寸为(batch_size, 5)的2D张量(矩阵的意思)
output_shape:函数输出,尺寸为(batch_size, 5,3)的3D张量
'''

特别说明:

1)Embedding(4, 3, input_length=5) 会随机初始化一个词表大小的4*3维矩阵
2)data = np.array([[0,1,2,1,1],[0,1,2,1,3]])里的0、1、2、3、4获取上面词表里对应取数,比如0会去取上面标记数字0行,其他类似;如果有小数点会直接取整2.3、2.9都是取2然后去取上面标记数字2行
3)这些随机初始化的Embedding作为神经网络输入会随着网络的训练而变化

在这里插入图片描述

再深入理解

参考:https://spaces.ac.cn/archives/4122
字向量就是one hot的全连接层的参数;下2图红线的参数就是对应的参数权重
在这里插入图片描述
在这里插入图片描述

2、 Embedding层中使用 预训练词向量

weights=[embeddings_matrix], # 重点:预训练的词向量系数
trainable=False # 是否在 训练的过程中 更新词向量

from keras.layers import Embedding

EMBEDDING_DIM = 100 #词向量维度

embedding_layer = Embedding(input_dim = len(embeddings_matrix), # 字典长度
                            EMBEDDING_DIM, # 词向量 长度(100)
                            weights=[embeddings_matrix], # 重点:预训练的词向量系数
                            input_length=MAX_SEQUENCE_LENGTH, # 每句话的 最大长度(必须padding) 
                            trainable=False # 是否在 训练的过程中 更新词向量
                           
                            )

其实就是给赋值下one hot全连接边的权重就是字向量的embedding;后续还是对应查表取值,后续训练也可以自定义权重是否被训练更新

在这里插入图片描述

pytorch Embedding 预训练向量加载

import torch

weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
embedding = torch.nn.Embedding.from_pretrained(weight)
 # Get embeddings for index 1
input = torch.LongTensor([1])  ##表示取表第二个,即结果[4, 5.1, 6.3]
embedding(input)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

神经网络 Embedding层理解; Embedding层中使用预训练词向量 的相关文章

  • Python(1)--Python安装

    本篇作为学习Python笔记 来记录学习过程 安装环境 windows10 官方下载地址 https www python org 有很多的版本 我这里选择了3 7 2 executable表示可执行版 需要安装后使用 embeddable
  • Python基础 NumPy数组相关概念及操作

    NumPy是Python的一种开源的数值计算扩展库 提供 数组支持以及相应的高效处理函数 它包含很多功能 如创建n维数组 矩阵 对数组进行函数运算 数值积分 线性代数计算 傅里叶变换和随机数产生等 Why NumPy 标准的Python用L
  • CentOS8基础篇2:文件系统

    一 文件系统概述 1 文件系统的基本概念 操作系统中负责管理和存储文件信息的软件机构称为文件管理系统 简称文件系统 它规定了文件的存储方式及文件索引方式等信息 文件系统主要由三部分组成 分别是与文件管理相关的软件 被管理的文件和实施文件管理
  • 神经网络中的神经元和激活函数详解

    在上一节 我们通过两个浅显易懂的例子表明 人工智能的根本目标就是在不同的数据集中找到他们的边界 依靠这条边界线 当有新的数据点到来时 只要判断这个点与边界线的相互位置就可以判断新数据点的归属 上一节我们举得例子中 数据集可以使用一条直线区分
  • jdk13快来了,jdk8的这几点应该看看!

    说明 jdk8虽然出现很久了 但是可能我们还是有很多人并不太熟悉 本文主要就是介绍说明一些jdk8相关的内容 主要会讲解 lambda表达式 方法引用 默认方法 Stream 用Optional取代null 新的日志和时间 Completa
  • 自定义view

    自定义View 有这一篇就够了 简书 jianshu com
  • STM32cubeProgrammer连接设置说明

    芯片型号 STM32F427 连接 connect Frequency设置为200 点击connection REG模块 随后device选STM32F427 peripheral选择GPIOD

随机推荐

  • android应用安装成功之后删除apk文件

    摘要 题目 正在运用开辟中碰到须要如许的需供 正在用户下载我们的运用装置以后删除装置包 办理 android会正在每一个中界操纵APK的举措以后收回体系级其余播送 过滤器称号 问题 在应用开发中遇到需要这样的需求 在用户下载我们的应用安装之
  • C语言学前班

    C 语言学前班 10分钟入门 10天练习 哪有那么难 根本用不着科班通过上课学几个月 程序 数据结构 算法 数据结构 容器来存储要进行各种操作的数据 算法 对各种数据进行各种操作 加减乘除 增删改查 判 判断 排 排序 复 复位 输出结果来
  • Some Android licenses not accepted. To resolve this, run: flutter doctor --android-licenses 解决方法

    mopondys iMac zyc flutter doctor Doctor summary to see all details run flutter doctor v Flutter Channel dev v1 16 2 on M
  • NVisionXR for ARCore内测版开放申请

    NVisionXR for ARCore引擎能够帮助开发者快速开发原生ARCore应用 只要你懂基本的Android开发 直接使用Android Studio 即可实现动画模型渲染 粒子特效 音视频播放 灯光渲染等功能 NVisionXR引
  • java线程池的使用

    线程池概述 线程池 Thread Pool 是一种基于池化思想管理线程的工具 使用线程池可以带来诸多好处 降低资源消耗 通过池化技术复用已创建的线程 减少线程创建和销毁的损耗 提高响应速度 任务到达时 特定情况下无需再创建线程 便于管理 j
  • hangfire+bootstrap ace 模板实现后台任务管理平台

    前言 前端时间刚开始接触Hangfire就翻译了一篇官方的教程 翻译 山寨 Hangfire Highlighter Tutorial 后来在工作中需要实现一个异步和定时执行的任务管理平台 就结合bootstrap ace模板和hangfi
  • echarts中多y轴图像(柱,折)

    先看看效果吧 var myChart echarts init document getElementById demo echarts zyyh 放入的id var colors e6bcff a3ffcd fefefe option c
  • C++之explicit的作用介绍

    1 C 中的关键字explicit主要是用来修饰类的构造函数 被修饰的构造函数的类 不能发生相应的隐式类型转换 只能以显示的方式进行类型转换 类构造函数默认情况下声明为隐式的即implicit 隐式转换即是可以由单个实参来调用的构造函数定义
  • 147. 精读《@types react 值得注意的 TS 技巧》

    1 引言 从 types react 源码中挖掘一些 Typescript 使用技巧吧 2 精读 泛型 extends 泛型可以指代可能的参数类型 但指代任意类型范围太模糊 当我们需要对参数类型加以限制 或者确定只处理某种类型参数时 就可以
  • 2022年江西省中职组“网络空间安全”赛项模块B-Web渗透测试

    2022年中职组山西省 网络空间安全 赛项 B 8 Web渗透测试任务书 B 8 Web渗透测试解析 不懂可以私信博主 一 竞赛时间 420分钟 共计7小时 吃饭一小时 二 竞赛阶段 竞赛阶段 任务阶段 竞赛任务 竞赛时间 分值 第 阶段
  • 【MySQL】数据库基本操作:创建删除数据库(Create/Drop),表增删改查

    数据库基本操作 1 启动服务 DOS命令 net start mysql 回车 2 登录MySQL数据库 mysql uroot proot 回车 3 查看MySQL中数据库 show databases 4 创建数据库 create da
  • 2023备战金三银四,Python自动化软件测试面试宝典合集(八)

    马上就又到了程序员们躁动不安 蠢蠢欲动的季节 这不 金三银四已然到了家门口 元宵节一过后台就有不少人问我 现在外边大厂面试都问啥 想去大厂又怕面试挂 面试应该怎么准备 测试开发前景如何 面试 一个程序员成长之路永恒绕不过的话题 每每到这个时
  • GAN生成MNIST数据-PyTorch

    摘抄别处 供自己学习用 直接上代码 代码如下 coding utf 8 import torch autograd import torch nn as nn from torch autograd import Variable from
  • ssm框架整合的配置笔记

    ssm框架整合的配置笔记 打开idea工具新建maven maven环境配置 项目后面下一步下一步就行了 整个的项目Java代码我就不发了 主要是帮助大家快速的搭下配置文件基本信息方便快速复制使用 2 1 pop xml中导入依赖 juni
  • Android通过webservice连接SQLServer 详细教程(数据库+服务器+客户端)

    本文为原创 如果转载请注明出处 http blog csdn net zhyl8157121 article details 8169172 其实之前发过一篇这样的博文http blog csdn net zhyl8157121 artic
  • MongoDB 或者 redis 可以替代 memcached 吗?

    mongodb和memcached不是一个范畴内的东西 mongodb是文档型的非关系型数据库 其优势在于查询功能比较强大 能存储海量数据 mongodb和memcached不存在谁替换谁的问题 和memcached更为接近的是redis
  • 计算机网络思维导图

    转载 原文 http www jingyile cn 496 2 复习计算机网络时画的一些思维导图 希望可以加深自己的理解 教材 计算机网络 第7版 谢希仁编著 第一章 概述 P0 计算机网络 lt 思维导图 gt 第二章 物理层 P1 计
  • TortoiseGit保存git的账号密码

    TortoiseGit保存git的账号密码 问题 电脑安装了git和TortoiseGit 但是每次commit pull push时都需要输入密码 而且是无限弹密码框 输入密码之后 还继续弹框 之前看了许多解决方案都不太行 例如 1 修改
  • pyecharts与matplotlib在使用时的区别和优缺点对比

    目录 简介 pyecharts 以下简介来源于官网 Matplotlib 个人总结 认知 举例 简介 pyecharts 以下简介来源于官网 echarts是一个由百度开源的数据可视化 凭借着良好的交互性 精巧的图表设计 得到了众多开发者的
  • 神经网络 Embedding层理解; Embedding层中使用预训练词向量

    1 Embedding层理解 高维稀疏特征向量到低维稠密特征向量的转换 嵌入层将正整数 下标 转换为具有固定大小的向量 把一个one hot向量变为一个稠密向量 参考 https zhuanlan zhihu com p 52787964