tensorflow:自定义op简单介绍

2023-05-16

本文只是简单的翻译了 https://www.tensorflow.org/extend/adding_an_op 的简单部分,高级部分请移步官网。

可能需要新定义 c++ operation 的几种情况:

  • 现有的 operation 组合不出来你想要的 op
  • 现有的 operation 组合 出来的 operation 十分低效
  • 如果你想要手动融合一些操作。

为了实现你的自定义操作,你需要做一下几件事:

  1. 在 c++ 文件中注册一个新opOp registration 定义了 op 的功能接口,它和 op 的实现是独立的。例如:op registration 定义了 op 的名字和 op的输出输出。它同时也定义了 shape 方法,被用于 tensorshape 接口。
  2. c++ 中实现 opop 的实现称之为 kernel ,它是op 的一个具体实现。对于不同的输入输出类型或者 架构(CPUs,GPUs)可以有不同的 kernel 实现 。
  3. 创建一个 python wrapper(可选的): 这个 wrapper 是一个 公开的 API,用来在 python中创建 opop registration 会生成一个默认的 wrapper,我们可以直接使用或者自己添加一个。
  4. 写一个计算 op 梯度的方法(可选)。
  5. 测试 op:为了方便,我们通常在 python 中测试 op,但是你也可以在 c++ 中进行测试。如果你定义了 gradients,你可以 通过 Python 的 gradient checker 验证他们。 这里有个例子relu_op_test.py ,测试 ReLU-likeop 的 前向和梯度过程。

Define the op’s interface

**You define the interface of an op by registering it with the TensorFlow system. **

在注册 op 的时候,你需要指定:

  • op 的名字
  • op 的输入(名字,类型),op 的输出(名字,类型)
  • docstrings
  • op 可能需要的 一些 attrs

为了演示这个到底怎么工作的,我们来看一个简单的例子:

  • 定义一个 op :输入是一个 int32tensor ,输出是输入的 拷贝,除了第一个元素保留,其它全都置零。

为了创建这个 op 的接口, 我们需要:

  • 创建一个文件,名字为 zero_out.cc. 然后调用 REGISTER_OP 宏,使用这个宏来定义 op 的接口 :
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
      .Input("to_zero: int32")
      .Output("zeroed: int32")
      .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
        c->set_output(0, c->input(0));
        return Status::OK();
      });

这个 ZeroOut op 接收一个 int 32tensor 作为输入,输出同样也是一个 int32tensor。这个 op 也使用了一个 shape 方法来确保输入和输出的维度是一样的。例如,如果输入的tensor 的shape 是 [10, 20],那么,这个 shape 方法保证输出的 shape 也是 [10, 20]

注意: op 的名字必须遵循驼峰命名法,而且要保证 op 的名字的唯一性。

Implement the kernel for the op

当你 定义了 op 的接口之后,你可以提供一个或多个 关于op 的实现。

为了实现这些 kernels

  • 创建一个类,继承 OpKernel
  • 重写 OpKernel 类的 Compute 方法
    • Compute 方法提供了一个 类型为 OpKernelContext*context 参数 ,从这里,我们可以访问到一些有用的信息,比如 输入 和 输出 tensor

kernel 代码也放到 之前创建的 zero_out.cc 文件中:

#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // 获取输入 tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // 创建输出 tensor, context->allocate_output 用来分配输出内存?
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // 执行计算操作。
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};

在实现了 kernel 之后,就可以将这个注册到 tensorflow 系统中去了。在注册时,你需要对 op 的运行环境指定一些限制。例如,你可能有一个 kernel 代码是给 CPU 用的,另一个是给 GPU用的。通过把下列代码添加到 zero_out.cc 中来完成这个功能:

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

注意:你实现的 OpKernel 的实例可能会被并行访问,所以,请确保 Compute方法是线程安全的。保证访问 类成员的 方法都加上 mutex。或者更好的选择是,不要通过 类成员来分享 状态。考虑使用 ResourceMgr 来追踪状态。

Multi-threaded CPU kernels

多线程主要由 work shard 搞定。work shard

GPU kernels

请移步官网

Build the op library

使用系统编译器 编译 定义的 op

我们可以使用 系统上的 c++ 编译器 g++ 或者 clang 来编译 zero_out.cc 。二进制的 PIP 包 已经将编译所需的 头文件 和 库 安装到了系统上。Tensorflowpython library 提供了一个用来获取 头文件目录的函数 get_include。下面是这个函数在 ubuntu 上的输出:

$ python
>>> import tensorflow as tf
>>> tf.sysconfig.get_include()
'/usr/local/lib/python2.7/site-packages/tensorflow/include'

假设你已经安装好了 g++ ,你可以使用 下面一系列的命令 将你的 op 编译成一个 动态库。

TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')

g++ -std=c++11 -shared zero_out.cc -o zero_out.so -fPIC -I $TF_INC -O2

如果你的 g++ 版本>5.0 的话,加上这个参数 -D_GLIBCXX_USE_CXX11_ABI=0

Use the op in Python

Tensorflow 的 python 接口提供了 tf.load_op_library 函数用来加载动态 library,同时将 op 注册到tensorflow 框架上。load_op_library 返回一个 python module,它包含了 opkernelpython wrapper 。因此,一旦你编译好了一个 op,就可以使用下列代码通过 python来执行它:

import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.so')
with tf.Session(''):
  zero_out_module.zero_out([[1, 2], [3, 4]]).eval()

# Prints
array([[1, 0], [0, 0]], dtype=int32)

记住:生成的函数的名字是 snake_case name。如果在c++文件中, op 的名字是ZeroOut,那么在python 中,名字是 zero_out

完整的代码在文章的最后

Verify that the op works

一个验证你的自定义的op是否正确工作的一个好的方法是 为它写一个测试文件。创建一个 zero_out_op_test.py 文件,内容为:

import tensorflow as tf

class ZeroOutTest(tf.test.TestCase):
  def testZeroOut(self):
    zero_out_module = tf.load_op_library('./zero_out.so')
    with self.test_session():
      result = zero_out_module.zero_out([5, 4, 3, 2, 1])
      self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])

if __name__ == "__main__":
  tf.test.main()

然后运行这个 test

代码

//zero_out.cc 文件
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // 将输入 tensor 从 context 中取出。
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // 创建一个 ouput_tensor, 使用 context->allocate_ouput() 给它分配空间。
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
#创建动态链接库的命令
g++ -std=c++11 -shared zero_out.cc -o zero_out.so -fPIC -D_GLIBCXX_USE_CXX11_ABI=0 -I $TF_INC -O2

总结

tensorflow 自定义 op 的方法可以总结为:

  1. 写个 diy_op.cc 文件
  2. g++ 把这个文件编译成动态链接库
  3. python 中使用 tf.load_op_library 将库导入。
  4. 就可以使用了。

还有一种方法是用 bazel 编译。

参考资料

https://www.tensorflow.org/extend/adding_an_op

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tensorflow:自定义op简单介绍 的相关文章

  • 线性时不变系统输出调节问题

    线性时不变系统输出调节问题 最近在学习 Nonlinear output regulation 中的linear output regulation时 xff0c 对于linear robust output regulation的问题时
  • MinGW-w64安装教程——著名C/C++编译器GCC的Windows版本

    MinGW w64安装教程 著名C C 43 43 编译器GCC的Windows版本 MinGW w64安装教程 著名C C 43 43 编译器GCC的Windows版本 本文主要讲述如何安装 C语言 编译器 MinGW w64 xff0c
  • RT-Thread实时操作系统简介

    目录 一 概述 二 架构 三 版本选择 四 内核启动流程 五 自动初始化机制 六 内核对象模型 七 I O设备模型 1 框架 2 设备驱动使用序列图 3 设备类型 八 FinSH控制台 九 ENV工具 1 menuconfig 2 Scon
  • PCIe RAS

    对于Linux系统针对RAS的AER错误处理机制完成 PCIe RAS简单来讲就是PCIe的错误检测 纠正以及汇报的机制 它可以方便我们准确的定位 xff0c 纠正和分析错误增强系统的健壮性和可靠性 PCIe错误的分类 PCIe错误分为可校
  • Linux下的regulator调试

    先看regulator使用的小demo 如 i2c8 touchscreen 64 28 vddcama supply 61 lt amp xxxxx gt int ret struct regulator power static int
  • 关于添加系统调用遇到 Unable to handle kernel paging request at virtual address 的解决

    Unable to handle kernel paging request at virtual address 是内存访问异常的错误 xff0c 原因通常有三种 xff1a virtual address 为 0x00000000 时
  • vscode安装配置clang-format插件及使用

    vscode安装配置clang format插件及使用 首先安装插件 在vscode扩展里搜索clang format xff0c 安装排名第一的xaver clang format 确认clang format可执行程序路径 window
  • 简历中项目描述怎么写啊

    http wenda tianya cn question 7ade6dc9324bed88
  • 树莓派(Raspberry Pi 3) - 系统烧录及系统使用

    树莓派 xff08 Raspberry pi xff09 是一块集成度极高的ARM开发板 xff0c 不仅包含了HDMI xff0c RCA xff0c CSI xff0c HDMI xff0c GPIO等端口 xff0c 还支持蓝牙以及无
  • flashcache原理

    介绍flashcache的文章很多 xff0c 我就不废话了 使用上 xff0c 有余峰老哥的 文章 xff1b 原理上 xff0c 有ningoo同学的 flashcache系列 但是ningoo同学漏掉了device mapper和fl
  • 无人机算法之PID

    xff08 未完成 xff09 一 PID介绍 xff08 百度百科 xff09 PID 控制器 xff08 比例 积分 微分控制器 xff09 是一个在工业控制应用中常见的反馈回路部件 xff0c 由比例单元 P 积分单元 I 和微分单元
  • java:接口、lambda表达式与内部类

    接口 xff08 interface 接口用来描述类应该做什么 xff0c 而不指定他们具体应该如何做 接口不是类 xff0c 而是对符合这个接口的类的一组需求 接口定义的关键词是interface span class token key
  • 卫星系统算法课程设计 - 第二部分 qt的安装与创建项目

    上一篇文章只讲了基本的东西 xff0c 这一篇要完成qt的安装 xff0c 构建项目 xff0c 并且将上一篇的代码导入进去 某比利比例搜qt安装 xff0c 看到qt5 14 2的下载安装 xff0c 跟着做 1 创建项目 创建新项目 x
  • 无人机-材料准备

    xff08 未完成 xff09 一 使用空心杯电机 xff0c 型号8520 xff0c 1S版本 xff0c 约5G每只 二 空心杯机架 xff0c 型号QX90 xff0c 约8 5g 三 使用55MM桨 四 1S 600MA电池 五
  • CMake中链接库的顺序问题

    原文链接 xff1a https blog csdn net lifemap article details 7586363 cmake中链接库的顺序是a依赖b xff0c 那么b放在a的后面 例如进程test依赖a库 b库 a库又依赖b
  • 鸿蒙wifi Demo运行

    title 鸿蒙Wi Fi Demo运行 date 2021 1 1 22 25 10 categories harmony 本文首发于LHM s notes 欢迎关注我的博客 坑有点多 由于之前没有看过wifi的内核态代码 xff0c 所
  • 将TensorFlow训练好的模型迁移到Android APP上(TensorFlowLite)

    将TensorFlow训练好的模型迁移到Android APP上 xff08 TensorFlowLite xff09 1 写在前面 最近在做一个数字手势识别的APP xff08 关于这个项目 xff0c 我会再写一篇博客仔细介绍 xff0
  • 汉诺塔代码图文详解(递归入门)

    游戏规则 xff1a 已知条件存在A B C三根柱子 xff0c A上套有N片圆盘 如下图 目的将A上的所有圆盘移到C上约束条件每次只能移动一片圆盘 xff0c 且整个过程中只能出现小圆盘在大圆盘之上的情况 首先我们模拟 N 61 2 xf
  • STM32 最小系统电路简析

    文章目录 一 最小系统的组成1 供电电路2 外部晶振3 BOOT选择4 复位电路 二 最小系统实例1 STM32F103C8T6最小系统 三 各部分组成简析1 供电电路设计2 外部晶振原理3 BOOT设计4 复位电路设计 一 最小系统的组成
  • 带参数的宏的问题

    include 34 iostream 34 using namespace std define COMPUTE XX a a a 43 a 2 int main int a 61 2 int test1 61 COMPUTE XX 43

随机推荐

  • python_imbalanced-learn非平衡学习包_02_Over-sampling过采样

    python imbalanced learn非平衡学习包 01 简介 python imbalanced learn非平衡学习包 02 Over sampling过采样 后续章节待定 希望各位认可前面已更 您的认可是我的动力 Over s
  • TX2+JetPack3.2.1+opencv3.3.1+caffe+realsense2.0环境配置教程

    TX2 开箱 一共6样 xff0c 开机之后自带ubuntu16 04LTS的系统 xff0c ARMv8的处理器 xff0c 所以有些指令 xff0c 安装包必须与arm结构保持一致 开机之后 xff0c 按照指示进入图形界面 xff1a
  • 初视openwrt

    openwrt是一个微型的嵌入式操作系统 在编译的时候需要安装许多的工具和库 预置环境 xff1a sudo apt get install g 43 43 libncurses5 dev zlib1g dev bison flex unz
  • 滑动窗口详解

    前言 滑动窗口是双指针的一种特例 xff0c 可以称为左右指针 xff0c 在任意时刻 xff0c 只有一个指针运动 xff0c 而另一个保持静止 滑动窗口路一般用于解决特定的序列中符合条件的连续的子序列的问题 滑动窗口的时间复杂度是线性的
  • RT-Thread入门教程,环境配置和第一个代码

    1 前言 RT Thread这一个操作系统获得很多工程师的好评 xff0c 使用简单 xff0c 支持多 xff0c 有软件包可以下载 xff0c 甚至未来会有更多MicroPython的支持 xff0c 能够兼容主流的一些MCU xff0
  • DHT12温湿度传感器IIC,I2C接口调试心得和代码说明

    来源 xff1a http www fuhome net bbs forum php mod 61 viewthread amp tid 61 2141 DHT11那个单总线的温湿度传感器用的很多了 xff0c aosong推出了DHT12
  • 升级windows11如何在电脑上启用TPM2.0

    本文适用于无法升级到 Windows 11 xff0c 因为他们的电脑当前未启用 TPM 2 0 或其电脑能够运行 TPM 2 0 xff0c 但并未设置为运行 TPM 2 0 1 下载微软电脑健康状况检查 下载地址为 xff1a Wind
  • python调用谷歌翻译

    from GoogleFreeTrans import Translator if name 61 61 39 main 39 translator 61 Translator translator src 61 39 en 39 dest
  • C++(4) 运算符重载

    C 43 43 学习心得 xff08 1 xff09 运算符重载 from 谭浩强 C 43 43 面向对象程序设计 第一版 2014 10 6 4 1什么是运算符重载 用户根据C 43 43 提供的运算符进行重载 xff0c 赋予它们新的
  • C++学习心得(3)多态性与虚函数

    C 43 43 学习心得 xff08 3 xff09 多态性与虚函数 from 谭浩强 C 43 43 面向对象程序设计 第一版 2014 10 13 6 1 多态性的概念 在C 43 43 中 xff0c 多态性是指具有不同功能的函数可以
  • C发送http请求

    C语言发送http请求和普通的socket通讯 原理是一样的 无非就三步connect 连上服务器 send 发送数据 recv 接收数据 只不过发送的数据有特定的格式 下面的是简单发送一个http请求的例子 span class hljs
  • tensorflow(四十七):tensorflow模型持久化

    模型保存 span class token keyword from span tensorflow span class token keyword import span graph util graph def span class
  • git subtree使用

    在一个git项目下引用另一个项目的时 xff0c 我们可以使用 git subtree 使用 git subtree 时 xff0c 主项目下包含子项目的所有代码 使用 git subtree 主要关注以下几个功能 一个项目下如何引入另一个
  • tensorflow(四十八): 使用tensorboard可视化训练出的文本embedding

    对应 tensorflow 1 15版本 log dir span class token operator 61 span span class token string 34 logdir 34 span metadata path s
  • java中数组之间的相互赋值

    前言 本文考虑的研究对象是数组 xff0c 需要明确的是在java中 xff0c 数组是一种对象 xff0c java的所有对象的定义都是放在堆当中的 xff0c 对象变量之间的直接赋值会导致引用地址的一致 在java中声明一个数组 spa
  • tensorflow学习笔记(十):sess.run()

    session run fetch1 fetch2 关于 session run fetch1 fetch2 xff0c 请看http stackoverflow com questions 42407611 how tensorflow
  • tensorflow学习笔记(二十三):variable与get_variable

    Variable tensorflow中有两个关于variable的op xff0c tf Variable 与tf get variable 下面介绍这两个的区别 tf Variable与tf get variable tf Variab
  • pytorch 学习笔记(一)

    pytorch是一个动态的建图的工具 不像Tensorflow那样 xff0c 先建图 xff0c 然后通过feed和run重复执行建好的图 相对来说 xff0c pytorch具有更好的灵活性 编写一个深度网络需要关注的地方是 xff1a
  • pytorch学习笔记(五):保存和加载模型

    span class hljs comment 保存和加载整个模型 span torch save model object span class hljs string 39 model pkl 39 span model 61 torc
  • tensorflow:自定义op简单介绍

    本文只是简单的翻译了 https www tensorflow org extend adding an op 的简单部分 xff0c 高级部分请移步官网 可能需要新定义 c 43 43 operation 的几种情况 xff1a 现有的