Tensorflow:如何在 python 中编写带有梯度的操作?

2023-12-14

我想用 python 编写一个 TensorFlow 操作,但我希望它是可微的(能够计算梯度)。

这个问题询问如何在 python 中编写操作,答案建议使用 py_func (没有梯度):Tensorflow:用 Python 编写操作

TF 文档描述了如何仅从 C++ 代码开始添加操作:https://www.tensorflow.org/versions/r0.10/how_tos/adding_an_op/index.html

就我而言,我正在进行原型设计,因此我不关心它是否在 GPU 上运行,也不关心它是否可以从 TF python API 以外的任何地方使用。


是的,正如 @Yaroslav 的回答中提到的,这是可能的,关键是他引用的链接:here and here。我想通过举一个具体的例子来详细阐述这个答案。

模运算:让我们在tensorflow中实现逐元素求模运算(它已经存在,但它的梯度尚未定义,但对于本示例,我们将从头开始实现它)。

numpy 函数:第一步是定义我们想要对 numpy 数组执行的操作。逐元素求模运算已经在 numpy 中实现,因此很简单:

import numpy as np
def np_mod(x,y):
    return (x % y).astype(np.float32)

原因是.astype(np.float32)是因为默认情况下,tensorflow 采用 float32 类型,如果你给它 float64 (numpy 默认值),它会抱怨。

Gradient Function: Next we need to define the gradient function for our opperation for each input of the opperation as tensorflow function. The function needs to take a very specific form. It need to take the tensorflow representation of the opperation op and the gradient of the output grad and say how to propagate the gradients. In our case, the gradients of the mod opperation are easy, the derivative is 1 with respect to the first argument and enter image description here with respect to the second (almost everywhere, and infinite at a finite number of spots, but let's ignore that, see https://math.stackexchange.com/questions/1849280/derivative-of-remainder-function-wrt-denominator for details). So we have

def modgrad(op, grad):
    x = op.inputs[0] # the first argument (normally you need those to calculate the gradient, like the gradient of x^2 is 2x. )
    y = op.inputs[1] # the second argument

    return grad * 1, grad * tf.neg(tf.floordiv(x, y)) #the propagated gradient with respect to the first and second argument respectively

grad 函数需要返回一个 n 元组,其中 n 是操作的参数数量。请注意,我们需要返回输入的张量流函数。

制作带有梯度的 TF 函数:正如上面提到的来源中所解释的,有一个 hack 可以使用以下方法定义函数的梯度tf.RegisterGradient [doc] and tf.Graph.gradient_override_map [doc].

复制代码来自harpone我们可以修改tf.py_func函数使其同时定义渐变:

import tensorflow as tf

def py_func(func, inp, Tout, stateful=True, name=None, grad=None):

    # Need to generate a unique name to avoid duplicates:
    rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))

    tf.RegisterGradient(rnd_name)(grad)  # see _MySquareGrad for grad example
    g = tf.get_default_graph()
    with g.gradient_override_map({"PyFunc": rnd_name}):
        return tf.py_func(func, inp, Tout, stateful=stateful, name=name)

The stateful选项是告诉tensorflow该函数是否总是为相同的输入提供相同的输出(stateful = False),在这种情况下,tensorflow可以简单地表示张量流图,这是我们的情况,并且在大多数情况下可能都是这种情况。

将它们组合在一起:现在我们已经有了所有的部分,我们可以将它们组合在一起:

from tensorflow.python.framework import ops

def tf_mod(x,y, name=None):

    with ops.op_scope([x,y], name, "mod") as name:
        z = py_func(np_mod,
                        [x,y],
                        [tf.float32],
                        name=name,
                        grad=modgrad)  # <-- here's the call to the gradient
        return z[0]

tf.py_func作用于张量列表(并返回张量列表),这就是为什么我们有[x,y](并返回z[0])。 现在我们完成了。我们可以测试它。

Test:

with tf.Session() as sess:

    x = tf.constant([0.3,0.7,1.2,1.7])
    y = tf.constant([0.2,0.5,1.0,2.9])
    z = tf_mod(x,y)
    gr = tf.gradients(z, [x,y])
    tf.initialize_all_variables().run()

    print(x.eval(), y.eval(),z.eval(), gr[0].eval(), gr[1].eval())

[ 0.30000001 0.69999999 1.20000005 1.70000005] [ 0.2 0.5 1. 2.9000001] [ 0.10000001 0.19999999 0.20000005 1.70000005] [ 1. 1. 1.1.] [-1。 -1。 -1。 0.]

Success!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow:如何在 python 中编写带有梯度的操作? 的相关文章

  • 上传时的 Google Drive API——这些额外的空行从何而来?

    总结一下该程序 我从我的 Google 云端硬盘下载一个文件 然后在本地计算机中打开并读取一个文件 file a txt 然后在我的计算机中打开另一个文件 file b txt 处于附加模式 并且在使用这个新的 file b 更新我的 Go
  • numpy python 中的“AttributeError:'matrix'对象没有属性'strftime'”错误

    我有一个维度为 72000 1 的矩阵 该矩阵涉及时间戳 我想使用 strftime 如下所示 strftime d m y 为了得到像这样的输出 11 03 02 我有这样一个矩阵 M np matrix timestamps 我使用了
  • NumPy linalg.eig

    我有这个烦人的问题 但我还没有弄清楚 我有一个矩阵 我想找到特征向量 所以我写 val vec np linalg eig mymatrix 然后我得到了 vec 我的问题是 当我小组中的其他人对相同的矩阵 mymatrix 做同样的事情时
  • 从sklearn PCA获取特征值和向量

    如何获取 PCA 应用程序的特征值和特征向量 from sklearn decomposition import PCA clf PCA 0 98 whiten True converse 98 variance X train clf f
  • Pyqt-如何因另一个组合框数据而更改组合框数据?

    我有一个表 有 4 列 这 4 列中的两列是关于功能的 一个是特征 另一个是子特征 在每一列中 所有单元格都有组合框 我可以在这些单元格中打开txt 我想 当我选择电影院作为功能时 我只想看到子功能组合框中的电影名称 而不是我的 数据 中的
  • Python 是解释型的还是编译型的,或者两者兼而有之?

    据我了解 An 解释的语言是由解释器 将高级语言转换为机器代码然后执行的程序 实时运行和执行的高级语言 它一次处理一点程序 A compiled语言是一种高级语言 其代码首先由编译器 将高级语言转换为机器代码的程序 转换为机器代码 然后由执
  • Python 使用 Gstreamer 访问 USB 麦克风时遇到问题,以便在 Raspberry Pi 上使用 Pocketsphinx 执行语音识别

    所以Python的表现就好像它根本听不到我的麦克风发出的任何声音 问题就在这里 我有一个Python 2 7 假设使用的脚本Gstreamer通过以下方式访问我的麦克风并为我进行语音识别口袋狮身人面像 我在用着脉冲音频我的设备是树莓派 我的
  • Python中列表中两个连续元素的平均值

    我有一个偶数个浮点数的列表 2 34 3 45 4 56 1 23 2 34 7 89 我的任务是计算 1 和 2 个元素 3 和 4 5 和 6 等元素的平均值 在 Python 中执行此操作的快捷方法是什么 data 2 34 3 45
  • ValueError:不支持连续[重复]

    这个问题在这里已经有答案了 我正在使用 GridSearchCV 进行线性回归的交叉验证 不是分类器也不是逻辑回归 我还使用 StandardScaler 对 X 进行标准化 我的数据框有 17 个特征 X 和 5 个目标 y 观察 约11
  • Pandas 堆积条形图中元素的排序

    我正在尝试绘制有关某个地区 5 个地区的家庭在特定行业赚取的收入比例的信息 我使用 groupby 按地区对数据框中的信息进行排序 df df orig groupby District Portion of income value co
  • Python:我不明白 sum() 的完整用法

    当然 我明白你使用 sum 与几个数字 然后它总结所有 但我正在查看它的文档 我发现了这一点 sum iterable start 第二个参数 start 的作用是什么 这太尴尬了 但我似乎无法通过谷歌找到任何示例 并且对于尝试学习该语言的
  • 为什么我应该使用 WSGI?

    使用 mod python 一段时间了 我读了越来越多关于 WSGI 有多好的文章 但没有真正理解为什么 那么我为什么要切换到它呢 有什么好处 这很难吗 学习曲线值得吗 为了用 Python 开发复杂的 Web 应用程序 您可能会使用更全面
  • Django - 提交具有同一字段多个输入的表单

    预警 我对 Django 以及一般的 Web 开发 非常陌生 我使用 Django 托管一个基于 Web 的 UI 该 UI 将从简短的调查中获取用户输入 通过我用 Python 开发的一些分析来提供输入 然后在 UI 中呈现这些分析的可视
  • Flask 应用程序的测试覆盖率不起作用

    您好 想在终端的 Flask 应用程序中测试 删除路由 我可以看到测试已经过去 它说 test user delete test app LayoutTestCase ok 但是当我打开封面时 它仍然是红色的 这意味着没有覆盖它 请有人向我
  • Python对象初始化性能

    我只是做了一些快速的性能测试 我注意到一般情况下初始化列表比显式初始化列表慢大约四到六倍 这些可能是错误的术语 我不确定这里的行话 例如 gt gt gt import timeit gt gt gt print timeit timeit
  • 附加两个具有相同列、不同顺序的数据框

    我有两个熊猫数据框 noclickDF DataFrame 0 123 321 0 1543 432 columns click id location clickDF DataFrame 1 123 421 1 1543 436 colu
  • Python问题:打开和关闭文件返回语法错误

    大家好 我发现了这个有用的 python 脚本 它允许我从网站获取一些天气数据 我将创建一个文件和其中的数据集 有些东西不起作用 它返回此错误 File
  • 计算互相关函数?

    In R 我在用ccf or acf计算成对互相关函数 以便我可以找出哪个移位给我带来最大值 从它的外观来看 R给我一个标准化的值序列 Python 的 scipy 中是否有类似的东西 或者我应该使用fft模块 目前 我正在这样做 xcor
  • bs4 `next_sibling` VS `find_next_sibling`

    我在使用时遇到困难next sibling 并且类似地与next element 如果用作属性 我不会得到任何返回 但如果用作find next sibling or find next 然后就可以了 来自doc https www cru
  • 操作错误:(sqlite3.OperationalError) SQL 变量太多,同时将 SQL 与数据帧一起使用

    我有一个熊猫数据框 如下所示 activity User Id 0 VIEWED MOVIE 158d292ec18a49 1 VIEWED MOVIE 158d292ec18a49 2 VIEWED MOVIE 158d292ec18a4

随机推荐