如何在TF2.0中创建具有自定义渐变的keras层?

2023-12-29

由于在 TensorFlow 2.0 中,他们计划统一 keras 下的所有高级 API(我不太熟悉)并完全删除会话,我想知道:

如何创建具有自定义渐变的自定义 keras 层?

我见过(相当有限)guide https://keras.io/layers/writing-your-own-keras-layers/关于在 keras 中创建自定义图层,但它没有描述如果我们希望我们的操作具有自定义渐变我们应该做什么。


首先,keras 下 API(如您所称)的“统一”并不会阻止您像在 TensorFlow 1.x 中那样进行操作。会话可能会消失,但您仍然可以像任何 python 函数一样定义模型,并在没有 keras 的情况下热切地训练它(即通过tf.渐变带 https://www.tensorflow.org/tutorials/eager/custom_training_walkthrough)

现在,如果你想构建一个 keras 模型自定义层执行一个自定义操作并有一个自定义渐变,您应该执行以下操作:

a) 编写一个执行自定义操作并定义自定义渐变的函数。有关如何执行此操作的更多信息here https://www.tensorflow.org/api_docs/python/tf/custom_gradient.

@tf.custom_gradient
def custom_op(x):
    result = ... # do forward computation
    def custom_grad(dy):
        grad = ... # compute gradient
        return grad
    return result, custom_grad

请注意,在函数中您应该对待x and dy作为张量和notnumpy 数组(即执行张量运算)

b) 创建一个自定义 keras 层来执行您的custom_op。对于这个例子,我假设你的层没有任何可训练的参数或改变其输入的形状,但如果有的话也没有太大区别。为此,您可以参考您发布的检查指南this one https://www.tensorflow.org/beta/tutorials/eager/custom_layers.

class CustomLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(CustomLayer, self).__init__()

    def call(self, x):
        return custom_op(x)  # you don't need to explicitly define the custom gradient
                             # as long as you registered it with the previous method

现在您可以在 keras 模型中使用该层并且它会起作用。例如:

inp = tf.keras.layers.Input(input_shape)
conv = tf.keras.layers.Conv2D(...)(inp)  # add params like the number of filters
cust = CustomLayer()(conv)  # no parameters in custom layer
flat = tf.keras.layers.Flatten()(cust)
fc = tf.keras.layers.Dense(num_classes)(flat)

model = tf.keras.models.Model(inputs=[inp], outputs=[fc])
model.compile(loss=..., optimizer=...)  # add loss function and optimizer
model.fit(...)  # fit the model
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在TF2.0中创建具有自定义渐变的keras层? 的相关文章

随机推荐

  • Flexbox 顺序和选项卡导航

    我想用显示 柔性改变order的 div 与line类 但我想保持这个 TAB 导航顺序 A B C D 正如您在代码片段中看到的 第一个示例工作正常 DOM 序列与 Order 相同 但在第二个示例中 选项卡遵循 DOM 序列 不使用 j
  • 编译一个快速修复程序

    我正在尝试使用 QuickFix 库通过 FIX 协议连接到代理 我刚刚使用他们提供的文档构建了库 并立即使用他们的示例代码 include quickfix FileStore h include quickfix FileLog h i
  • MongoDB 获取聚合查询的executionStats

    我正在寻找一种方法来检索executionStats用于聚合 当使用 find 时 我可以通过使用轻松检索它们explain https docs mongodb com manual reference explain results 输
  • 防止浏览器缓存 AJAX 调用结果

    看起来如果我使用加载动态内容 get 结果缓存在浏览器中 在 QueryString 中添加一些随机字符串似乎可以解决这个问题 我使用new Date toString 但这感觉就像是黑客攻击 还有其他方法可以实现这一目标吗 或者 如果唯一
  • 如何使用Java读取带有部分的配置文件[重复]

    这个问题在这里已经有答案了 给定一个包含以下内容的文件 upper a A b B words 1 one 2 two 如何参考它们的标头访问这些键 值 Java 的 Properties 类仅处理无节文件 使用 ini4j 库 链接教程
  • C 缓冲区溢出 - 为什么有恒定数量的字节会引发段错误? (Mac OS 10.8 64 位,clang)

    我正在试验 C 中的缓冲区溢出 发现一个有趣的怪癖 对于任何给定的数组大小 似乎有一定数量的溢出字节可以在 SIGABRT 崩溃之前写入内存 例如 在下面的代码中 10 字节数组可以溢出到 26 字节 然后在 27 处崩溃 同样 20 字节
  • 按住按钮“重复射击”

    我已经提到了无数关于按住按钮的其他问题 但与 Swift 相关的问题并不多 我有一个使用 touchUpInside 事件连接到按钮的函数 IBAction func singleFire sender AnyObject code 还有另
  • 我的随机化代码无法离线工作

    我是一个 php 菜鸟 我只是根据我在网上找到的其他一些脚本制作了一个小脚本 它从名为 Random 的文件夹中随机选取 3 张图像并显示它们 当我在线运行脚本时它可以工作 但是当我尝试在 xampp 上离线运行它时 我收到此错误 注意 未
  • UITableView 的第一行在顶部栏下被截断

    我有一个UITabBarController有两个UITableViews 全部都是在故事板中创建的 问题是 在第二个表视图中 表的前几行位于顶栏下方 第一个表视图不会发生这种情况 即使我更改视图的顺序 第一个视图将完美工作 第二个视图将呈
  • 派生类构造函数在 python 中如何工作?

    我有以下基类 class NeuralNetworkBase def init self numberOfInputs numberOfHiddenNeurons numberOfOutputs self inputLayer numpy
  • R:使用 spplot 地图中的自定义调色板

    我正在努力使用在多个多边形上引入自定义调色板spplot来自sp包裹 我正在绘制几个字段并希望显示我的评级 其值可以为 0 1 2 4 或 5 我需要为此使用自定义颜色 我尝试的是 spplot Map zcol Rating col re
  • 仅在现有 iOS 应用程序中对某些视图使用 React Native

    是否可以仅对项目中的一个视图使用 React Native 我已经成功为特定的 iOS 应用程序屏幕添加了 React 视图 使用 与现有 iOS 项目集成 文档中的说明 但我不知道如何从该屏幕获取数据并调用其他 objective c 代
  • VB.Net Xml 反序列化为类

    我在尝试将一些 XML 反序列化到我创建的类中时遇到了一些问题 我得到的错误是 There is an error in XML document 1 2 at System Xml Serialization XmlSerializer
  • MongoDB索引/RAM关系

    我即将在一个新项目中采用 MongoDB 我选择它是为了灵活性 而不是可扩展性 因此将在一台机器上运行它 从文档和网络帖子中我一直读到所有索引都在 RAM 中 这对我来说没有意义 因为我的索引很容易大于可用 RAM 的量 谁能分享一些关于索
  • 如何使用java获取xml节点的属性值

    我有一个 xml 如下所示
  • 我怎样才能让 Modelsim 警告我有关“X”信号的信息?

    我正在使用 Modelsim 进行大型设计 我已经了解了 modelsim 模拟的工作方式 我想知道 是否有一种方法可以在 modelsim 在仿真阶段评估信号并发现它是红色信号 即 X 时向我发出警告 要知道 不可能列出设计的所有信号并一
  • Rails 4:SQLException:没有这样的表:

    我在 Rails4 中运行以下命令 bundle exec rake db migrate 201405270646 AddAttachmentImageToPins 迁移 change table pins 耙子中止 StandardEr
  • 通过扬声器的 AVAudioPlayer

    我得到以下代码 id init if self super init UInt32 sessionCategory kAudioSessionCategory MediaPlayback AudioSessionSetProperty kA
  • 选中复选框时动态更改引导程序进度条值

    我正在尝试制作一个带有引导进度条的动态清单 这是我的标记代码 div class progress progress striped active div class progress bar div div div class row t
  • 如何在TF2.0中创建具有自定义渐变的keras层?

    由于在 TensorFlow 2 0 中 他们计划统一 keras 下的所有高级 API 我不太熟悉 并完全删除会话 我想知道 如何创建具有自定义渐变的自定义 keras 层 我见过 相当有限 guide https keras io la