如何在 TensorFlow 中有效地分配给张量的切片

2023-12-26

我想为 TensorFlow 2.x 中的一个模型中的输入张量切片分配一些值(我正在使用 2.2,但准备接受 2.1 的解决方案)。 我想做的一个非工作模板是:

import tensorflow as tf
from tensorflow.keras.models import Model

class AddToEven(Model):
    def call(self, inputs):
        outputs = inputs
        outputs[:, ::2] += inputs[:, ::2]
        return outputs

当然,在构建这个时(AddToEven().build(tf.TensorShape([None, None])))我收到以下错误:

TypeError: 'Tensor' object does not support item assignment

我可以通过以下方式实现这个简单的示例:

class AddToEvenScatter(Model):
    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        n = tf.shape(inputs)[-1]
        update_indices = tf.range(0, n, delta=2)[:, None]
        scatter_nd_perm = [1, 0]
        inputs_reshaped = tf.transpose(inputs, scatter_nd_perm)
        outputs = tf.tensor_scatter_nd_add(
            inputs_reshaped,
            indices=update_indices,
            updates=inputs_reshaped[::2],
        )
        outputs = tf.transpose(outputs, scatter_nd_perm)
        return outputs

(您可以通过以下方式进行健全性检查:

model = AddToEvenScatter()
model.build(tf.TensorShape([None, None]))
model(tf.ones([1, 10]))

)

但正如你所看到的,写起来非常复杂。这仅适用于 1D(+ 批量大小)张量的静态更新次数(此处为 1)。

我想做的是更复杂一点,我想用tensor_scatter_nd_add这将是一场噩梦。

目前该主题的许多 WA 都涵盖了变量的情况,但没有涵盖张量的情况(参见例如this https://stackoverflow.com/questions/39157723/how-to-do-slice-assignment-in-tensorflow or this https://stackoverflow.com/questions/39157723/how-to-do-slice-assignment-in-tensorflow/43139565#43139565)。 有提到here https://github.com/tensorflow/tensorflow/issues/33131#issuecomment-614379591pytorch 确实支持这一点,所以我很惊讶地看到最近没有任何 tf 成员就该主题做出回应。这个答案 https://github.com/tensorflow/tensorflow/issues/14132#issuecomment-483002522并没有真正帮助我,因为我需要某种面具生成,这也会很糟糕。

因此,问题是:如何有效地进行切片分配(计算方面、内存方面和代码方面)而不需要tensor_scatter_nd_add?诀窍是我希望它尽可能动态,这意味着inputs可能是可变的。

(对于任何好奇的人,我正在尝试翻译这段代码 https://github.com/lpj-github-io/MWCNNv2/blob/master/MWCNN_code/model/common.py#L80-L99 in tf).

这个问题最初发布在 GitHub 问题中 https://github.com/tensorflow/tensorflow/issues/36559#issuecomment-636084462.


这是另一种基于二进制掩码的解决方案。

"""Solution based on binary mask.
- We just add this mask to inputs, instead of multiplying."""
class AddToEven(tf.keras.Model):
    def __init__(self):
        super(AddToEven, self).__init__()        

    def build(self, inputshape):
        self.built = True # Actually nothing to build with, becuase we don't have any variables or weights here.

    @tf.function
    def call(self, inputs):
        w = inputs.get_shape()[-1]

        # 1-d mask generation for w-axis (activate even indices only)        
        m_w = tf.range(w)  # [0, 1, 2,... w-1]
        m_w = ((m_w%2)==0) # [True, False, True ,...] with dtype=tf.bool

        # Apply 1-d mask to 2-d input
        m_w = tf.expand_dims(m_w, axis=0) # just extend dimension as to be (1, W)
        m_w = tf.cast(m_w, dtype=inputs.dtype) # in advance, we need to convert dtype

        # Here, we just add this (1, W) mask to (H,W) input magically.
        outputs = inputs + m_w # This add operation is allowed in both TF and numpy!
        return tf.reshape(outputs, inputs.get_shape())

在这里进行健全性检查。

# sanity-check as model
model = AddToEven()
model.build(tf.TensorShape([None, None]))
z = model(tf.zeros([2,4]))
print(z)

结果(使用 TF 2.1)是这样的。

tf.Tensor(
[[1. 0. 1. 0.]
 [1. 0. 1. 0.]], shape=(2, 4), dtype=float32)

-------- 以下是之前的回答 --------

您需要在 build() 方法中创建 tf.Variable 。 它还允许通过 shape=(None,) 动态调整大小。 在下面的代码中,我将输入形状指定为(无,无)。

class AddToEven(tf.keras.Model):
    def __init__(self):
        super(AddToEven, self).__init__()

    def build(self, inputshape):
        self.v = tf.Variable(initial_value=tf.zeros((0,0)), shape=(None, None), trainable=False, dtype=tf.float32)

    @tf.function
    def call(self, inputs):
        self.v.assign(inputs)
        self.v[:, ::2].assign(self.v[:, ::2] + 1)
        return self.v.value()

我用 TF 2.1.0 和 TF1.15 测试了这段代码

# test
add_to_even = AddToEven()
z = add_to_even(tf.zeros((2,4)))
print(z)

Result:

tf.Tensor(
[[1. 0. 1. 0.]
 [1. 0. 1. 0.]], shape=(2, 4), dtype=float32)

附:还有其他一些方法,例如使用 tf.numpy_function(),或生成掩码函数。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 TensorFlow 中有效地分配给张量的切片 的相关文章

随机推荐

  • 重建后是否需要重新安装Windows服务

    如果我在进行更改后重建 Windows 服务 我是否可以仅复制并替换旧的程序集 exe 文件来运行这些更改 还是需要重新安装该服务 另外 在安装新版本之前是否必须先卸载该服务 您不必卸载并重新安装该服务 因为这只会添加有关可执行路径和启动选
  • 如果一个 div 具有“clear:right”,那么没有任何东西应该浮动到它的右侧,不是吗?

    我似乎对 css clear 关键字的含义感到困惑 我有许多 div 元素 全部带有 float left 倒数第二个 div 元素也有 clear right 我认为这会导致后续元素转到下一行 但对我来说 事实并非如此 这是我的例子 di
  • iPhone (iOS) 上的 HTML 上传:文件名始终相同(image.jpg、image.png...)

    我正在使用非常简单的代码在我的响应式网站中上传文件 但是当我使用上传图像时iPhone 图像名称始终是image jpg与实际图像名称无关 有解决这个问题的方法吗 我使用小代码创建了此示例页面以进行调试
  • 使用 Powershell 关闭特定 Excel 文件

    用于终止特定 Excel 文件的适当 PowerShell cmdlet 是什么 命令 Stop Process Name excel 关闭所有打开的 Excel 应用程序 提前致谢 其实excel打开 xlsx一个实例中的文件excel
  • C语言中指针的地址可以为0吗?

    我正在阅读书中某个问题的解决方案破解编码面试 http www crackingthecodinginterview com 问题 1 2 目标是实现一个功能void revers char str 在 C 中 反转空终止字符串 解决方案代
  • 在 Android 中动态创建 EditText

    我正在开发一个应用程序 我必须创建多个EditText和动态微调器 所以我开始寻找 解决方案 因为我没有使用权限无形的XML 文件中的属性 我搜索了很多并得到了很少的例子 仅在堆栈溢出 我跟随他们并创建了这个程序 MainActivity
  • 安装 aws php sdk - 意外变量

    我正在尝试使用 AWS php sdk 但在设置时遇到一些问题 当我运行需要自动加载器的 php 脚本时 出现此错误 Parse error syntax error unexpected value T VARIABLE in direc
  • 带有向量条件值的数据框的 r 下标

    这似乎很容易 但它让我忙了一段时间 我有一个包含 n 列的数据框 df 和一个具有相同数量 n 值的向量 向量中的值是数据帧中列中观测值的阈值 所以线索是 如何告诉 R 对每一列使用不同的阈值 我想将所有观察结果保留在数据框中 以满足每列的
  • MATLAB:获取文件的最后修改时间

    我正在寻找执行一些例程的 MATLAB 代码 更新file m if file csv最近编辑于file m 应该看起来像这样 Write time extraction tempC GetFileTime file csv Write t
  • Eclipse Scala IDE 中的 Scala-Lift 项目错误

    我安装了适用于 Eclipse 的 Scala IDE http www scala ide org 而且似乎工作正常 所以现在我正在尝试导入一个 Lift 项目 特别是 自动生成的 Lift 项目 斯塔克斯应用平台 http stax n
  • Plotly.js 模式栏,下载为 png,给 png 命名

    我的网页上有一个 Plotly 您可以通过单击模式栏中的图片图标将其下载为 png 但是 当我单击它时 它会将其下载为 png 格式 名称为 new plot 如何为其指定自定义名称 我当前的代码 var data 只是数据 所以将其省略
  • Windows 上 Ubuntu 上的 Bash 上的 pem 文件权限

    我尝试使用 pem 文件登录我的盒子 但收到错误消息 WARNING UNPROTECTED PRIVATE KEY FILE Permissions 0555 for arete server pem are too open It is
  • 如何让 Hibernate 调用我的自定义 typedef?

    我正在尝试定义 CompositeUserType 来处理 JPA Hibernate 应用程序中的特定类型 我有一个名为 ApplicationMessageType 的 CompositeUserType 旨在处理我的映射 根据我所读到
  • 将 PreBuiltTransportClient 与 elasticsearch 5 结合使用

    我正在尝试按照官方 Elasticsearch 5 文档来设置传输客户端 https www elastic co guide en elasticsearch client java api 5 0 transport client ht
  • PHP 和 Composer,如何组合composer.json 文件

    有人可以解释一下我应该如何将 Composer 与 php ini 一起使用吗 我的文档根目录中有一个composer json文件 它下载我的项目的核心包 但是当我想添加另一个项目 例如在这里找到的google php sdk 时http
  • [13]:Array 的未定义方法“assign_attributes”

    我的应用程序设置为 当 Product sold 属性为 1 时 表示商品已售出 并且不会显示在商店视图中 我正在尝试获取它 以便当客户签出时 购买商品时会更新product sold 属性 以下是我的控制器中应将 Product sold
  • 如何使用CSS在图像上添加覆盖颜色

    如果我有这样的图像 img src inshot1 jpg width 100px height 100px 悬停时我希望该块被某种颜色覆盖 例如 当您将鼠标悬停在其上时 您会看到一块具有相同高度和宽度的红色块 那么基本上是叠加吗 您可以通
  • 单个应用程序二进制文件如何支持 64 位和 32 位应用程序

    我们可以看到苹果的公告here https developer apple com news 根据这个文档 我们可以提交相同的二进制文件 支持 32 位和 64 位 我找到了一个堆栈溢出答案here https stackoverflow
  • 输入字段问题 - 关闭窗口但保持 Python 运行 [重复]

    这个问题在这里已经有答案了 有点长的问题 我正在创建一个输入字段 在 skrx 的主要帮助下 该字段显示在定制屏幕上 我已经对其进行了编程 以便当我按 Enter 键时 屏幕应该自行关闭 目前这个 pygame display quit 感
  • 如何在 TensorFlow 中有效地分配给张量的切片

    我想为 TensorFlow 2 x 中的一个模型中的输入张量切片分配一些值 我正在使用 2 2 但准备接受 2 1 的解决方案 我想做的一个非工作模板是 import tensorflow as tf from tensorflow ke