Keras 自定义指标迭代

2023-11-21

我对 Keras 还很陌生,我正在尝试定义自己的指标。它计算一致性指数,这是回归问题的度量。

def cindex_score(y_true, y_pred):
    sum = 0
    pair = 0    
    for i in range(1, len(y_true)):
        for j in range(0, i):
            if i is not j:
                if(y_true[i] > y_true[j]):
                  pair +=1
                  sum +=  1* (y_pred[i] > y_pred[j]) + 0.5 * (y_pred[i] == y_pred[j])
    if pair is not 0:
        return sum/pair
    else:
        return 0


def baseline_model(hidden_neurons, inputdim):
    model = Sequential()
    model.add(Dense(hidden_neurons, input_dim=inputdim, init='normal', activation='relu'))
    model.add(Dense(hidden_neurons, init='normal', activation='relu'))
    model.add(Dense(1, init='normal')) #output layer

    model.compile(loss='mean_squared_error', optimizer='adam', metrics=[cindex_score])
    return model

def run_model(P_train, Y_train, P_test, model):
    history = model.fit(numpy.array(P_train), numpy.array(Y_train), batch_size=50, nb_epoch=200)
    plotLoss(history)
    return model.predict(P_test)

benchmark_model、run_model 和 cindex_score 函数位于 one.py 中,以下函数位于 Two.py 中,我在其中调用模型,

def experiment():
    hidden_neurons = 250
    dmodel=baseline_model(hidden_neurons, train_pair.shape[1])
    predicted_Y = run_model(train_pair,train_Y, test_pair, dmodel)

但我收到以下错误:“‘Tensor’类型的对象没有 len()”。它也不适用于形状属性。

例如,y_true表示为Tensor("dense_4_target:0", shape=(?, ?), dtype=float32),其形状为Tensor("strided_slice:0", shape=(), dtype=int32)。

您能帮我了解如何在张量对象中进行迭代吗?

Best,


如果您使用起来方便tensorflow,那么您可以尝试使用以下代码:

def cindex_score(y_true, y_pred):

    g = tf.subtract(tf.expand_dims(y_pred, -1), y_pred)
    g = tf.cast(g == 0.0, tf.float32) * 0.5 + tf.cast(g > 0.0, tf.float32)

    f = tf.subtract(tf.expand_dims(y_true, -1), y_true) > 0.0
    f = tf.matrix_band_part(tf.cast(f, tf.float32), -1, 0)

    g = tf.reduce_sum(tf.multiply(g, f))
    f = tf.reduce_sum(f)

    return tf.where(tf.equal(g, 0), 0.0, g/f)

下面是一些代码,可验证这两种方法是否等效:

def _ref(J, K):
    _sum = 0
    _pair = 0
    for _i in range(1, len(J)):
        for _j in range(0, _i):
            if _i is not _j:
                if(J[_i] > J[_j]):
                  _pair +=1
                  _sum +=  1* (K[_i] > K[_j]) + 0.5 * (K[_i] == K[_j])
    return 0 if _pair == 0 else _sum / _pair

def _raw(J, K):

    g = tf.subtract(tf.expand_dims(K, -1), K)
    g = tf.cast(g == 0.0, tf.float32) * 0.5 + tf.cast(g > 0.0, tf.float32)

    f = tf.subtract(tf.expand_dims(J, -1), J) > 0.0
    f = tf.matrix_band_part(tf.cast(f, tf.float32), -1, 0)

    g = tf.reduce_sum(tf.multiply(g, f))
    f = tf.reduce_sum(f)

    return tf.where(tf.equal(g, 0), 0.0, g/f)


for _ in range(100):
    with tf.Session() as sess:
        inputs = [tf.placeholder(dtype=tf.float32),
                  tf.placeholder(dtype=tf.float32)]
        D = np.random.randint(low=10, high=1000)
        data = [np.random.rand(D), np.random.rand(D)]

        r1 = sess.run(_raw(inputs[0], inputs[1]),
                      feed_dict={x: y for x, y in zip(inputs, data)})
        r2 = _ref(data[0], data[1])

        assert np.isclose(r1, r2)

请注意,这仅适用于一维张量(在 keras 中很少出现这种情况)。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras 自定义指标迭代 的相关文章

  • Python setuptools:如何在 setup.py 中添加私有存储库 (gitlab)?

    我上传了 2 个包 它们位于我的 gitlab 存储库中 如果我想使用 pip 将它们安装在我的系统中 这很容易 因为 gitlab 可以帮助您 https docs gitlab com ee user packages pypi rep
  • JavaScript 相当于 Python 的参数化 string.format() 函数

    这是 Python 示例 gt gt gt Coordinates latitude longitude format latitude 37 24N longitude 115 81W Coordinates 37 24N 115 81W
  • 将 numpy 数组写入文本文件的速度

    我需要将一个非常 高 的两列数组写入文本文件 而且速度非常慢 我发现如果我将数组改造成更宽的数组 写入速度会快得多 例如 import time import numpy as np dataMat1 np random rand 1000
  • 在 macOS 中通过 Python 访问进程的压缩 RAM(顶部的 CMPRS)的方法?

    我试图弄清楚如何从 Python 访问任何给定进程占用的实际 RAM 量 我发现 psutil Process PID memory info rss 工作得很好 直到操作系统决定开始压缩某些进程的 RAM 然后 所有的 memory in
  • Mobilenet 与 SSD [关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 Locked 这个问题及其答案是locked help locked posts因为这个问题是题外话 但却具有历史意义 目前不接受新的答案
  • 可以用 Django 制作移动应用程序吗?

    我想知道我是否可以在我的网站上使用 Django 代码 并以某种方式在移动应用程序 Flutter 等框架中使用它 那么是否可以使用我现在拥有的 Django 后端并在移动应用程序中使用它 所以就像models views etc 是的 有
  • 如何将 self 传递给装饰器?

    我该如何通过self key下面进入装饰器 class CacheMix object def init self args kwargs super CacheMix self init args kwargs key func Cons
  • PySide6.1 与 matplotlib 3.4 不兼容

    当我只安装PySide6时 GUI程序运行良好 但是一旦我安装了matplotlib及其依赖包 包括pyqt5 则GUI程序将无法运行并输出以下错误消息 This application failed to start because no
  • Pandas 滚动窗口 Spearman 相关性

    我想使用滚动窗口计算 DataFrame 两列之间的 Spearman 和 或 Pearson 相关性 我努力了df corr df col1 rolling P corr df col2 P为窗口尺寸 但我似乎无法定义该方法 添加meth
  • 动态 __init_subclass__ 方法的参数绑定

    我正在尝试让类装饰器工作 装饰器会添加一个 init subclass 方法到它所应用的类 但是 当该方法动态添加到类中时 第一个参数不会绑定到子类对象 为什么会发生这种情况 举个例子 这是可行的 下面的静态代码是我试图最终得到的示例 cl
  • 在Python中计算内存碎片

    我有一个长时间运行的进程 不断分配和释放对象 尽管正在释放对象 但 RSS 内存使用量会随着时间的推移而增加 如何计算发生了多少碎片 一种可能性是计算 RSS sum of allocations 并将其作为指标 即便如此 我该如何计算分母
  • 乘以行并按单元格值附加到数据框

    考虑以下数据框 df pd DataFrame X a b c d Y a b d e Z a b c d 1 2 1 3 df 我想在 列中附加数字大于 1 的行 并在该行中的数字减 1 df 最好应该 然后看起来像这样 或者它可能看起来
  • Python 惰性迭代器

    我试图了解迭代器表达式如何以及何时被求值 以下似乎是一个懒惰的表达 g i for i in range 1000 if i 3 i 2 然而 这个在构造上失败了 g line strip for line in open xxx r if
  • 如何使用 paramiko 查看(日志)文件传输进度?

    我正在使用 Paramiko 的 SFTPClient 在主机之间传输文件 我希望我的脚本打印文件传输进度 类似于使用 scp 看到的输出 scp my file user host user host password my file 1
  • PyTorch DataLoader 对并行运行的批次使用相同的随机种子

    有一个bug https tanelp github io posts a bug that plagues thousands of open source ml projects 在 PyTorch Numpy 中 当并行加载批次时Da
  • 解析根元素内元素之间的 XML 文本

    我正在尝试用 Python 解析 XML 以下是 XML 结构的示例 a aaaa1 b bbbb b aaaa2 a
  • 是否可以将 pd.Series 分配给无序 pd.DataFrame 中的列而不映射到索引(即不重新排序值)?

    在 Pandas 中创建或分配新列时 我发现了一些意外的行为 当我对 pd DataFrame 进行过滤或排序 从而混合索引 然后从 pd Series 创建新列时 Pandas 会重新排序该系列以映射到 DataFrame 索引 例如 d
  • 处理大文件的最快方法?

    我有多个 3 GB 制表符分隔文件 每个文件中有 2000 万行 所有行都必须独立处理 任何两行之间没有关系 我的问题是 什么会更快 逐行阅读 with open as infile for line in infile 将文件分块读入内存
  • python sklearn中的fit方法

    我问自己关于 sklearn 中拟合方法的各种问题 问题1 当我这样做时 from sklearn decomposition import TruncatedSVD model TruncatedSVD svd 1 model fit X
  • 缓存 Flask-登录 user_loader

    我有这个 login manager user loader def load user id None return User query get id 在我引入 Flask Principal 之前它运行得很好 identity loa

随机推荐

  • 如果值匹配,则将单元格数据连接到另一个数据

    我有两个columns A and B在同一个 Excel 工作表中 我正在尝试如果在Column B两个值匹配 那么它应该复制相关值A在同一行 For e g Table Column A Column B xyz 1 abc 1 pqr
  • 将事件保存到用户的日历

    如何将事件添加到用户的日历中 然后允许用户选择日历等 我有这段有效的代码 但这会将事件添加到用户的默认日历中 如何允许用户更改日历 自定义警报等 我见过其他应用程序打开日历应用程序并预 先填写字段 add to calendar let e
  • 使用 Python 2.x 中的“is”运算符将对象与空元组进行比较

    我看惯了if obj is None 在Python中 我最近遇到if obj is 由于元组是不可变的 这听起来像是 Python 解释器中合理的内部优化 让空元组成为单例 因此允许使用is而不是要求 但这在某个地方得到保证吗 从哪个版本
  • 使用多个键排序时反转特定键

    当使用多个键排序时 如何反转单个键的顺序 例如 vec sort by key k foo k reverse bar k 您可以使用sort by与Ordering reverse代替sort by key use std cmp Ord
  • Rails 3:如何通过javascript触发表单提交?

    我有一个表单 大部分只是作为普通表单提交 所以我不想在 form tag 中设置 remote gt true 选项 然而 在某些情况下 我希望能够有一个 javascript 函数发布表单 就像它是通过 remote gt true 发布
  • ASP.NET Identity - 使用安全标记强制重新登录

    So from ASP NET Identity 的 IUserSecurityStampStore 接口是什么 我们了解到 ASP NET Identity 具有安全标记功能 用于使用户登录 cookie 无效 并强制他们重新登录 在我的
  • 如何将dispatch_data_t转换为NSData?

    这是正确的方法吗 convert const void buffer NULL size t size 0 dispatch data t new data file dispatch data create map data buffer
  • 如何在闪亮页面而不是弹出窗口中渲染 scatter3d

    我正在考虑在我闪亮的应用程序中实现 3D 交互式绘图 到目前为止我一直在使用plotly 然而 plotly 有一个主要缺点 渲染时速度非常慢 我已经完成了检查 尽管涉及大量数据集 但更新的 outplot plot 因此 我希望使用一个名
  • 同时运行延迟作业和 Sidekiq

    我目前使用延迟作业来异步处理作业 我没有创建工人 而是使用 delay方法很多 我想转到 Sidekiq 但我的工作类型太多 无法确保所有工作都是线程安全的 所以我想并行运行 Delayed Job 和 Sidekiq 并一次迁移一种类型的
  • 如何捕获 pg_connect() 函数错误?

    pg connect 以表格式显示错误 而不是以表格式显示错误消息 需要错误消息警报 错误信息警告 pg connect function pg connect 无法连接到 PostgreSQL 服务器 致命 第 41 行 home tes
  • 使用 log4j2 进行公共日志记录

    我正在使用 log4j 1 2 和 commons logging 现在我正在尝试将其升级到log4j2 但是如何使用 log4j2 和 commons logging 来初始化 log4j2 我尝试通过以下方式初始化公共日志记录 它工作正
  • 从 JSON 检索项目时出现“无法将 Newtonsoft.Json.Linq.JObject 转换为 Newtonsoft.Json.Linq.JToken”

    当有以下代码时 var TermSource token Value
  • 代表们快了?

    如何任命一名代表 即NSUserNotificationCenterDelegate快点 以下是关于两个视图控制器之间的委托的一些帮助 Step 1 在您将要删除 将发送数据的 UIViewController 中制定一个协议 protoc
  • 在没有故事板的情况下启动ios项目

    我在使用 xibs 而不是故事板启动 iOS 应用程序时遇到了一些麻烦 问题是我遇到黑屏并且第一个视图控制器没有被调用 添加断点viewDidLoad方法 在应用程序委托标头中 我声明了这一点 property strong nonatom
  • linq 异常:只能从 LINQ to Entities 调用此函数

    我正在尝试获取保存在缓存中的数据 但它在 select new FilterSsrsLog 行上引发异常 例外 此函数只能从 LINQ to Entities 调用 List
  • 自动导入包的顺序和歧义

    JLS 第 7 章 软件包 一个包由许多编译单元组成 第 7 3 节 一个编译单元自动有权访问其包中声明的所有类型并且自动导入预定义包 java lang 中声明的所有公共类型 让我们假设以下代码 package com example p
  • jQuery 日期时间选择器 MVC3

    我的模型中有这个字段 DataType DataType DateTime Required ErrorMessage Expire is required public DateTime Expire get set 在我看来 Html
  • 在 JavaScript 中定义全局对象的独立于实现的版本

    我正在尝试定义globalJavaScript 中的对象在一行中如下所示 var global this global this 上述声明是在全局范围内的 因此在浏览器中this指针是一个别名window目的 假设它是在当前网页上下文中执行
  • 将 key=value 对转换回 Python 字典

    有一个日志文件 其中的文本以空格分隔key value对 每行最初都是从 Python 字典中的数据序列化的 类似于 join f k v r for k v in d items 键始终只是字符串 这些值可以是任何值ast literal
  • Keras 自定义指标迭代

    我对 Keras 还很陌生 我正在尝试定义自己的指标 它计算一致性指数 这是回归问题的度量 def cindex score y true y pred sum 0 pair 0 for i in range 1 len y true fo