Keras 自定义损失函数:形状为batch_size (y_true) 的变量

2024-04-11

在 Keras 中实现自定义损失函数时,我需要tf.Variable与我的输入数据的批量大小的形状(y_true, y_pred).

def custom_loss(y_true, y_pred):

    counter = tf.Variable(tf.zeros(K.shape(y_true)[0], dtype=tf.float32))
    ...

但是,这会产生错误:

You must feed a value for placeholder tensor 'dense_17_target' with dtype float and shape [?,?]

如果我将batch_size固定为一个值:

def custom_loss(y_true, y_pred):

    counter = tf.Variable(tf.zeros(batch_size, dtype=tf.float32))
    ...

这样|training_set| % batch_size and |val_set| % batch_size等于零,一切正常。

是否有任何建议,为什么根据输入的形状分配批量大小的变量(y_true and y_pred)不起作用?

SOLUTION

我找到了一个令人满意且有效的解决方案。 我用可能的最大batch_size(在模型构建时指定)初始化了变量,并使用了K.shape(y_true)[0]仅用于切片变量。这样它就完美地工作了。这里是代码:

def custom_loss(y_true, y_pred):
    counter = tf.Variable(tf.zeros(batch_size, dtype=tf.float32))
    ...
    true_counter = counter[:K.shape(y_true)[0]]
    ...

它不起作用,因为K.shape返回一个符号形状,它本身就是一个张量,而不是一个 int 值的元组。要从张量获取值,您必须在会话下对其进行评估。看文档 https://keras.io/backend/#shape为了这。要在评估时间之前获得实际值,请使用K.int_shape: https://keras.io/backend/#int_shape https://keras.io/backend/#int_shape

然而,K.int_shape在这里也不起作用,因为它只是一些静态元数据,通常不会反映当前批量大小,但有一个占位符值None.

您找到的解决方案(控制批量大小并在损失中使用它)确实是一个很好的解决方案。

我认为问题是因为您需要在定义时知道批量大小来构建变量,但只有在会话运行时才会知道它。

如果您像使用张量一样使用它,应该没问题,请参阅此example https://github.com/pierluigiferrari/ssd_keras/blob/master/keras_loss_function/keras_ssd_loss.py#L128.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras 自定义损失函数:形状为batch_size (y_true) 的变量 的相关文章

  • 根据 pandas 中的条件交换列值

    我想按条件重新定位列 如果国家 地区是 日本 我需要将姓氏和名字反向重新定位 df pd DataFrame France Kylian Mbappe Japan Hiroyuki Tajima Japan Shiji Kagawa Eng
  • 为什么方法无法访问类变量?

    我试图理解Python中的变量作用域 除了我不明白为什么类变量不能从其方法访问的部分之外 大多数事情对我来说都很清楚 在下面的例子中mydef1 无法访问a 但如果a可以在全局范围 类定义之外 声明 class MyClass1 a 25
  • 如何有条件地组合两个相同形状的 numpy 数组

    这听起来很简单 但我想我把它想得太复杂了 我想创建一个数组 其元素是从两个形状相同的源数组生成的 具体取决于源数组中哪个元素更大 为了显示 import numpy as np array1 np array 2 3 0 array2 np
  • 如何在算术表达式的结果上添加 SQLAlchemy 标签?

    我如何将这样的东西翻译成 SQLAlchemy select x y as difference 我知道该怎么做 x label foo 但我不确定在哪里放置下面的 label 方法调用 select table c x table c y
  • 返回不包括指定键的字典副本

    我想创建一个函数 返回字典的副本 不包括列表中指定的键 考虑这本词典 my dict keyA 1 keyB 2 keyC 3 致电without keys my dict keyB keyC 应该返回 keyA 1 我想用一行简洁的字典理
  • 无法安装时间模块

    我试过了pip install time and sudo H pip install time 但我不断收到错误 找不到满足要求时间的版本 从 版本 未找到时间匹配的发行版 我正在 PyCharm 中工作 但真正没有意义的是我可以在 Py
  • 如何用xlrd读取公式

    我正在尝试做一个解析器 它读取几个 Excel 文件 我通常需要位于行底部的值 您可以在其中找到所有上部元素的总和 因此 单元格值实际上是 sum 或 A5 0 5 可以说 对于使用 Excel 打开此文件的用户来说 它看起来像一个数字 这
  • 为 PyCharm 中的所有配置设置相同的环境变量

    我有一个与 Celery 和很多不同的工作人员一起的项目 如何避免每次将 PyCharm 中的环境变量复制粘贴到每个运行 调试配置 有什么方法可以在项目设置中设置它们吗 找到解决方案here https stackoverflow com
  • 一起使用 Argparse 和 Json

    我是 Python 初学者 我想知道 Argparse 和 JSON 是否可以一起使用 说 我有变量p q r 我可以将它们添加到 argparse 中 parser add argument p param1 help x variabl
  • 如何使用注释和聚合在 Django 的 ORM 中执行此 GROUP BY 查询

    我真的不知道如何翻译GROUP BY and HAVING到姜戈的QuerySet annotate and QuerySet aggregate 我正在尝试将这个 SQL 查询转换为 ORM 语言 SELECT EXTRACT year
  • 别碰我的女人

    我讨厌的一件事迪斯图尔斯 http docs python org distutils 我猜他是邪恶的人 他这样做了 https github com python cpython blob 300dd552b15825abfe0e367a
  • 如何在python中递归复制目录并覆盖全部?

    我正在尝试复制 home myUser dir1 及其所有内容 及其内容等 home myuser dir2 在Python中 此外 我希望副本覆盖中的所有内容dir2 It looks like distutils dir util co
  • 如何将reportlab与Google应用程序引擎一起使用

    我无法在谷歌应用程序引擎下正确导入reportlab 根据以下guide http blog notdot net 2010 04 Generating PDFs on App Engine Python and introducing M
  • 从 Apache 运行 python 脚本的最简单方法

    我花了很长时间试图弄清楚这一点 我基本上正在尝试开发一个网站 当用户单击特定按钮时 我必须在其中执行 python 脚本 在研究了 Stack Overflow 和 Google 之后 我需要配置 Apache 以便能够运行 CGI 脚本
  • Scikit Learn - K-Means - 肘部 - 标准

    今天我想学习一些关于 K means 的知识 我已经了解该算法并且知道它是如何工作的 现在我正在寻找正确的 k 我发现肘部准则作为检测正确的 k 的方法 但我不明白如何将它与 scikit learn 一起使用 在 scikit learn
  • Pandas DataFrame:如何计算组中第一行和最后一行的差异?

    这是我的熊猫数据框 import pandas as pd import numpy as np data column1 338 519 871 1731 2693 2963 3379 3789 3910 4109 4307 4800 4
  • LSTM 批次与时间步

    我按照 TensorFlow RNN 教程创建了 LSTM 模型 然而 在这个过程中 我对 批次 和 时间步长 之间的差异 如果有的话 感到困惑 并且我希望得到帮助来澄清这个问题 教程代码 见下文 本质上是根据指定数量的步骤创建 批次 wi
  • Django - 缺少 1 个必需的位置参数:'request'

    我收到错误 get indiceComercioVarejista 缺少 1 个必需的位置参数 要求 当尝试访问 get indiceComercioVarejista 方法时 我不知道这是怎么回事 views from django ht
  • 如何将 Pandas Dataframe 中的字符串转换为字符列表或数组?

    我有一个名为的数据框data 其中一列包含字符串 我想从字符串中提取字符 因为我的目标是对它们进行一次性编码并使之可用于分类 包含字符串的列存储在预测因子如下 predictors pd DataFrame data columns Seq
  • 从 Flask 中的 S3 返回 PDF

    我正在尝试在 Flask 应用程序的浏览器中返回 PDF 我使用 AWS S3 来存储文件 并使用 boto3 作为与 S3 交互的 SDK 到目前为止我的代码是 s3 boto3 resource s3 aws access key id

随机推荐