将 CNN Pytorch 中的预训练权重传递到 Tensorflow 中的 CNN

2024-04-05

我在 Pytorch 中针对 224x224 大小的图像和 4 个类别训练了这个网络。

class CustomConvNet(nn.Module):
    def __init__(self, num_classes):
        super(CustomConvNet, self).__init__()

        self.layer1 = self.conv_module(3, 64)
        self.layer2 = self.conv_module(64, 128)
        self.layer3 = self.conv_module(128, 256)
        self.layer4 = self.conv_module(256, 256)
        self.layer5 = self.conv_module(256, 512)
        self.gap = self.global_avg_pool(512, num_classes)
        #self.linear = nn.Linear(512, num_classes)
        #self.relu = nn.ReLU()
        #self.softmax = nn.Softmax()

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.gap(out)
        out = out.view(-1, 4)
        #out = self.linear(out)

        return out

    def conv_module(self, in_num, out_num):
        return nn.Sequential(
            nn.Conv2d(in_num, out_num, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=None))

    def global_avg_pool(self, in_num, out_num):
        return nn.Sequential(
            nn.Conv2d(in_num, out_num, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(out_num),
            #nn.LeakyReLU(),

            nn.ReLU(),
            nn.Softmax(),
            nn.AdaptiveAvgPool2d((1, 1)))

我从第一个 Conv2D 获得了权重及其大小torch.Size([64, 3, 3, 3])

我已将其另存为:

weightsCNN = net.layer1[0].weight.data
np.save('CNNweights.npy', weightsCNN)

这是我在 Tensorflow 中构建的模型。我想将从 Pytorch 模型中保存的权重传递到这个 Tensorflow CNN 中。

    model = models.Sequential()
    model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(224, 224, 3)))
    model.add(layers.MaxPooling2D((2, 2)))

    model.add(layers.Conv2D(128, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))

    model.add(layers.Conv2D(256, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))

    model.add(layers.Conv2D(256, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))

    model.add(layers.Conv2D(512, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))

    model.add(layers.Conv2D(512, (3, 3), activation='relu'))

    model.add(layers.GlobalAveragePooling2D())
    model.add(layers.Dense(4, activation='softmax'))
    print(model.summary())


    adam = optimizers.Adam(learning_rate=0.0001, amsgrad=False)
    model.compile(loss='categorical_crossentropy',
                  optimizer=adam,
                  metrics=['accuracy'])


    nb_train_samples = 6596
    nb_validation_samples = 1290
    epochs = 10
    batch_size = 256


    history = model.fit_generator(
        train_generator,
        steps_per_epoch=np.ceil(nb_train_samples/batch_size),
        epochs=epochs,
        validation_data=validation_generator,
        validation_steps=np.ceil(nb_validation_samples / batch_size)
        )

我实际上应该怎么做? Tensorflow 需要什么形状的权重?谢谢!


您可以检查所有重量的形状keras层非常简单:

for layer in model.layers:
    print([tensor.shape for tensor in layer.get_weights()])

这将为您提供所有重量(包括偏差)的形状,因此您可以准备加载numpy相应地权重。

要设置它们,请执行类似的操作:

for torch_weight, layer in zip(model.layers, torch_weights):
    layer.set_weights(torch_weight)

where torch_weights应该是一个包含列表的列表np.array你必须加载它。

通常每个元素torch_weights将包含一个np.array一个用于权重,一个用于偏置。

请记住,从打印中收到的形状必须与您输入的形状完全相同set_weights.

See 文档 https://keras.io/models/about-keras-models/了解更多信息。

顺便提一句。确切的形状取决于模型执行的层和操作,有时您可能需要转置一些数组才能“适应它们”。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

将 CNN Pytorch 中的预训练权重传递到 Tensorflow 中的 CNN 的相关文章

  • pexpect 和 ssh:如何在 su - root -c 之后格式化一串命令

    我正在尝试迭代服务器和密码列表来更改一组服务器上的 sshd 配置 以便我可以使用无密码 SSH 密钥通过 root 登录 运行命令 我可以在 bash 中轻松完成此操作 但我正在尝试学习 Python 并且 显然 希望放弃手动输入密码 这
  • 确定非空列表条目是否“连续”的 Pythonic 方法

    我正在寻找一种方法来轻松确定列表中所有非 None 项目是否出现在单个连续切片中 我将使用整数作为非 None 项目的示例 例如 列表 None None 1 2 3 None None 满足我对连续整数条目的要求 相比之下 1 2 Non
  • 导入错误:无法导入名称 md5

    真的不知道这里发生了什么 我需要在弹性beanstalk上部署我的flask应用程序 但不知何故改变了路径并且无法再运行python application py dotnet info NET Core SDK reflecting an
  • 非常大的数据集的余弦相似度

    我在计算大量 100 维向量之间的余弦相似度时遇到问题 当我使用from sklearn metrics pairwise import cosine similarity I get MemoryError在我的 16 GB 机器上 每个
  • 用于列出用户和组的 Python 脚本

    我正在尝试编写一个脚本 在自己的行上输出每个用户及其组 如下所示 user1 group1 user2 group1 user3 group2 user10 group6 etc 我正在为此用 python 编写一个脚本 但想知道如何做到这
  • 在Python中解析制表符分隔的文件

    我正在尝试在 Python 中解析一个制表符分隔的文件 其中与行开头分开的 k 个制表符的数字应该放入第 k 个数组中 除了逐行读取并执行简单解决方案将执行的所有明显处理之外 是否有内置函数可以执行此操作 或者有更好的方法 您可以使用the
  • python 使用 shapefile 掩码 netcdf 数据

    我正在使用以下软件包 import pandas as pd import numpy as np import xarray as xr import geopandas as gpd 我有以下存储数据的对象 print precip d
  • 如何访问 pytest 夹具中的所有标记?

    我正在使用 pytest 我想用标记来标记我的测试 这些标记将指定固定装置要在驱动程序中加载哪个页面 这可以轻松地与行为上下文对象一起使用 但我找不到如何使用 pytest 来做到这一点 以这段代码为例 import pytest pyte
  • 将四边形(四边形)拟合到斑点

    应用不同的过滤和分割技术后 我最终得到如下图像 我可以访问一些轮廓检测函数 这些函数返回该对象边缘上的点列表 或者返回一个拟合的多边形 尽管有很多边 远多于 4 个 我想要一种将四边形适合该形状的方法 因为我知道它是应该是四边形的鞋盒的正面
  • Pyside QPushButton 和 matplotlib 的连接

    我正在尝试使用 matplotlib 开发一个非常简单的 pyside Qt 程序 我希望按下按钮时绘制图表 到目前为止 我可以在构造函数上绘制一些东西 但无法将 Pyside 事件与 matplotlib 连接起来 有没有办法做到这一点
  • 为什么 Keras 的 train_on_batch 在第二个 epoch 产生零损失和准确率?

    我正在使用一个大数据集 所以我尝试使用 train on batch 或适合 epoch 1 model Sequential model add LSTM size input shape input shape return seque
  • 使用 numpy 数组时出现内存错误 Python

    我原来的list 函数有超过 200 万行代码 当我运行计算 的代码时出现内存错误 有什么办法可以绕过它吗 这list 下面是实际 numpy 数组的一部分 熊猫数据 import pandas as pd import math impo
  • 在 Mac 上安装 python igraph

    我执行了brew install homebrew science igraph当我执行时sudo pip3 install python igraph 我收到以下错误 Cannot find the C core of igraph on
  • 屏幕截图中低分辨率文本的 OCR

    我正在编写一个 OCR 应用程序来从屏幕截图图像中读取字符 目前 我只关注数字 我的方法部分基于这篇博文 http blog damiles com 2008 11 basic ocr in opencv http blog damiles
  • 如何在 Pandas 中叠加“一天”内的数据进行绘图

    我有一个数据框 里面有一些 更有意义 数据格式如下 In 67 df Out 67 latency timestamp 2016 09 15 00 00 00 000000 0 042731 2016 09 15 00 16 24 3769
  • 无法将项目追加到多处理共享列表

    我正在使用多重处理来为我的应用程序创建子流程 我还在进程和子进程之间共享一个字典 我的代码示例 主要流程 from multiprocessing import Process Manager manager Manager shared
  • 安装轮子后安装后脚本

    Using from setuptools command install import install 如果我运行 我可以轻松运行自定义安装后脚本python setup py install 这是相当微不足道 https stackov
  • 使用 python mechanize 库登录 https 站点

    我有以下代码 import requests import sys import urllib2 import re import mechanize import cookielib import json import imp prin
  • Python pycrypto 模块:为什么 simplejson 无法转储加密字符串?

    表明统一码错误 utf8 codec can t decode byte 0x82 in position 0 unexpected code byte 这是代码 from Crypto Cipher import AES import s
  • python - lxml:强制执行属性的特定顺序

    我有一个 XML 编写脚本 可以为特定的第 3 方工具输出 XML 我使用原始 XML 作为模板来确保构建所有正确的元素 但最终的 XML 看起来与原始的不同 我以相同的顺序编写属性 但 lxml 按自己的顺序编写它们 我不确定 但我怀疑第

随机推荐