Pytorch:了解 nn.Module 类内部如何工作

2024-03-27

一般来说,一个nn.Module可以由子类继承,如下所示。

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)  # 

class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.fc1 = nn.Linear(20, 1)
        self.apply(init_weights)

    def forward(self, x):
        x = self.fc1(x)
        return x

我的第一个问题是,为什么我可以简单地运行下面的代码,甚至我的__init__没有任何正论training_signals看起来像那样training_signals被传递给forward()方法。它是如何工作的?

model = LinearRegression()
training_signals = torch.rand(1000,20)
model(training_signals)

第二个问题是如何self.apply(init_weights)内部工作?是否在调用之前执行forward method?


Q1:为什么我可以简单地运行下面的代码,甚至我的__init__没有任何位置参数training_signals看起来像那样training_signals被传递给forward()方法。它是如何工作的?

首先,__init__当您运行此行时调用:

model = LinearRegression()

正如您所看到的,您没有传递任何参数,也不应该传递任何参数。您的签名__init__与基类之一相同(运行时调用super(LinearRegression, self).__init__())。如你看到的here https://github.com/pytorch/pytorch/blob/be757957bace28100e571ec7914765020be4a069/torch/nn/modules/module.py#L69, nn.Module的 init 签名很简单def __init__(self)(就像你的一样)。

Second, model现在是一个对象。当您运行以下行时:

model(training_signals)

你实际上是在调用__call__方法和传递training_signals作为位置参数。如你看到的here https://github.com/pytorch/pytorch/blob/be757957bace28100e571ec7914765020be4a069/torch/nn/modules/module.py#L522-L550,除其他事项外,__call__方法调用forward method:

result = self.forward(*input, **kwargs)

传递所有参数(位置和命名)__call__ to the forward.

Q2:怎么办?self.apply(init_weights)内部工作?是在调用forward方法之前执行的吗?

PyTorch 是开源的,因此您只需转到源代码并检查它即可。如你看到的here https://github.com/pytorch/pytorch/blob/be757957bace28100e571ec7914765020be4a069/torch/nn/modules/module.py#L248-L288,实现非常简单:

def apply(self, fn):
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self

引用该函数的文档:it“applies fn递归到每个子模块(由.children()) 也self》。基于实现,你还可以了解到需求:

  • fn必须是可调用的;
  • fn仅接收一个输入作为Module object;
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch:了解 nn.Module 类内部如何工作 的相关文章

随机推荐

  • Mono mkbundle 工具无法创建二进制文件,并抱怨输出文件不可用

    根据来自的建议这个线程 https stackoverflow com questions 551554 can you compile c without using the net framework在运行没有 NET 的 C 应用程序
  • postgresql 存储过程开始提交结束

    实际上 在执行 postgresql 存储过程时我很困惑 我从某处学到了以下内容 create or replace procedure update dba trades language plpgsql as begin CODE BL
  • 完成部分网格并使其不漏水

    我正在从 RealSense 相机捕获点云 并使用 Trimesh 库将它们转换为网格 问题是我只能从中得到一个不防水的网格 如何 完成 网格并使其防水 I tried trimesh repair broken faces mesh co
  • jquery切换 - 在切换功能之间切换?

    大家好 我喜欢 jquery 的切换功能 然而目前我面临一个小问题 我不知道如何以最好的方式解决它 我有一个名为 searchbox 的 div 它取决于用户设置是隐藏还是可见 如果我单击按钮 则触发的切换功能应该是 slideDown s
  • 创建Python包并导入模块

    我正在尝试编写我的第一个 Python 包 几乎所有模块都需要使用 NumPy 我应该写吗import numpy在每个模块中或者包中是否有某个地方我可以将其导入一次 以便每个模块都可以使用它 最好的方法是什么 是的 只需将其导入到需要的地
  • 在 iTunes Connect 中提交应用程序时附加屏幕截图的顺序

    我目前正在提交我的应用程序以供审核 并且我已经上传了主屏幕截图 但不确定如何让我的其他屏幕截图以正确的顺序显示 您必须在上传之前将它们全部选择 并且没有任何指示它们的顺序 有人可以告诉我您是否需要按正确顺序或相反顺序选择屏幕截图吗 以相反的
  • Cloud Dataflow 中的作业失败:启用 Dataflow API

    我目前正在尝试将 Dataflow 与 Pub Sub 结合使用 但收到此错误 工作流程失败 原因 6e74e8516c0638ca 刷新您的凭据时出现问题 请检查 1 为您的项目启用Dataflow API 2 您的项目有一个机器人服务帐
  • 在 Visual Studio 2010 项目中包含外部库

    我是视觉工作室的新手 似乎无法在任何地方找到这个问题的答案 我正在使用 VS2010 进行 VC 项目 我有另一个项目构建到 lib 文件中并设置为参考 但无法弄清楚如何实际包含标头 事实证明谷歌毫无用处 请帮忙 通常 这是通过将包含文件所
  • qt/c++ 动态命名变量

    我正在为我的一项大学作业在 Qt 中开发一个 html 编辑器 并且我在某些变量的命名方面遇到了问题 问题是这样的 当用户决定加载他们的 项目 时 程序会迭代该文件夹并查找其中有多少个 html 文件 然后它会创建要显示的选项卡 我有一个自
  • Symfony2 Assetic 路由和资源错误

    我有一个模板 例如index html php 我在其中使用 php assetic 加载器 如下所示 如果我对模板文件进行任何更改 我会得到路线 assetic 2b431f4 不存在 如果我改变 assetic use controll
  • C 中逐个字符读取文件

    我正在用 C 语言编写 BF 解释器 但在读取文件时遇到了问题 我以前用过scanf为了读取第一个字符串 但是你的 BF 代码中不能有空格或注释 现在这就是我所拥有的 char readFile char fileName FILE fil
  • EditText 随选择缩放

    我有一个EditText我想缩放它并滚动setScaleX setScaleY它工作正常 文本正在正确的位置进行编辑 但是当我尝试选择文本时 它会将选择手柄绘制到位置 就像文本未缩放时一样 我们都知道bug https code googl
  • 对公司名称的 DataFrame 进行非规范化 [第 1 部分]

    我有一个公司名称的 Pandas DataFrame 其结构如下 import numpy as np import pandas as pd df pd DataFrame name Nitron Pulset Rotaxi postal
  • 我如何知道创建项目时使用的是哪个版本的 Delphi

    如果我有 Delphi 项目的完整源代码 我如何知道使用哪个版本 即 Delphi 5 Delphi 7 Delphi 2010 等 来创建它 而无需在 Delphi 中打开它 我有许多可以追溯到 Delphi 6 时代的项目 我想对它们进
  • OpenId Connect 与 wso2 仅返回子声明

    当我询问用户 WSO2 的信息时 响应仅包含他的子信息 Request GET https srv wso2 domain com 9443 oauth2 userinfo schema openid Request headers Acc
  • AngularJS - 涉及异步数据的依赖注入

    我想让当前登录的用户 ID 和用户名可供我的 Angular 指令使用 我创建了一个 API 端点来检索此信息 以及一些其他信息 问题是 API 调用是异步的 var url baseUrl api sessions http get ur
  • 没有指定 dataType 的自定义 ajaxTransport 函数不会触发(根本!)

    我一直在尝试设置jQuery 的自定义 ajaxTransports http api jquery com extending ajax Transports在我们的产品的某些场景下缩短某些工作流程 然而 我在让这些运输受到尊重方面取得了
  • 将 CSV 文件转换为 Java - 向后复制

    我之前问过一个关于在java中将CSV文件转换为二维数组的问题 我完全重写了我的代码 几乎要重新编写了 我现在遇到的唯一问题是它正在向后打印 换句话说 列打印在行应该打印的位置 反之亦然 这是我的代码 int board new int 2
  • D 中是否有相当于 C++ 的 Future/Promise ?

    D 世界中是否存在 C 世界中的未来 承诺等价物 当然有标准并行度 http dlang org phobos std parallelism html但它并不完全具有承诺 未来组合的功能 没有相当于获取未来或设置结果或异常的功能 您也不能
  • Pytorch:了解 nn.Module 类内部如何工作

    一般来说 一个nn Module可以由子类继承 如下所示 def init weights m if type m nn Linear torch nn init xavier uniform m weight class LinearRe