如何从 pytorch 模块获取子模块序列?

2024-01-27

对于火炬module https://pytorch.org/docs/master/generated/torch.nn.Module.html,我想我可以用.named_children, .named_modules等来获取子模块的列表。但是,我想该列表不是按顺序给出的,对吧?一个例子:

In [19]: import transformers

In [20]: model = transformers.DistilBertForSequenceClassification.from_pretrained('distilb
    ...: ert-base-cased')

In [21]: [name for name, _ in model.named_children()]
Out[21]: ['distilbert', 'pre_classifier', 'classifier', 'dropout']

的顺序.named_children()在上面的模型中,给出了 distilbert、pre_classifier、classifier 和 dropout。但是,如果您检查code https://github.com/huggingface/transformers/blob/9931f817b75ecb2c8bb08b6e9d4cbec4b0933935/src/transformers/modeling_distilbert.py#L641,显然dropout发生在之前classifier。那么如何获得这些子模块的顺序呢?


在 Pytorch 中,结果为print(model) or .named_children()等根据声明顺序列出__init__模型的类别,例如

Case 1

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        self.conv2_drop = nn.Dropout2d()

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.6)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = Model()
print(model)
[name for name, _ in model.named_children()]
# output
['conv1', 'conv2', 'fc1', 'fc2', 'conv2_drop']

Case 2

更改了顺序fc1 and fc2构造函数中的层。

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc2 = nn.Linear(50, 10)
        self.fc1 = nn.Linear(320, 50)
        self.conv2_drop = nn.Dropout2d()

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.6)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = Model()
print(model)
[name for name, _ in model.named_children()]
# output
['conv1', 'conv2', 'fc2', 'fc1', 'conv2_drop']

这就是为什么classifier之前打印过dropout正如它在构造函数中声明的那样:

class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
        ...
        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.classifier = nn.Linear(config.dim, config.num_labels)
        self.dropout = nn.Dropout(config.seq_classif_dropout)

尽管如此,您可以使用模型的子模块.modules()等,但它们只会按照声明的顺序列出__init__。如果您只想打印基于的结构forward方法,您可以尝试使用pytorch 摘要 https://github.com/Fangyh09/pytorch-summary.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何从 pytorch 模块获取子模块序列? 的相关文章

随机推荐

  • 通过 id 删除骨干模型?

    可以通过id删除模型吗 文档说您需要传入模型本身才能将其删除 所以我需要先获取模型然后删除它 我不能直接通过id删除它吗 您的意思是从集合中删除模型吗 查看文档 似乎您确实需要传递一个真实的模型 但源代码表明您可以只传递一个模型id或型号c
  • 具有唯一键的javascript和es6过滤器数组

    我有一个变量列表 例如 var name list some list console log name list Array 3 0 Object name Johny 1 Object name Monty 2 Object3 name
  • 统一加速

    我正在尝试在 Unity 中模拟加速和减速 我编写了代码来在 Unity 中生成轨道 并根据时间将对象放置在轨道上的特定位置 结果看起来有点像这样 我目前遇到的问题是样条线的每个部分都有不同的长度 并且立方体以不同但均匀的速度穿过每个部分
  • 指定父级 div 的不透明度,但使其不影响子级 HTML 元素

    我在 div 中有一个段落元素 div 的不透明度为 0 3 段落的不透明度为 1 当我显示元素时 该段落看起来是透明的 就像它的不透明度为 0 3 一样 有没有办法让div内的段落完全不透明 也许我可以为此设置一个 CSS 值 div s
  • 跳过没有装饰器语法的单元测试

    我有一套使用 TestLoader 的 来自单元测试模块 loadTestsFromModule 方法加载的测试 即 suite loader loadTestsFromModule module 这给了我一个非常充足的 运行良好的测试列表
  • 在Android模拟器中添加铃声

    有谁知道如何向 Android 模拟器添加 下载铃声或 mp3 声音 Go to DDMS in Eclipse 点击File Explorer选项卡并导航至mnt sdcard 单击创建新文件夹Plus图标称为ringtones 然后单击
  • 哪里可以找到 Android 示例?

    我检查了谷歌开发者网站上的一些 Android 开发练习和示例 我发现了这个网页 http developer android com tools samples index html http developer android com
  • Haskell - 非法多态类型?

    为什么该类型单独使用可以编译 但放入列表却失败 ft1 Foldable t Num a gt t a gt a ft1 F foldl 0 fTest Foldable t Num a gt t a gt a fTest F foldl
  • Django Cripy-Forms 找不到 CSS

    我正在使用 Django 和 Crispy Forms 我可以正确呈现表单 但不会出现 CSS 格式 我需要做什么 我已经添加了 CRISPY TEMPLATE PACK bootstrap to my settings py file h
  • 如何让 django 在继续完成与请求相关的任务之前给出 HTTP 响应?

    在我的 django 活塞 API 中 我想在调用另一个需要相当长的时间的函数之前向客户端产生 返回一个 http 响应 如何使yield 给出包含所需JSON 的HTTP 响应 而不是与生成器对象创建相关的字符串 我的活塞处理程序方法如下
  • 如何读取属性文件并使用项目 Gradle 脚本中的值?

    我正在开发一个 Gradle 脚本 我需要阅读local properties文件并使用属性文件中的值build gradle 我正在按照以下方式进行操作 我运行了下面的脚本 它现在抛出一个错误 但它也没有执行任何操作 例如创建 删除和复制
  • Django-CKEditor 不会渲染图像

    我已经安装了 Django CKEditor 并对其进行了配置以用于开发目的 现在我可以编辑文本并将其作为文本字段保存到数据库中 但是在插入图像时我遇到了很大的问题 我可以插入图像 它似乎可以正确保存到本地主机 正确的文件夹 但是当将图像渲
  • 如何更改 setInterval 和 setTimeout 函数中“this”的范围

    怎么可能使用this代替setInterval and setTimeout calls 我想这样使用它 function myObj this func function args setTimeout function this fun
  • 如何解决Require.js中的循环依赖?

    基本上 这个想法是 子 模块创建一个对象 并且该对象应该是作为 主 模块的实用程序库的一部分 然而 子 对象depends关于 main 的实用程序 Main module define sub function sub var utils
  • NameError:未初始化的常量 Bundler

    我刚刚将我的网络服务器更改为 Puma 并且必须将我的开发数据库从 sqlite 更改为 postgresql 但现在每次我尝试运行 rake db migrate 时都会收到此错误 rake aborted NameError unini
  • 为 ObjectContext 创建接口

    我正在尝试创建一个抽象层ObjectContext 我理解 OC 是一个工作单元 但我并不完全了解如何为它编写一个好的界面 理想情况下 我希望能够交换实现的 RealDataContext IDataContext对于像 FakeDataC
  • 求解 a^3 + b^4 = c^3 + d^3 最佳运行时间

    注意 这个问题不同于写出 a 3 b 3 c 3 d 3 的所有解 https stackoverflow com questions 14454133 write all solutions for a3 b3 c3 d3因为我需要帮助理
  • SQL 如果 select 语句不返回任何行,则执行替代 select 语句

    基本上 什么语法可以让我实现标题声明 If select statement 1 returns 0 rows THEN select statement 2 else select statement 3 以便 sql 返回语句 2 或
  • 将图库中的所有图像加载到 android 中的应用程序中

    我正在创建一个应用程序 其中我需要图库中的所有图像到我的应用程序中 其中有一个 girdview 我希望所有文件夹中的所有图像都显示在网格视图中 String proj MediaStore Images Media DATA MediaS
  • 如何从 pytorch 模块获取子模块序列?

    对于火炬module https pytorch org docs master generated torch nn Module html 我想我可以用 named children named modules等来获取子模块的列表 但是