如何在 TensorFlow 2.0 中使用 tf.Lambda 和 tf.Variable

2024-02-16

我对 TensorFlow 2.0 非常陌生。

我为循环 GAN 编写了如下代码(我仅提取用于构建生成器神经网络的代码):

def instance_norm(x, epsilon=1e-5):

    scale = tf.Variable(initial_value=np.random.normal(1., 0.02, x.shape[-1:]),
                        trainable=True,
                        name='SCALE',
                        dtype=tf.float32)
    offset = tf.Variable(initial_value=np.zeros(x.shape[-1:]),
                         trainable=True,
                         name='OFFSET',
                         dtype=tf.float32)
    mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
    inv = tf.math.rsqrt(variance + epsilon)
    normalized = (x - mean) * inv
    return scale * normalized + offset

def build_generator(options, name='Generator'):

    initializer = tf.random_normal_initializer(0., 0.02)

    inputs = Input(shape=(options.time_step,
                          options.pitch_range,
                          options.output_nc))

    x = inputs
    # (batch * 64 * 84 * 1)

    x = layers.Lambda(padding,
                      name='PADDING_1')(x)
    # (batch * 70 * 90 * 1)

    x = layers.Conv2D(filters=options.gf_dim,
                      kernel_size=7,
                      strides=1,
                      padding='valid',
                      kernel_initializer=initializer,
                      use_bias=False,
                      name='CONV2D_1')(x)
    x = layers.Lambda(instance_norm,
                      name='IN_1')(x)
    x = layers.ReLU()(x)

但是当我运行这段代码时,出现如下错误:

Traceback (most recent call last):
  File "tf2_main.py", line 50, in <module>
    model = CycleGAN(args)
  File "/Users/mhiro/PycharmProjects/music_gan/CycleGAN-Music-Style-Transfer-Refactorization-master/tf2_model.py", line 55, in __init__
    self._build_model(args)
  File "/Users/mhiro/PycharmProjects/music_gan/CycleGAN-Music-Style-Transfer-Refactorization-master/tf2_model.py", line 63, in _build_model
    name='Generator_A2B')
  File "/Users/mhiro/PycharmProjects/music_gan/CycleGAN-Music-Style-Transfer-Refactorization-master/tf2_module.py", line 154, in build_generator
    name='IN_1')(x)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py", line 773, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow_core/python/keras/layers/core.py", line 847, in call
    self._check_variables(created_variables, tape.watched_variables())
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow_core/python/keras/layers/core.py", line 873, in _check_variables
    raise ValueError(error_str)
ValueError: 
The following Variables were created within a Lambda layer (IN_1)
but are not tracked by said layer:
  <tf.Variable 'IN_1/SCALE:0' shape=(64,) dtype=float32>
  <tf.Variable 'IN_1/OFFSET:0' shape=(64,) dtype=float32>
The layer cannot safely ensure proper Variable reuse across multiple
calls, and consquently this behavior is disallowed for safety. Lambda
layers are not well suited to stateful computation; instead, writing a
subclassed Layer is the recommend way to define layers with
Variables.

看来我应该重写 tf.Lambda 和 tf.Variable 部分。

谁能教我如何重写这段代码?


Lambda https://www.tensorflow.org/api_docs/python/tf/keras/layers/Lambda层是无状态的,也就是说,您不能在其中定义变量。相反,你可以宁愿编写自定义层 https://www.tensorflow.org/guide/keras/custom_layers_and_models。大致如下:

import tensorflow as tf
from tensorflow.keras import layers

class InstanceNorm(layers.Layer):
    def __init__(self):
        super(InstanceNorm, self).__init__()

    def build(self, input_shape):
        self.scale = self.add_weight(shape=your_shape_1,
                                 initializer=your_initializer_1,
                                 trainable=True)
        self.offset = self.add_weight(shape=your_shape_2,
                                 initializer=your_initializer_2,
                                 trainable=True)

  def call(self, x, epsilon=1e-5):
        mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
        inv = tf.math.rsqrt(variance + epsilon)
        normalized = (x - mean) * inv
        return self.scale * normalized + self.offset

现在可以按如下方式调用该层:

...
x = layers.Conv2D(filters=options.gf_dim,
                  kernel_size=7,
                  strides=1,
                  padding='valid',
                  kernel_initializer=initializer,
                  use_bias=False,
                  name='CONV2D_1')(x)
x = InstanceNorm()(x)
x = layers.ReLU()(x)
...

NOTE:未测试。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 TensorFlow 2.0 中使用 tf.Lambda 和 tf.Variable 的相关文章

  • 如何使用 opencv.omnidir 模块对鱼眼图像进行去扭曲

    我正在尝试使用全向模块 http docs opencv org trunk db dd2 namespacecv 1 1omnidir html用于对鱼眼图像进行扭曲处理Python 我正在尝试适应这一点C 教程 http docs op
  • 安装了 32 位的 Python,显示为 64 位

    我需要运行 32 位版本的 Python 我认为这就是我在我的机器上运行的 因为这是我下载的安装程序 当我重新运行安装程序时 它会将当前安装的 Python 版本称为 Python 3 5 32 位 然而当我跑步时platform arch
  • 将html数据解析成python列表进行操作

    我正在尝试读取 html 网站并提取其数据 例如 我想查看公司过去 5 年的 EPS 每股收益 基本上 我可以读入它 并且可以使用 BeautifulSoup 或 html2text 创建一个巨大的文本块 然后我想搜索该文件 我一直在使用
  • 处理 Python 行为测试框架中的异常

    我一直在考虑从鼻子转向行为测试 摩卡 柴等已经宠坏了我 到目前为止一切都很好 但除了以下之外 我似乎无法找出任何测试异常的方法 then It throws a KeyError exception def step impl contex
  • 用枢轴点拟合曲线 Python

    我有下面的图 我想用 2 条线来拟合它 使用 python 我设法适应上半部分 def func x a b x np array x return a x b popt pcov curve fit func up x up y 我想用另
  • Pandas 日期时间格式

    是否可以用零后缀表示 pd to datetime 似乎零被删除了 print pd to datetime 2000 07 26 14 21 00 00000 format Y m d H M S f 结果是 2000 07 26 14
  • 在Python中连接反斜杠

    我是 python 新手 所以如果这听起来很简单 请原谅我 我想加入一些变量来生成一条路径 像这样 AAAABBBBCCCC 2 2014 04 2014 04 01 csv Id TypeOfMachine year month year
  • datetime.datetime.now() 返回旧值

    我正在通过匹配日期查找 python 中的数据存储条目 我想要的是每天选择 今天 的条目 但由于某种原因 当我将代码上传到 gae 服务器时 它只能工作一天 第二天它仍然返回相同的值 例如当我上传代码并在 07 01 2014 执行它时 它
  • 为什么 PyYAML 花费这么多时间来解析 YAML 文件?

    我正在解析一个大约 6500 行的 YAML 文件 格式如下 foo1 bar1 blah name john age 123 metadata whatever1 whatever whatever2 whatever stuff thi
  • Python beautifulsoup 仅限 1 级文本

    我看过其他 beautifulsoup 得到相同级别类型的问题 看来我的有点不同 这是网站 我正试图拿到右边那张桌子 请注意表的第一行如何展开为该数据的详细细分 我不想要那个数据 我只想要最顶层的数据 您还可以看到其他行也可以展开 但在本例
  • 如何使用 Mysql Python 连接器检索二进制数据?

    如果我在 MySQL 中创建一个包含二进制数据的简单表 CREATE TABLE foo bar binary 4 INSERT INTO foo bar VALUES UNHEX de12 然后尝试使用 MySQL Connector P
  • 如何通过 TLS 1.2 运行 django runserver

    我正在本地 Mac OS X 机器上测试 Stripe 订单 我正在实现这段代码 stripe api key settings STRIPE SECRET order stripe Order create currency usd em
  • 如何使用 pybrain 黑盒优化训练神经网络来处理监督数据集?

    我玩了一下 pybrain 了解如何生成具有自定义架构的神经网络 并使用反向传播算法将它们训练为监督数据集 然而 我对优化算法以及任务 学习代理和环境的概念感到困惑 例如 我将如何实现一个神经网络 例如 1 以使用 pybrain 遗传算法
  • javascript 是否有等效的 __repr__ ?

    我最接近Python的东西repr这是 function User name password this name name this password password User prototype toString function r
  • python import inside函数隐藏现有变量

    我在我正在处理的多子模块项目中遇到了一个奇怪的 UnboundLocalError 分配之前引用的局部变量 问题 并将其精简为这个片段 使用标准库中的日志记录模块 import logging def foo logging info fo
  • 将 Python 中的日期与日期时间进行比较

    所以我有一个日期列表 datetime date 2013 7 9 datetime date 2013 7 12 datetime date 2013 7 15 datetime date 2013 7 18 datetime date
  • Scipy Sparse:SciPy/NumPy 更新后出现奇异矩阵警告

    我的问题是由大型电阻器系统的节点分析产生的 我基本上是在设置一个大的稀疏矩阵A 我的解向量b 我正在尝试求解线性方程A x b 为了做到这一点 我正在使用scipy sparse linalg spsolve method 直到最近 一切都
  • 在 JavaScript 函数的 Django 模板中转义字符串参数

    我有一个 JavaScript 函数 它返回一组对象 return Func id name 例如 我在传递包含引号的字符串时遇到问题 Dr Seuss ABC BOOk 是无效语法 I tried name safe 但无济于事 有什么解
  • cv2.VideoWriter:请求一个元组作为 Size 参数,然后拒绝它

    我正在使用 OpenCV 4 0 和 Python 3 7 创建延时视频 构造 VideoWriter 对象时 文档表示 Size 参数应该是一个元组 当我给它一个元组时 它拒绝它 当我尝试用其他东西替换它时 它不会接受它 因为它说参数不是
  • 张量流中的复杂卷积

    我正在尝试运行一个简单的卷积 但包含复数 r np random random 1 10 10 10 i np random random 1 10 10 10 x tf complex r i conv layer tf layers c

随机推荐