tensorflow py_func 很方便,但使我的训练步骤非常慢。

2024-01-03

我在使用张量流函数 py_func 时遇到一些效率问题。

Context

在我的项目中,我有一批张量input_features大小的[? max_items m]。第一个维度设置为?因为它是动态形状(为自定义张量流读取器读取批次,并使用 tf.train.shuffle_batch_join() 进行混洗)。第二个维度对应于上限(我的示例中可以采用的最大项目数),第三个维度对应于特征维度空间。我也有一个张量num_items具有批量大小的尺寸(因此形状是(?,)),表示示例中的项数,其他设置为0(采用numpy的写法input_feature[k, num_items[k]:, :] = 0)

Issue

我的工作流程需要一些自定义 python 操作(特别是为了处理索引,我需要或实例对某些示例块执行聚类操作),并且我使用了一些包装在其中的 numpy 函数py_func功能。这效果很好,但是训练变得非常非常慢(比没有这个 py_func 的模型慢大约 50 倍),而且函数本身并不耗时。

问题

1 - 计算时间增加正常吗?该函数包裹在py_func给我一个新的张量,该张量在此过程中进一步相乘。它能解释计算时间吗? (我的意思是用这样的函数计算梯度可能更困难)。

2 - 我正在尝试修改我的处理并避免使用py_func功能。然而,使用 numpy 索引提取数据非常方便(尤其是我的数据格式化),并且我在以 TF 方式传递它时遇到一些困难。例如,如果我有一个张量t1有形状[-1, n_max, m](第一个维度是batch_size,它是动态的)并且t2有形状[-1,2]包含整数。有没有一种简单的方法可以在张量流中执行平均运算,从而导致t_mean_chunk有形状(-1, m)其中(在 numpy 公式中):t_mean_chunk[i,:] = np.mean(t1[i, t2[i,0]:t2[i,1], :], axis=0)? 这是(以及其他操作)我在包装函数中所做的事情。


如果没有确切的 py_func,问题 1 很难回答,但正如 hpaulj 在他的评论中提到的那样,它减慢了速度也就不足为奇了。作为最坏情况的后备方案,tf.scan or tf.while_loop with a TensorArray可能会快一些。然而,最好的情况是使用 TensorFlow 操作提供矢量化解决方案,我认为在这种情况下这是可能的。

至于问题 2,我不确定它是否算简单,但这里有一个计算索引表达式的函数:

import tensorflow as tf

def range_mean(index_ranges, values):
  """Take the mean of `values` along ranges specified by `index_ranges`.

  return[i, ...] = tf.reduce_mean(
    values[i, index_ranges[i, 0]:index_ranges[i, 1], ...], axis=0)

  Args:
    index_ranges: An integer Tensor with shape [N x 2]
    values: A Tensor with shape [N x M x ...].
  Returns:
    A Tensor with shape [N x ...] containing the means of `values` having
    indices in the ranges specified.
  """
  m_indices = tf.range(tf.shape(values)[1])[None]
  # Determine which parts of `values` will be in the result
  selected = tf.logical_and(tf.greater_equal(m_indices, index_ranges[:, :1]),
                            tf.less(m_indices, index_ranges[:, 1:]))
  n_indices = tf.tile(tf.range(tf.shape(values)[0])[..., None],
                      [1, tf.shape(values)[1]])
  segments = tf.where(selected, n_indices + 1, tf.zeros_like(n_indices))
  # Throw out segment 0, since that's our "not included" segment
  segment_sums = tf.unsorted_segment_sum(
      data=values,
      segment_ids=segments, 
      num_segments=tf.shape(values)[0] + 1)[1:]
  divisor = tf.cast(index_ranges[:, 1] - index_ranges[:, 0],
                    dtype=values.dtype)
  # Pad the shape of `divisor` so that it broadcasts against `segment_sums`.
  divisor_shape_padded = tf.reshape(
      divisor,
      tf.concat([tf.shape(divisor), 
                 tf.ones([tf.rank(values) - 2], dtype=tf.int32)], axis=0))
  return segment_sums / divisor_shape_padded

用法示例:

index_range_tensor = tf.constant([[2, 4], [1, 6], [0, 3], [0, 9]])
values_tensor = tf.reshape(tf.range(4 * 10 * 5, dtype=tf.float32), [4, 10, 5])
with tf.Session():
  tf_result = range_mean(index_range_tensor, values_tensor).eval()
  index_range_np = index_range_tensor.eval()
  values_np = values_tensor.eval()

for i in range(values_np.shape[0]):
  print("Slice {}: ".format(i),
        tf_result[i],
        numpy.mean(values_np[i, index_range_np[i, 0]:index_range_np[i, 1], :],
                   axis=0))

Prints:

Slice 0:  [ 12.5  13.5  14.5  15.5  16.5] [ 12.5  13.5  14.5  15.5  16.5]
Slice 1:  [ 65.  66.  67.  68.  69.] [ 65.  66.  67.  68.  69.]
Slice 2:  [ 105.  106.  107.  108.  109.] [ 105.  106.  107.  108.  109.]
Slice 3:  [ 170.  171.  172.  173.  174.] [ 170.  171.  172.  173.  174.]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tensorflow py_func 很方便,但使我的训练步骤非常慢。 的相关文章

  • sphinx 中的分组方法文档字符串

    是否可以使用 sphinx 的 autodoc 功能将多个方法文档字符串分组 以便将它们列在一起 class Test object def a self A method of group foo def b self A method
  • HoughLinesP后如何合并线?

    My task is to find coordinates of lines startX startY endX endY and rectangles 4 lines Here is input file 我使用下一个代码 img c
  • Python pandas:删除字符串中分隔符之后的所有内容

    我有数据框 其中包含例如 vendor a ProductA vendor b ProductA vendor a Productb 我需要删除所有内容 包括 两个 以便我最终得到 vendor a vendor b vendor a 我尝
  • 检查多维 numpy 数组的所有边是否都是零数组

    n 维数组有 2n 个边 1 维数组有 2 个端点 2 维数组有 4 个边或边 3 维数组有 6 个 2 维面 4 维数组有 8 个边 ETC 这类似于抽象 n 维立方体发生的情况 我想检查 n 维数组的所有边是否仅由零组成 以下是边由零组
  • 使用Python进行图像识别[关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我有一个想法 就是我想识别图像中的字母 可能是 bmp或 jpg 例如 这是一个包含字母 S 的 bmp 图像 我想做的是使用Pyth
  • 将分布拟合到直方图

    I want to know the distribution of my data points so first I plotted the histogram of my data My histogram looks like th
  • 覆盖现有的 django-admin 命令

    除了编写自定义 django admin 命令之外 这是有详细记录的 https docs djangoproject com en 1 9 howto custom management commands 我希望能够覆盖现有命令 例如ma
  • 来自数据框 groupby 的条形图

    import pandas as pd import numpy as np import matplotlib pyplot as plt df pd read csv arrests csv df df replace np nan 0
  • 什么时候用==,什么时候用is?

    奇怪的是 gt gt gt a 123 gt gt gt b 123 gt gt gt a is b True gt gt gt a 123 gt gt gt b 123 gt gt gt a is b False Seems a is b
  • Python代码执行时自动打开浏览器

    我正在 Python Flask 中实现 GUI Flask 的设计方式是 必须 手动 打开本地主机以及端口号 有没有一种方法可以使其自动化 以便在运行代码时自动打开浏览器 本地主机 我尝试使用 webbrowser 包 但它在会话终止后打
  • 如何将一串Python代码编译成一个可以调用函数的模块?

    在 Python 中 我有一串 Python 源代码 其中包含以下函数 mySrc def foo print foo def bar print bar 我想将这个字符串编译成某种形式类似模块的对象这样我就可以调用代码中包含的函数 这是我
  • 如何将 pip 指向 Mercurial 分支?

    我正在尝试通过 pip 将我的应用程序安装到 virtualenv 进行测试 安装时效果很好default or tip像这样 pip install e hg https email protected cdn cgi l email p
  • Bottle 是否可以处理没有并发的请求?

    起初 我认为 Bottle 会并发处理请求 所以我编写了如下测试代码 import json from bottle import Bottle run request response get post import time app B
  • 当我打印“查询”时获取 PY_VAR1

    我正在制作一个简单的网络抓取代码 当我尝试打印一个值时 它给了我其他东西 def PeopleSearch query SearchTerm query what is query print str query SearchTerm St
  • Python `concurrent.futures`:根据完成顺序迭代 future

    我想要类似的东西executor map 除了当我迭代结果时 我想根据完成的顺序迭代它们 例如首先完成的工作项应该首先出现在迭代中 等等 这样 当且仅当序列中的每个工作项尚未完成时 迭代就会阻塞 我知道如何使用队列自己实现这一点 但我想知道
  • 在Python中确定句子中2个单词之间的邻近度

    我需要确定 Python 句子中两个单词之间的接近度 例如 在下面的句子中 the foo and the bar is foo bar 我想确定单词之间的距离foo and bar 确定之间出现的单词数foo and bar 请注意 该词
  • 如何让 Python 找到 ffprobe?

    I have ffmpeg and ffprobe安装在我的 mac macOS Sierra 上 并且我已将它们的路径添加到 PATH 中 我可以从终端运行它们 我正在尝试使用ffprobe使用以下代码获取视频文件的宽度和高度 impor
  • 为什么 tesseract 无法从这个简单的图像中读取文本?

    我在 pytesseract 上阅读了大量的帖子 但我无法让它从一个简单的图像中读取文本 它返回一个空字符串 这是图像 我尝试过缩放它 灰度化它 调整对比度 阈值 模糊 以及其他帖子中所说的一切 但我的问题是我不知道 OCR 想要更好地工作
  • 在Python中将罗马数字转换为整数

    根据 user2486 所说 这是我当前的代码 def romanMap map M 1000 CM 900 D 500 CD 400 C 100 XC 90 L 50 XL 40 X 10 IX 9 V 5 V 4 I 1 return
  • 将自定义属性添加到 Tk 小部件

    我的主要目标是向小部件添加隐藏标签或字符串之类的内容 以在其上保存简短信息 我想到创建一个新的自定义 Button 类 在本例中我需要按钮 它继承所有旧选项 这是代码 form tkinter import class NButton Bu

随机推荐