如何在 Pytorch 中可视化网络?

2024-01-30

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

我想要形象化resnet来自 pytorch 模型。我该怎么做?我尝试使用torchviz但它给出了一个错误:

'ResNet' object has no attribute 'grad_fn'

以下是使用不同工具的三种不同的图形可视化。

为了生成示例可视化,我将使用一个简单的 RNN 来执行从在线教程 https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/1%20-%20Simple%20Sentiment%20Analysis.ipynb:

class RNN(nn.Module):

    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):

        super().__init__()
        self.embedding  = nn.Embedding(input_dim, embedding_dim)
        self.rnn        = nn.RNN(embedding_dim, hidden_dim)
        self.fc         = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):

        embedding       = self.embedding(text)
        output, hidden  = self.rnn(embedding)

        return self.fc(hidden.squeeze(0))

这是输出,如果您print()该模型。

RNN(
  (embedding): Embedding(25002, 100)
  (rnn): RNN(100, 256)
  (fc): Linear(in_features=256, out_features=1, bias=True)
)

以下是三种不同可视化工具的结果。

对于所有这些,您需要有可以通过模型的虚拟输入forward()方法。获取此输入的一个简单方法是从 Dataloader 中检索批次,如下所示:

batch = next(iter(dataloader_train))
yhat = model(batch.text) # Give dummy batch to forward().

Torchviz

https://github.com/szagoruyko/pytorchviz https://github.com/szagoruyko/pytorchviz

我相信这个工具使用向后传递生成其图形,因此所有盒子都使用 PyTorch 组件进行反向传播。

from torchviz import make_dot

make_dot(yhat, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")

该工具生成以下输出文件:

这是唯一清楚提到我的模型中的三层的输出,embedding, rnn, and fc。运算符名称取自向后传递,因此其中一些难以理解。

隐藏层

https://github.com/waleedka/hiddenlayer https://github.com/waleedka/hiddenlayer

我相信这个工具使用前向传播。

import hiddenlayer as hl

transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.

graph = hl.build_graph(model, batch.text, transforms=transforms)
graph.theme = hl.graph.THEMES['blue'].copy()
graph.save('rnn_hiddenlayer', format='png')

这是输出。我喜欢蓝色的阴影。

我发现输出有太多细节并且混淆了我的架构。例如,为什么是unsqueeze提到了这么多次?

Netron

https://github.com/lutzroeder/netron https://github.com/lutzroeder/netron

该工具是适用于 Mac、Windows 和 Linux 的桌面应用程序。它依赖于首先导出到的模型ONNX 格式 https://onnx.ai/。然后,应用程序读取 ONNX 文件并渲染它。然后可以选择将模型导出到图像文件。

input_names = ['Sentence']
output_names = ['yhat']
torch.onnx.export(model, batch.text, 'rnn.onnx', input_names=input_names, output_names=output_names)

这是模型在应用程序中的样子。我认为这个工具非常灵活:您可以缩放和平移,并且可以深入了解图层和运算符。我发现的唯一缺点是它只支持垂直布局。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 Pytorch 中可视化网络? 的相关文章

随机推荐

  • 为什么不是 scanf("%*[^\n]\n");和 scanf("%*[^\n]%*c");清除悬挂的换行符?

    拨打电话后scanf d variable 我们留下了一个换行符挂在stdin 应在调用之前清除fgets 或者我们最终给它提供一个换行符并使其过早返回 我找到了建议使用的答案scanf n c 第一次致电后scanf放弃换行符和其他建议使
  • oracle查询比较表中具有相同id的所有行

    需要一个 sql 查询来生成具有相同 id 的状态为完整的记录 例如 mytable是包含各种记录的表名 我们需要找到同一 ID 的所有状态为完整的 ID id status 12 complete 12 required 12 activ
  • Google 网站管理员工具 API:通过 OAUTH2 下载查询

    我正在尝试使用 Google 的网站管理员工具 API 下载最近搜索查询的 CSV 文件 我知道如何使用他们的 Python 示例来做到这一点http googlewebmastercentral blogspot com 2011 12
  • 如何向 .DecimalPad iOS 键盘添加减号?

    如何向 DecimalPad 类型 iOS 键盘添加减号 就像下面链接中的应用程序一样 如果我错了 请纠正我 但这对我来说似乎不是一个自定义键盘 它看起来像是苹果公司的默认十进制键盘 带减号的十进制键盘 https i stack imgu
  • Apache 基准 HTTPS 失败

    我在 Ubuntu 虚拟机中使用 Apache 2 4 2 我用它来加载测试 将请求发送到某个 HTTPS url 失败的请求数为零 但我的请求都无法真正得到处理 已经在数据库中查找 使用相同的url 通过浏览器调用它就可以了 数据库已更新
  • Python 中按年月分组并删除所有 NaN 的列

    基于来自的输出数据帧这个链接 https stackoverflow com questions 69937232 groupby year month and find top n smallest values columns in p
  • 原子属性的 setter 和 getter

    对于以下属性值 自动生成的 getter 和 setter 是什么样的 in h interface MyClass NSObject private NSString value property retain NSString valu
  • Qt/C++ 如何迭代给定类对象的 QMetaObject 属性/数据类型?

    在 C Java 中 我使用反射来读取类的属性 我尝试过使用 Qt 但不知道是否能正确解决我的问题 一个简单的 Person 类头 注意 3 个属性 id fname lname ifndef PERSON H define PERSON
  • 我可以创建私有枚举构造函数吗?

    在 Haskell 中我可以做这样的事情 示例改编自学习 Haskell http learnyouahaskell com making our own types and typeclasses algebraic data types
  • Angular - 顺序进行多个 HTTP 调用

    我需要创建一个函数来顺序进行 HTTP 调用 以便使用一个调用的响应到另一个调用 例如从第一次调用中获取用户的 IP 地址 并使用该 IP 在第二次调用中注册用户 演示代码 registerUser user User this utili
  • 如何配置 Sublime Text 在保存时始终转换为 Unix 行结尾?

    我希望我在 Sublime Text 中保存的所有文件都采用 Unix 行结束格式 即使我打开最初以不同格式保存但后来在 Sublime Text 中编辑的文件也是如此 简单设定 default line ending unix 还不够 因
  • Django INSTALLED_APPS 'polls' 与 'polls.apps.PollsConfig'

    在每个 YouTube 教程中 我都看到人们只是将 app name 添加到 INSTALLED APPS 列表中 昨天我开始了官方 Django 教程 他们建议使用 app name apps App nameConf 符号 我猜官方方法
  • 理解java的同步集合

    我正在阅读java官方doc https docs oracle com javase tutorial collections implementations wrapper html关于包装器实现 它们是静态方法收藏用于获取同步集合 例
  • 数字签名时间戳在 XP/Vista 上“不可用”,导致验证失败

    背景 我有一个 WiX Burn 安装包 其中包括安装 ReportViewer 2012 Runtime 在 Windows 7 或更高版本的计算机上运行时 它工作正常 在 XP SP3 或 Vista SP1 上它会失败 现在 检查Re
  • 使用mysqldump将表数据导出到csv文件

    我想使用 mysqldump 将表数据导出到 csv 文件中 我想做一些类似的东西 mysqldump compact no create info tab testing fields enclosed by fields termina
  • JBoss 4.2.2 Web服务soap:地址

    我在 JBoss 4 2 2 中部署了一个 EJB3 bean 作为 Web 服务 在生产中 服务器位于 Apache 服务器后面 该服务器将请求重定向到 Jboss 服务器 这使得 WSDL 具有错误的soap address 位置 我能
  • 将按钮添加到 PreferenceScreen

    我不知道如何在PreferenceScreen 向上按钮会在应用程序图标旁边的操作栏中显示插入符号 使您可以导航应用程序的层次结构 更多信息here http developer android com training implement
  • 查找 CSV 文件中的重复项总数

    我正在解析 CSV 文件 需要您的帮助 我的 CSV 文件中有重复项 我想告诉Python向我提供重复地址的总数和唯一地址的总数 然后列出它们 我已经成功到达地址显示它是唯一还是重复的部分 但现在我想告诉 Python 也为我提供受尊重的数
  • 查看和设置 Safari/Chrome 的 HTTP 标头

    我正在测试一个 API 我想用 safari 来访问它并查看返回的原始 json API 要求每个请求都发送特定的 HTTP 标头 Safari 或 Chrome 中有没有办法在访问 URL 时设置我的 http 标头 有几个 Google
  • 如何在 Pytorch 中可视化网络?

    import torch import torch nn as nn import torch optim as optim import torch utils data as data import torchvision models