从 TensorFlow 图中删除 dropout 操作

2023-11-23

我有一个经过训练的冻结图,我正在尝试在 ARM 设备上运行它。基本上,我使用 contrib/pi_examples/label_image,但使用我的网络而不是 Inception。我的网络接受了 dropout 训练,这现在给我带来了麻烦:

Invalid argument: No OpKernel was registered to support Op 'Switch' with these attrs.  Registered kernels:
  device='CPU'; T in [DT_FLOAT]
  device='CPU'; T in [DT_INT32]
  device='GPU'; T in [DT_STRING]
  device='GPU'; T in [DT_BOOL]
  device='GPU'; T in [DT_INT32]
  device='GPU'; T in [DT_FLOAT]

 [[Node: l_fc1_dropout/cond/Switch = Switch[T=DT_BOOL](is_training_pl, is_training_pl)]]

我看到的一种解决方案是构建这样的 TF 静态库,其中包含相应的操作。从另一方面来说,从网络中消除 dropout 操作可能是一个更好的主意,以使其更简单、更快。有没有办法做到这一点?

Thanks.


#!/usr/bin/env python2

import argparse

import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2

def print_graph(input_graph):
    for node in input_graph.node:
        print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)

def strip(input_graph, drop_scope, input_before, output_after, pl_name):
    input_nodes = input_graph.node
    nodes_after_strip = []
    for node in input_nodes:
        print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)

        if node.name.startswith(drop_scope + '/'):
            continue

        if node.name == pl_name:
            continue

        new_node = node_def_pb2.NodeDef()
        new_node.CopyFrom(node)
        if new_node.name == output_after:
            new_input = []
            for node_name in new_node.input:
                if node_name == drop_scope + '/cond/Merge':
                    new_input.append(input_before)
                else:
                    new_input.append(node_name)
            del new_node.input[:]
            new_node.input.extend(new_input)
        nodes_after_strip.append(new_node)

    output_graph = graph_pb2.GraphDef()
    output_graph.node.extend(nodes_after_strip)
    return output_graph

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--input-graph', action='store', dest='input_graph')
    parser.add_argument('--input-binary', action='store_true', default=True, dest='input_binary')
    parser.add_argument('--output-graph', action='store', dest='output_graph')
    parser.add_argument('--output-binary', action='store_true', dest='output_binary', default=True)

    args = parser.parse_args()

    input_graph = args.input_graph
    input_binary = args.input_binary
    output_graph = args.output_graph
    output_binary = args.output_binary

    if not tf.gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return

    input_graph_def = tf.GraphDef()
    mode = "rb" if input_binary else "r"
    with tf.gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read().decode("utf-8"), input_graph_def)

    print "Before:"
    print_graph(input_graph_def)
    output_graph_def = strip(input_graph_def, u'l_fc1_dropout', u'l_fc1/Relu', u'prediction/MatMul', u'is_training_pl')
    print "After:"
    print_graph(output_graph_def)

    if output_binary:
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
    else:
        with tf.gfile.GFile(output_graph, "w") as f:
            f.write(text_format.MessageToString(output_graph_def))
    print("%d ops in the final graph." % len(output_graph_def.node))


if __name__ == "__main__":
    main()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

从 TensorFlow 图中删除 dropout 操作 的相关文章

随机推荐

  • 通信链路故障 最后发送到服务器的数据包是在 1 毫秒前。

    我尝试连接到mysql database但我失败了并且显示了这个错误 Communications link failure Last packet sent to the server was 1 ms ago 这是我的代码 任何人都可以
  • 如何将 xsi:type 定义为 XML 模式中的属性?

    我有一个 XML 我想为其编写架构定义 问题是我不知道如何将 xsi type 定义为属性 这是 XML 元素
  • 如何在.net中查找当前线程的最大堆栈大小?

    如何找到当前线程的最大堆栈大小 我在从 MMC UI 而不是从 Powershell 命令行 控制台 执行函数时遇到堆栈溢出异常 所以我猜测它与 UI 线程中分配的默认堆栈大小与 Powershell 命令行 控制台 中分配的默认堆栈大小有
  • gridview rowCommand 中的行索引

    只是想将值从变量转移到另一个变量 protected void gvVariableDetail RowCommand object sender GridViewCommandEventArgs e if e CommandName Ed
  • XNA 和 GUI 控件(例如 xaml 和 xna)

    有没有办法在 xna 中获取支持边距等的文本框 标签和其他 wpf 控件 并根据窗口大小进行伸缩 你可能会给CeGui a shot 如果您的游戏需要高级 GUI 功能 CeGui 可能正好适合您 撇开市场因素不谈 这是一个非常好的 GUI
  • AlarmManager 在 Android 4.4.2 中停止工作(使用 SetExact())

    我在代码中设置了一个在特定时间响起的闹钟 警报机制在 SDK 这是我设置闹钟的代码 public void SetAlarm Context context Long executionTime AlarmManager am AlarmM
  • 让 Git 使用代理服务器 - 失败并显示“请求超时”

    如何让 Git 使用代理服务器 我需要从 Git 服务器检查代码 但每次都显示 请求超时 我该如何解决这个问题 或者 如何设置代理服务器 使用的命令 git config global http proxy http proxyuser e
  • 有多少用户连接到我的 Shiny 应用程序?

    我正在开发一个闪亮的应用程序shinydashboard在应用程序的某个地方 我想显示一条通知 告诉用户有多少其他用户同时连接到该应用程序 我想出了第一段似乎有效的代码 library shiny ui fluidPage uiOutput
  • 单元测试插入/更新/删除

    我用谷歌搜索了一下 并没有真正找到我需要的答案 我正在为客户使用 C SQL Server 和 LINQ 开发一个网页 我希望用户能够互相发送消息 所以我所做的是使用实际进入数据库的数据对其进行单元测试 问题是我现在依赖于至少有 2 个我知
  • PowerShell v5 - 如何将模块安装到没有互联网连接的计算机上?

    我有一台机器 v3 互联网 无管理员访问权限 我用它下载 WMF 5 0 并设置另一台机器 v5 无互联网 管理员访问权限 现在 我想在运行 v5 但没有互联网连接的计算机上使用 PowerShellGet 中的一些模块 我需要一个选项来下
  • 如何修复 iOS Firestore Increment() 上的错误“‘增量’的使用不明确”

    我在尝试使用 firebase 时收到编译器错误FieldValue increment 1 在 iOS 中使用 swift 该错误仅表示 增量 的使用不明确 我已将所有 pod 更新为所使用的所有 firebase pod 的当前版本 更
  • FFMPEG:使用绘图文本以及自动换行和填充创建视频

    我正在努力使用绘图文本过滤器从文本创建视频 输出视频我可以看到文本溢出而不是换行 有什么方法可以存档自动换行并将内部填充设置为视频 下面是我用来从文本生成视频的片段 ffmpeg exe f lavfi i color c white s
  • 在 numpy 数组中相乘

    我试图将二维数组中的每个项乘以一维数组中的相应项 如果我想将每一列乘以一维数组 这非常容易 如下所示numpy 乘法功能 但我想做相反的事情 将行中的每一项相乘 换句话说 我想乘以 1 2 3 0 4 5 6 1 7 8 9 2 and g
  • 如何在 Rails 6 中执行自定义 JavaScript 函数

    随着 Webpacker 引入 Ruby On Rails 我找不到使用 JavaScript 函数的方法 我有一个名为app globals js具有要测试的功能 function alerts alert TEST 然后我想在我的观点之
  • Codeigniter:ORDER BY CASE 查询出错

    这是我的查询代码点火器 this gt db gt select p u firstname u lastname s title AS industry pt type name al length value FALSE this gt
  • Firemonkey 中的 Cleartype 字体/文本渲染

    下面是一个仅包含 TEdit 控件的示例 VCL 应用程序 如果您编译类似的 Firemonkey FMX 应用程序 您会注意到这一点 小L字母疯狂地跳来跳去 根据我的研究我发现thisG 帖子的结果是 如您所见 结果更好 跳跃消失了 然而
  • 如何在一定时间后删除MySQL记录

    我想在 7 天后从 MySQL 数据库中删除一些消息 我的消息表行具有以下格式 编号 留言 日期 日期是正常格式的时间戳 2012 12 29 17 14 53 我认为 MySQL 事件将是替代 cron 作业的方法 对于经验丰富的 SQL
  • 在 C++ 中将宽字符字符串转换为小写

    如何在 C 中将 wchar t 字符串从大写转换为小写 该字符串包含日语 中文 德语和希腊字符的混合体 我想过用塔罗 http msdn microsoft com en us library 8h19t214 28VS 80 29 as
  • Android 棒棒糖工具栏在打开/关闭抽屉和后退按钮之间切换

    我有标准导航抽屉 但现在我正在尝试使用工具栏修改它 早些时候我的代码看起来像 MainActivity java Override protected void onCreate Bundle savedInstanceState supe
  • 从 TensorFlow 图中删除 dropout 操作

    我有一个经过训练的冻结图 我正在尝试在 ARM 设备上运行它 基本上 我使用 contrib pi examples label image 但使用我的网络而不是 Inception 我的网络接受了 dropout 训练 这现在给我带来了麻