Tensorflow.Keras:自定义约束不起作用

2024-03-27

我正在尝试实现权重正交约束所示here https://towardsdatascience.com/build-the-right-autoencoder-tune-and-optimize-using-pca-principles-part-ii-24b9cca69bd6,在第 2.0 节中。当我尝试在 Keras 密集层上使用它时,会出现值错误。

当尝试实现同一篇文章第 3.0 部分中的自定义不相关特征约束时,也会发生这种情况。

import tensorflow as tf
import numpy as np

class WeightsOrthogonalityConstraint(tf.keras.constraints.Constraint):
    def __init__(self, encoding_dim, weightage = 1.0, axis = 0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage
        self.axis = axis

    def weights_orthogonality(self, w):
        if(self.axis==1):
            w = tf.keras.backend.transpose(w)
        if(self.encoding_dim > 1):
            m = tf.keras.backend.dot(tf.keras.backend.transpose(w), w) - tf.keras.backend.eye(self.encoding_dim)
            return self.weightage * tf.keras.backend.sqrt(tf.keras.backend.sum(tf.keras.backend.square(m)))
        else:
            m = tf.keras.backend.sum(w ** 2) - 1.
            return m

    def __call__(self, w):
        return self.weights_orthogonality(w)

rand_samples = np.random.rand(16, 4)
dummy_ds = tf.data.Dataset.from_tensor_slices((rand_samples, rand_samples)).shuffle(16).batch(16)

encoder = tf.keras.layers.Dense(2, "relu", input_shape=(4,), kernel_regularizer=WeightsOrthogonalityConstraint(2))
decoder = tf.keras.layers.Dense(4, "relu")

autoencoder = tf.keras.models.Sequential()
autoencoder.add(encoder)
autoencoder.add(decoder)

autoencoder.compile(metrics=['accuracy'],
                    loss='mean_squared_error',
                    optimizer='sgd')

autoencoder.summary()

autoencoder.fit(dummy_ds, epochs=1)

如果我停止使用约束,则不会出现错误,但是在使用时,会引发下一个错误:

2019-09-07 14:20:25.962610: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library nvcuda.dll
2019-09-07 14:20:26.997957: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1640] Found device 0 with properties: 
name: GeForce GTX 1060 major: 6 minor: 1 memoryClockRate(GHz): 1.733
pciBusID: 0000:01:00.0
2019-09-07 14:20:27.043016: I tensorflow/stream_executor/platform/default/dlopen_checker_stub.cc:25] GPU libraries are statically linked, skip dlopen check.        
2019-09-07 14:20:27.050749: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1763] Adding visible gpu devices: 0
2019-09-07 14:20:27.081369: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
2019-09-07 14:20:27.113598: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1640] Found device 0 with properties: 
name: GeForce GTX 1060 major: 6 minor: 1 memoryClockRate(GHz): 1.733
pciBusID: 0000:01:00.0
2019-09-07 14:20:27.144194: I tensorflow/stream_executor/platform/default/dlopen_checker_stub.cc:25] GPU libraries are statically linked, skip dlopen check.        
2019-09-07 14:20:27.151802: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1763] Adding visible gpu devices: 0
2019-09-07 14:20:27.800616: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1181] Device interconnect StreamExecutor with strength 1 edge matrix:
2019-09-07 14:20:27.817323: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1187]      0 
2019-09-07 14:20:27.840635: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1200] 0:   N 
2019-09-07 14:20:27.848536: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1326] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 4712 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1060, pci bus id: 0000:01:00.0, compute capability: 6.1)
Traceback (most recent call last):
  File "c:\Users\whitm\.vscode\extensions\ms-python.python-2019.9.34911\pythonFiles\ptvsd_launcher.py", line 43, in <module>
    main(ptvsdArgs)
  File "c:\Users\whitm\.vscode\extensions\ms-python.python-2019.9.34911\pythonFiles\lib\python\ptvsd\__main__.py", line 432, in main
    run()
  File "c:\Users\whitm\.vscode\extensions\ms-python.python-2019.9.34911\pythonFiles\lib\python\ptvsd\__main__.py", line 316, in run_file
    runpy.run_path(target, run_name='__main__')
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "c:\Users\whitm\Desktop\CodeProjects\ForestClassifier-DEC\Test.py", line 35, in <module>
    optimizer='sgd')
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\training\tracking\base.py", line 458, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\keras\engine\training.py", line 337, in compile
    self._compile_weights_loss_and_weighted_metrics()
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\training\tracking\base.py", line 458, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1494, in _compile_weights_loss_and_weighted_metrics
    self.total_loss = self._prepare_total_loss(masks)
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1601, in _prepare_total_loss
    custom_losses = self.get_losses_for(None) + self.get_losses_for(
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1209, in get_losses_for
    return [l for l in self.losses if l._unconditional_loss]
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 835, in losses
    return collected_losses + self._gather_children_attribute('losses')
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 2129, in _gather_children_attribute      
    getattr(layer, attribute) for layer in nested_layers))
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 2129, in <genexpr>
    getattr(layer, attribute) for layer in nested_layers))
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 832, in losses
    loss_tensor = regularizer()
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 907, in _tag_unconditional
    loss = loss()
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1659, in _loss_for_variable
    regularization = regularizer(v)
  File "c:\Users\whitm\Desktop\CodeProjects\ForestClassifier-DEC\Test.py", line 21, in __call__
    return self.weights_orthogonality(w)
  File "c:\Users\whitm\Desktop\CodeProjects\ForestClassifier-DEC\Test.py", line 14, in weights_orthogonality
    m = tf.keras.backend.dot(tf.keras.backend.transpose(w), w) - tf.keras.backend.eye(self.encoding_dim)
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\keras\backend.py", line 1310, in eye
    return variable(linalg_ops.eye(size, dtype=tf_dtype), dtype, name)
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\keras\backend.py", line 785, in variable
    constraint=constraint)
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\ops\variables.py", line 264, in __call__
    return super(VariableMetaclass, cls).__call__(*args, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\ops\resource_variable_ops.py", line 464, in __init__
    shape=shape)
  File "C:\ProgramData\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\ops\resource_variable_ops.py", line 550, in _init_from_args
    raise ValueError("Tensor-typed variable initializers must either be "
ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable (e.g., `tf.Variable(lambda : tf.truncated_normal([10, 40]))`) when building functions. Please file a feature request if this restriction inconveniences you.

提前致谢!

PD: Here https://colab.research.google.com/drive/19jTi5jRaDKFey0QZ1FQgOUqiz3UqS3xm是一个显示错误的 Colab Notebook

PD2:我设法找到导致问题的线路,是这个:

m = tf.keras.backend.dot(tf.keras.backend.transpose(w), w) - tf.keras.backend.eye(self.encoding_dim)

特别是 keras 后端 eye() 函数导致了问题


我设法解决这个问题:

导致错误的函数是第 14 行的 tf.keras.backed.eye()。我在那里读到,该函数的 keras 后端中的实现使用 numpy 数组作为单位矩阵,但张量流和其他后端已经有了它们的实现使用张量来实现此函数。 tf2.0 上缺少张量导致错误,只需将 tf.keras.backed.eye() 更改为 tf.eye() 即可解决问题。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow.Keras:自定义约束不起作用 的相关文章

  • 在 Tensorflow 中使用队列将数据馈送到网络时分开验证和训练图

    我一直在做大量关于如何使用队列将数据正确输入网络的研究 但是 我在互联网上找不到任何解决方案 目前我的代码能够读取训练数据并执行训练 但无需验证和测试 这里有一些重要的行构成了我的代码 images volumes utils inputs
  • 以编程方式设置 mosquitto 中的访问控制限制

    我正在开发一个将使用 mqtt 的应用程序 我将使用 python 库 我一直倾向于使用 mosquitto 但找不到以编程方式为其设置访问控制限制的方法 我正在编写的应用程序需要能够区分用户 并且只允许他们订阅某些主题 当前的解决方案看起
  • 如何让电脑看起来像是在打字? [复制]

    这个问题在这里已经有答案了 我希望它看起来像是计算机正在尝试向用户输入信息 我尝试了一些代码 但是当我运行它时 它只是一次打印所有内容 即使我一次打印 1 个 A Random sentence for x in A time sleep
  • 如何在 Python 3.2 程序中优雅地包含 Python 3.3 from None 异常语法?

    我正在尝试重新引发异常 以便为用户提供有关实际错误的更好信息 Python 3 3 包括PEP 409 http www python org dev peps pep 0409 它添加了raise NewException from No
  • 使用 python 从 XSD 文件创建特定的 XML 文件

    我有一个现有的 xsd 架构 并且需要创建 希望使用 Python 带有一些特定输入的 XML 文件 最好的方法是什么 我尝试了 Element Tree 和 xmlschema 但我无法判断它们是否允许从已知的 XSD 架构开始生成 XM
  • dask groupby 不合并分区

    我有一组数据 我想要对其进行一些简单的 groupby count 操作 但我似乎无法使用 dask 来完成此操作 我很可能不理解 dask 中执行 groupby reduce 的方式 特别是当索引位于分组键中时 所以我将用玩具数据来说明
  • PyCharm 虚拟环境和 Anaconda 环境有什么区别?

    当我在 PyCharm 中创建新项目时 它会创建一个新的虚拟环境 我读到 当我执行Python脚本时 它们是使用此环境中的解释器而不是系统环境来执行的 因此 如果我需要安装一些软件包 我只能将它们安装在这个环境中 而不是在系统环境中 这很酷
  • pandas dataframe 视图与复制,我如何区分?

    有什么区别 pandas df loc col a col b and df loc col a col b 下面的链接没有提到后者 尽管它有效 两者都拉视图吗 第一个拉取视图 第二个拉取副本吗 http pandas pydata org
  • Python相对导入导致语法错误:无效语法

    我正在尝试安装这个很棒的 python 模块Python Chrono http oss codepoet no python chrono wiki Home我的 python 环境 但至少在 python 2 4 3 和 2 6 6 中
  • 单词和表情符号计数器

    我有一个包含 clear message 列的数据框 并且创建了一个用于计算每行中所有单词的列 history word count history clear message apply lambda x Counter x split
  • 适用于 Web 照片库的正确 NoSQL 数据架构

    我正在寻找为照片库的 NoSQL 存储构建合适的数据结构 在我的网络应用程序中 一张照片可以是一个或多个相册的一部分 我有使用 MySQL 的经验 但几乎没有使用键值存储的经验 使用 MySQL 我将设置 3 个表 如下所示 photos
  • GitPython 并向 Git 对象发送命令

    GitPython http gitorious org git python是一种从 python 与 git 交互的方式 我正在尝试访问基本的 git 命令 例如git commit m message 从此模块中 根据this htt
  • 将文件转换为 Ascii 抛出异常

    后果我之前的问题 https stackoverflow com questions 31742609 how to strip the leading unciode characters from a file 31742694 nor
  • 后视模式无效

    为什么这个正则表达式在 Python 中有效 但在 Ruby 中无效
  • 无需重新计算即可获取字典键哈希

    有没有办法从字典中提取现有的密钥哈希 而无需再次重新计算它们 暴露它们并因此通过哈希而不是密钥访问字典会有什么风险 我认为 Python 的字典对象没有任何公共 API 可以让您查看存储其对象的哈希值 您无法在 Python 代码中直接通过
  • 计算具有不均匀间隔点的 3D 梯度

    我目前有一个由几百万个不均匀间隔的粒子组成的体积 每个粒子都有一个属性 对于那些好奇的人来说是潜力 我想计算其局部力 加速度 np gradient 仅适用于均匀分布的数据 我在这里查看 numpy 中的二阶梯度 https stackov
  • 如何获得 GTK 中的默认颜色?

    Context 在 GTK 3 中 人们可以设置自己的主题 甚至默认主题 Adwaita 也提供两种变体 浅色和深色 当我编写自己的小部件 用Python 时 我需要获取这些颜色以避免在黑色上绘制黑色或在白色上绘制白色 Question 如
  • Scipy:在对整个表面进行集成时加快集成速度?

    I have a probability density function pdf f x y And to get its cumulative distribution function cdf F x y at point x y y
  • pandas 数据框中的 count 和 countif

    我有一个 DF 如下所示 trainee course completed days overdue Ava ABC Yes 0 Bob ABC Yes 1 Charlie DEF No 10 David DEF Yes 0 Emily D
  • 如何在Python中的滚动平均计算中忽略NaN

    对于时间序列销售预测任务 我想创建一个代表过去 3 天平均销售额的功能 当我想预测未来几天的销售额时遇到问题 因为这些数据点没有销售数据 NaN 值 Pandas 提供rolling mean 但当窗口中的任何数据点为 NaN 时 该函数会

随机推荐

  • 为什么回调在 Ruby on Rails 中使用符号

    我很难理解何时以及何时不应该在 Rails 中使用符号 我知道符号与没有许多方法的字符串并没有太大不同 我还知道这些符号是很好的键 因为同名的符号在内存中占据一个地址 我很难理解为什么 Rails 决定在某些情况下使用符号 如果我有回调 b
  • 主机名未使用 Winsock 转换为 IP 地址

    getaddrinfo 不会将主机名转换为 IP 地址 因此不会connect 到服务器 我的实现有问题吗 编译时没有警告消息 这个函数调用的是connect正确的 connect client result gt ai addr resu
  • 在 Python 中编写仅附加 gzip 日志文件

    我正在构建一项服务 在其中记录来自多个源的纯文本格式日志 每个源一个文件 我不打算轮换这些日志 因为它们必须永远存在 为了使这些永远存在的文件更小 我希望我可以在飞行中对它们进行 gzip 压缩 由于它们是日志数据 因此文件压缩得很好 在
  • 当对象包含 ng-repetate 时,如何使用 angularFire 保存 Firebase 对象 $asArray()

    我最近从 angularfire 0 6 切换到 0 8 0 我在保存包含数组本身的列表项时遇到问题 我的对象account看起来像这样 JQruasomekeys0nrXxH created 2014 03 23T22 00 10 176
  • Python 与格式 '%Y-%m-%dT%H:%M:%S%Z.%f' 不匹配

    我尝试在Python中将字符串转换为日期时间对象 但我找不到我的格式有任何问题 Y m dT H M S Z f import datetime datetime datetime strptime 2019 11 19T17 22 23
  • 使用 getFilesDir() 时应用程序上下文返回 null

    我不知道为什么会发生这种情况 当我检查 DDMS 时也没有文件目录 我正在尝试在我的应用程序子类中访问此文件夹 知道为什么会发生这种情况吗 我需要应用程序上下文是全局的 这样我就可以在不扩展 Activity 的类上使用 package m
  • Selenium-Webdriver:找到元素后获取属性

    我对自动化的东西还很陌生 所以这听起来像是一个愚蠢的问题 在发布问题之前 我确实用谷歌搜索了它 不管怎样 问题就在这里 我正在 Android 设备上进行自动化测试 其中一项测试是验证某个项目是否已被标记为 收藏夹 页面代码片段为 li c
  • Android Studio 2.3 错误:无法加载类“com.google.common.collect.ImmutableSet”

    大家 突然 当我打开现有项目时 出现错误 错误 无法加载类 com google common collect ImmutableSet 导致此意外错误的可能原因包括 格拉德尔的 依赖项缓存可能已损坏 这有时会在网络连接后发生 连接超时 重
  • 创建基类对象的向量并在其中存储派生类对象

    我正在尝试创建一个员工数据库 员工向量 有 3 种类型的员工 即 Employees 是基类 Manager Engg 和 Scientist 是派生类 每个员工都有名字和姓氏 除了名字之外 这 3 种类型的员工中的每一种都有独特的统计数据
  • javascript date.utc 问题

    我正在尝试使用 javascript 比较 2 个日期 月末 1 个 月初 1 个 我需要以秒为单位比较这两个日期 因此我使用 Date UTC javascript 函数 这是代码 var d Date UTC 2010 5 31 23
  • 实体框架中推荐的身份生成方法是什么?

    我对 StoreGeneratePattern 的最高效的方式感兴趣 过去我习惯让数据库为我生成ID 但我想知道设置是否有任何优势 StoreGeneratedPattern None 代替 StoreGeneratedPattern Id
  • Demean R 数据框

    我想贬低 R 中的多列data frame 使用来自的示例这个问题 https stats stackexchange com questions 46978 fixed effects using demeaned data why di
  • android maven插件在Eclipse中没有获取ANDROID_HOME环境变量

    我正在开发一个 Android 应用程序项目 它是一个 Maven 项目 当我尝试作为 maven install 运行时 这就是我得到的 无法在项目 android client 上执行目标 com jayway maven plugin
  • 如果给定空白正则表达式,则 regex_replace 中的 C++ Mac OS 无限循环

    执行后 std regex replace the string std regex doesn t matter 我的 Mac 将无限期挂起 我是 xcode 新手 但我认为我正确使用它 我在调试程序时点击 暂停 发现最后执行的代码位于正
  • 无法通过Java删除目录

    在我的应用程序中 我编写了从驱动器中删除目录的代码 但是当我检查文件的删除功能时 它不会删除该文件 我写过一些这样的东西 Code to delete the directory if it exists File directory ne
  • javaFX 表视图中的错误

    I make TableView在 javaFX 中包含两个TableColumns TableView Span 的宽度大于所有的宽度TableColumn 但这不是问题 我不明白的是 当我单击包含数据的行外部区域和列外部区域 红色区域
  • 在哪里可以找到已实施的耐心差异?

    这个网站上有很好的答案 Bram Cohen 的耐心 diff 在 bazaar 中作为默认 diff 和 git diff 的一个选项找到 但我发现很难找到一个独立的独立程序来实现这个特定的 diff 算法 例如 我想将 Patient
  • 根据列表中的值将列添加到数据框

    我有一个如下所示的数据框 df lt data frame A c a b c d e f g h i B c 1 1 1 2 2 2 3 3 3 C c 0 1 0 2 0 4 0 1 0 5 0 7 0 1 0 2 0 5 gt df
  • PHP 发送邮件表单到多个电子邮件地址

    我对 PHP 非常陌生 正在联系页面上使用基本模板 发送邮件 表单 当单击 提交 按钮时 要求我将电子邮件发送到多个电子邮件地址 我已经四处搜寻 但还没有找到我需要的东西 我需要在下面的表单中添加什么代码才能将其发送到多个电子邮件地址
  • Tensorflow.Keras:自定义约束不起作用

    我正在尝试实现权重正交约束所示here https towardsdatascience com build the right autoencoder tune and optimize using pca principles part