torch.nn.Embedding是否有梯度,是否会被训练

2023-11-06

结论:

会被训练

测试代码:

import torch
from torch.nn import Embedding
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.emb = Embedding(5, 10)
    def forward(self,vec):
        input = torch.tensor([0, 1, 2, 3, 4])
        emb_vec1 = self.emb(input)
        print(emb_vec1)  ### 输出对同一组词汇的编码
        output = torch.einsum('ik, kj -> ij', emb_vec1, vec)
        return output
def simple_train():
    model = Model()
    vec = torch.randn((10, 1))
    label = torch.Tensor(5, 1).fill_(3)
    loss_fun = torch.nn.MSELoss()
    opt = torch.optim.SGD(model.parameters(), lr=0.015)
    for iter_num in range(100):
        output = model(vec)
        loss = loss_fun(output, label)
        print('iter:%d loss:%.2f' % (iter_num, loss))
        opt.zero_grad()
        loss.backward(retain_graph=True)
        opt.step()
if __name__ == '__main__':
    simple_train()

代码作用:

model中的参数就是一个embedding,前向传播总是编码同一个词汇表,然后乘上输入的向量。

经过训练之后embedding乘上向量可以得到全为3的向量。

优化器中的参数仅有embedding,所以embedding是会被训练的。

具体训练策略,不太清楚,估计就是对这个参与运算的embedding(tensor)进行梯度更新。

 

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

torch.nn.Embedding是否有梯度,是否会被训练 的相关文章

  • 尝试理解 Pytorch 的 LSTM 实现

    我有一个包含 1000 个示例的数据集 其中每个示例都有5特征 a b c d e 我想喂7LSTM 的示例 以便它预测第 8 天的特征 a 阅读 nn LSTM 的 Pytorchs 文档 我得出以下结论 input size 5 hid
  • 下载变压器模型以供离线使用

    我有一个训练有素的 Transformer NER 模型 我想在未连接到互联网的机器上使用它 加载此类模型时 当前会将缓存文件下载到 cache 文件夹 要离线加载并运行模型 需要将 cache 文件夹中的文件复制到离线机器上 然而 这些文
  • 为什么 pytorch matmul 在 cpu 和 gpu 上执行时得到不同的结果?

    我试图找出 numpy pytorch gpu cpu float16 float32 数字之间的舍入差异 而我发现的内容让我感到困惑 基本版本是 a torch rand 3 4 dtype torch float32 b torch r
  • Blenderbot 微调

    我一直在尝试微调 HuggingFace 的对话模型 Blendebot 我已经尝试过官方拥抱脸网站上给出的传统方法 该方法要求我们使用 trainer train 方法来完成此操作 我使用 compile 方法尝试了它 我尝试过使用 Py
  • pytorch 的 IDE 自动完成

    我正在使用 Visual Studio 代码 最近尝试了风筝 这两者似乎都没有 pytorch 的自动完成功能 这些工具可以吗 如果没有 有人可以推荐一个可以的编辑器吗 谢谢你 使用Pycharmhttps www jetbrains co
  • Pytorch“展开”等价于 Tensorflow [重复]

    这个问题在这里已经有答案了 假设我有大小为 50 50 的灰度图像 在本例中批量大小为 2 并且我使用 Pytorch Unfold 函数 如下所示 import numpy as np from torch import nn from
  • PyTorch 中的交叉熵

    交叉熵公式 但为什么下面给出loss 0 7437代替loss 0 since 1 log 1 0 import torch import torch nn as nn from torch autograd import Variable
  • 保存具有自定义前向功能的 Bert 模型并将其置于 Huggingface 上

    我创建了自己的 BertClassifier 模型 从预训练开始 然后添加由不同层组成的我自己的分类头 微调后 我想使用 model save pretrained 保存模型 但是当我打印它并从预训练上传时 我看不到我的分类器头 代码如下
  • 如何使用 pytorch 同时迭代两个数据加载器?

    我正在尝试实现一个接收两张图像的暹罗网络 我加载这些图像并创建两个单独的数据加载器 在我的循环中 我想同时遍历两个数据加载器 以便我可以在两个图像上训练网络 for i data in enumerate zip dataloaders1
  • 在Pytorch中计算欧几里得范数..理解和实现上的麻烦

    我见过另一个 StackOverflow 线程讨论计算欧几里德范数的各种实现 但我很难理解特定实现的原因 如何工作 该代码可以在 MMD 指标的实现中找到 https github com josipd torch two sample b
  • Pytorch 与 joblib 的 autograd 问题

    将 pytorch 的 autograd 与 joblib 混合似乎存在问题 我需要并行获取大量样本的梯度 Joblib 与 pytorch 的其他方面配合良好 但是 与 autograd 混合时会出现错误 我做了一个非常小的例子 显示串行
  • ValueError:使用火炬张量时需要解压的值太多

    对于神经网络项目 我使用 Pytorch 并使用 EMNIST 数据集 已经给出的代码加载到数据集中 train dataset dsets MNIST root data train True transform transforms T
  • 如何将 35 类城市景观数据集转换为 19 类?

    以下是我的代码的一小段 使用它 我可以在城市景观数据集上训练名为 lolnet 的模型 但数据集包含 35 个类别 标签 0 34 imports trainloader torch utils data DataLoader datase
  • 如何以干净高效的方式在 pytorch 中获得小批量?

    我试图做一件简单的事情 即使用火炬通过随机梯度下降 SGD 训练线性模型 import numpy as np import torch from torch autograd import Variable import pdb def
  • 当前向包含多个自动分级节点时,PyTorch 关于使用非完整后向挂钩的警告

    最近升级后 当运行 PyTorch 循环时 我现在收到警告 当前向包含多个自动分级节点时使用非完整后向钩子 训练仍在运行并完成 但我不确定应该将其放置在哪里register full backward hook功能 我尝试将它添加到神经网络
  • 如何解决错误:PyTorch 中预期输入批量大小与目标批量大小不匹配?

    我尝试通过 PyTorch 在 CIFAR10 数据集上创建逻辑模型 但是我收到错误 ValueError 预期输入batch size 900 与目标batch size 300 匹配 我认为正在发生的事情是 3 100 是 300 所以
  • PyTorch 如何计算二阶雅可比行列式?

    我有一个正在计算向量的神经网络u 我想计算关于输入的一阶和二阶雅可比矩阵x 单个元素 有人知道如何在 PyTorch 中做到这一点吗 下面是我项目中的代码片段 import torch import torch nn as nn class
  • 无法在jupyter笔记本中导入torch

    系统 macOS 10 13 6 蟒蛇 3 7 蟒蛇3 我遇到麻烦时import torch在 jupyter 笔记本中 ModuleNotFoundError No module named torch 这是我安装 pytorch 的方法
  • 设置 torch.gather(...) 调用的结果

    我有一个形状为 n x m 的 2D pytorch 张量 我想使用索引列表来索引第二个维度 可以使用 torch gather 完成 然后然后还设置新值到索引的结果 Example data torch tensor 0 1 2 3 4
  • 从 torch.autograd.gradcheck 导入 zero_gradients

    我想复制代码here https github com LTS4 DeepFool blob master Python deepfool py 并且我在 Google Colab 中运行时收到以下错误 ImportError 无法导入名称

随机推荐

  • Ubuntu 使用笔记

    更新时间 2020 3 24 文章目录 一 安装 二 快捷键 三 常用命令 3 1 软件安装 3 2 程序编写 四 软件使用 1 终端使用 2 Vim编辑器 3 Linuxqq 4 基于wine 的软件下载 5 CAJViewer使用 6
  • BEVDet:High-Performance Multi-Camera 3D Object Detection in Bird-Eye-View 论文笔记

    原文链接 https arxiv org pdf 2112 11790 pdf 1 引言 如下图所示 本文提出的BEVDet包含4个部分 即图像编码器 提取图像特征 视图转换器 将图像视图转化为BEV BEV编码器 进一步提取BEV特征 和
  • MATLAB算法实战应用案例精讲-【人工智能】语义分割(最终篇)(附实战应用案例及代码实现)

    目录 前言 语义分割面临的问题及解决方法 几个高频面试题目 什么是图像中的语义信息
  • MySQL8.0 :grant all privileges on *.* to 报错问题

    报了这个错误 ERROR 1064 42000 You have an error in your SQL syntax check the manual that corresponds to your MySQL server vers
  • 阿里是如何做Code Review的?

    作为卓越工程文化的一部分 Code Review其实一直在进行中 只是各团队根据自身情况张驰有度 松紧可能也不一 这里简单梳理一下CR的方法和团队实践 一 为什么要CR 提前发现缺陷 在CodeReview阶段发现的逻辑错误 业务理解偏差
  • 通俗易懂的java设计模式(7)-原型模式

    1 什么是原型模式 原型模式提供了一种创建对象的模式 它是指用原型实例创建对象的种类 并且通过拷贝这些原型 创建新的对象 用一个很生动形象的例子 孙悟空拔出一根猴毛 变出其他和自己一模一样的小孙悟空 在这里 原型实例就是孙悟空 拔出的猴毛通
  • PCL_BoundaryEstimation边界提取

    pcl BoundaryEstimation用于散乱点云的边界提取 但应该只适用于简单的点云 过于复杂的话效果应该不太好 同时 需要pcl NormalEstimation先计算法线 算起来也挺慢的 https download csdn
  • 军用软件国家标准

    GB T 11457 2006 信息技术 软件工程术语 SJ 20778 2000 软件开发与文档标准 GJB 2786A 2009 军用软件开发通用要求 GJB 438B 2009 军用软件开发文档通用要求 GJB 4072A 2006
  • 公司职员薪水管理系统(List)

    集合初步完成下面的功能需求 做公司职员薪水管理系统 完成以下功能 1 当有新员工时 将加入该管理系统 2 根据员工号 显示该员工信息 3 可以显示所有员工的信息 4 可以修改员工的薪水 5 当员工离职时 从该系统中删除该员工 6 可以将员工
  • Go 性能

    写性能测试在Go语言中是很便捷的 go自带的标准工具链就有完善的支持 下面我们来从Go的内部和系统调用方面来详细剖析一下Benchmark这块 Benchmark Go做Benchmark只要在目录下创建一个 test go后缀的文件 然后
  • Java项目毕业设计:电脑城销售商城网站(java+springboot+vue+mysql)

    运行环境 开发工具 IDEA Eclipse 数据库 MYSQL5 7 应用服务 Tomcat7 Tomcat8 使用框架springboot vue 项目介绍 随着科技的发展 人们对电子产品的依赖越来越严重 尤其是像电脑和手机这些日常生活
  • 前端获取数据常见的几种方法

    1 原生获取ajax
  • requset-使用BeanUtils封装表单提交的数据到javaBean对象中

    request对象请求参数过多 可以将数据封装到对象 使用BeanUtils解决这个问题 设置一个登录页面准备提交表单数据 username password 导入BeanUtils相关jar包 创建Servlet获取请求参数 调用Bean
  • 84.柱状图中最大的矩形

    84 柱状图中最大的矩形 给定 n 个非负整数 用来表示柱状图中各个柱子的高度 每个柱子彼此相邻 且宽度为 1 求在该柱状图中 能够勾勒出来的矩形的最大面积 示例 1 输入 heights 2 1 5 6 2 3 输出 10 解释 最大的矩
  • vim 删除一整块,vim 删除一整行

    陈永鹏的微博 陈永鹏的csdn博客地址 http blog csdn net chenyoper陈永鹏的博客园地址 http www cnblogs com Yoperchen dd 删除游标所在的一整行 常用 ndd n为数字 删除光标所
  • JavaScript判断数组是否为空、 判断数据类型

    数组 let arr 在进行if 判断数组时 在new Array 一个空数组时 是一个Object对象 所以if arr 时是true 在进行数组直接与true和false的布尔类型比较时 默认是将数组和布尔类型都转化为了Number类型
  • Vtk多个actor绑定选中事件

    Vtk多个actor绑定选中事件 1 交互只有 放大 移动 沿着z轴旋转 2 增加选中回调 3 增加部分模型隐藏 效果 项目地址 在官方案例基础上改的 案例 https kitware github io vtk examples site
  • 后端系统开发之工作和面试中的gdb

    gdb是C C 程序员必备的专业技能 工作中gdb最常用的场景有两个 一个是分析core文件 另一个是调试程序 分析core文件的方法如下 1 gdb 程序名 core文件名 2 bt或where命令查看堆栈信息 3 进入某个栈 f N f
  • Ubuntu + CUDA9.0 + tensorflow-gpu 安装过程

    Ubuntu CUDA9 0 tensorflow gpu 安装过程 简介 tensorflow支持CUDA9 0和cuDNN7 0 因此本教程是在该版本基础上进行安装的 我的电脑CPU是Intel core i7 4710MQ GPU是G
  • torch.nn.Embedding是否有梯度,是否会被训练

    结论 会被训练 测试代码 import torch from torch nn import Embedding class Model torch nn Module def init self super Model self init