如何获取 lambda 层内的批量大小

2024-05-01

我正在尝试实现一个层(通过 lambda 层),它执行以下 numpy 过程:

def func(x, n):
    return np.concatenate((x[:, :n], np.tile(x[:, n:].mean(axis = 0), (x.shape[0], 1))), axis = 1)

我陷入困境,因为我不知道如何获取 x 第一个维度的大小(即批量大小)。后台功能int_shape(x)回报(None, ...).

所以,如果我知道batch_size,相应的Keras过程将是:

def func(x, n):
    return K.concatenate([x[:, :n], K.tile(K.mean(x[:, n:], axis=0), [batch_size, 1])], axis = 1)

正如@pitfall所说,第二个参数K.tile应该是一个张量。 并根据keras后端的文档 https://keras.io/backend/, K.shape返回一个张量并且K.int_shape返回 int 或 None 条目的元组。所以正确的方法是使用K.shape。以下是 MWE:

import keras.backend as K
from keras.layers import Input, Lambda
from keras.models import Model
import numpy as np

batch_size = 8
op_len = ip_len = 10

def func(X):
    return K.tile(K.mean(X, axis=0, keepdims=True), (K.shape(X)[0], 1))

ip = Input((ip_len,))
lbd = Lambda(lambda x:func(x))(ip)

model = Model(ip, lbd)
model.summary()

model.compile('adam', loss='mse')

X = np.random.randn(batch_size*100, ip_len)
Y = np.random.randn(batch_size*100, op_len)
#no parameters to train!
#model.fit(X,Y,batch_size=batch_size)

#prediction
np_result = np.tile(np.mean(X[:batch_size], axis=0, keepdims=True), 
                    (batch_size,1))
pred_result = model.predict(X[:batch_size])
print(np.allclose(np_result, pred_result))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何获取 lambda 层内的批量大小 的相关文章

随机推荐

  • 如何找到某个命令的目录?

    我知道 当您使用 shell 时 唯一可以使用的命令是可以在 PATH 上设置的某些目录中找到的命令 即使我不知道如何查看 PATH 变量上的目录 这是另一个可以回答的好问题 我想知道的是 我来到 shell 并写道 lshw 我想知道 s
  • 形状为 (N,1) 的数组与形状为 (N) 的数组有什么区别?以及两者之间如何转换?

    来自 MATLAB 背景的 Python 新手 我有一个 1 列数组 我想将该列移到 3 列数组的第一列中 如果我有 MATLAB 背景 我会这样做 import numpy as np A np zeros 150 3 three col
  • 重新创建 svn 存储库

    在一次重大服务器故障之后 svn 存储库被破坏 我的工作版本是最新版本 从我的工作版本重新创建 svn 存储库的方法是什么 在新服务器上安装 svn 并尝试我的工作副本之后 svn switch NEW SVN PATH 我收到一个错误 R
  • 在 Idle shell 中导入模块

    我正在尝试学习 python 但在导入模块时遇到问题 我有一个 pyc 文件 我正在尝试将其导入到名为 dfa pyc 的空闲 shell 中 我将该文件放在名为 xyz 的文件夹中 我使用以下命令导航到该文件夹 os chdir User
  • uWSGI 说:“ImportError:没有名为 wsgi 的模块”

    当uWSGI启动时 它会写入 ImportError No module named wsgi 我的 uwsgi xml
  • 数据网格中的主键始终为零

    我们正在VS2012中使用实体框架 DB First 开发WPF应用程序 我们在数据网格视图中遇到问题 我们从数据源中拖动了一个数据网格 这创建了一个绑定到该特定表的数据网格 该表有两列 一列是 TransporterID 它是 PK 是自
  • 在SSL模式下使用apache kafka

    我正在尝试在 SSL 1 way 模式下设置 kafka 我已经阅读了官方文档并成功生成了证书 我将记下两种不同情况的行为 此设置只有一名经纪人和一名动物园管理员 案例 1 经纪人间通信 明文 我的相关条目server properties
  • 将自定义值存储在 EKEvent(iPhone 日历)中

    我的应用程序与设备日历集成 当新项目添加到我的应用程序时 我们会为此项目创建一个日历条目 如果项目被编辑 我们需要更新日历项目 我现在所做的是将 GUID 放入 EKEvent Notes 中 但显然这对用户是可见的 因此我们添加文本 请勿
  • 使用 Selenium 的 Chrome 驱动程序错误:无法发现打开的页面

    运行 Selenium 测试时 我收到与 Chrome 驱动程序相关的错误 错误消息是 无法发现打开的页面 直到昨晚 Selenium 测试都运行良好 问题似乎是在前一天重新启动服务器后开始的 我无法在本地机器上重现此错误 从服务器上的命令
  • Mercurial hook 的操作类似于“changegroup”,但仅在推送时?

    我们已经构建了一个变更集传播机制 但它依赖于捆绑和解除捆绑新变更集 如果我们要使用changegroup钩子 那么它会导致循环行为 因为钩子是运行的在拉 推或解绑期间 http mercurial selenic com wiki Hook
  • 德尔福:idHttp+SSL

    请解释一下如何使用 SSL https 从服务器下载文件 我在互联网上没有找到合适的答案 每个人都说 TIdSSLIOHandlerSocket 但我只有 TIdSSLIOHandlerSocketOpenSSL 如果我使用 TIdSSLI
  • html或css中的倾斜对角线?

    I want to make a Table like this 是否可以添加一个倾斜的对角边框在表中 基于CSS3 线性渐变 http dev w3 org csswg css images 3 linear gradients解决方案
  • 秘密名称不支持特殊字符

    我有一个要求 需要将我的秘密名称存储为 fname lname 但是当我尝试使用下划线时 出现以下错误 为了暂时绕过该错误 我编写了一个实用程序来将下划线转换为连字符 反之亦然 有什么原因不支持下划线等基本特殊字符吗 az keyvault
  • 如何包含来自其他域的一个 php 文件

    我在同一台服务器上有两个域 www domain com www domain com 我有一个index1 php在拳头服务器中 现在我需要包含该文件index2 php驻留在域2中 如何使用 php 代码 包括 要求 不可能在另一台服务
  • 无论如何,为什么要处置一个肯定很快就会被处置的物体呢?

    假设我有一个程序 例如单击按钮 我创建了一个 Graphics 对象 显然我应该处理掉它 例如 using Graphics gr this CreateGraphics 或通过调用 Dispose in the finallytry ca
  • 资源 ID #4 PHP MYSQL

    result mysql query SELECT indvsum sum1 indvsum sum2 FROM SELECT SUM Cash AS sum1 SUM Bank AS sum2 FROM players indvsum e
  • PhysicsFS 是否独立于平台?

    我正在考虑在我的游戏引擎项目中使用PhysicsFS 但我想首先确保它完全独立于平台 这是因为我想在完成 Windows 代码后将我的引擎移植到一些相当不起眼的平台 例如 Wii Homebrew 根据开发者提供的官方规格他们的网站 htt
  • 无法使用php连接到mongodb数据库用户

    我有一个正在运行的 mongodauth true在我的服务器上 如果我登录到我的管理员用户 从管理数据库 则获取数据没有问题 但如果我将第一行替换为 connection new Mongo mongodb mydbadmin email
  • 将开始列和结束列合并为一列

    我已经上下搜索了好几个星期 试图找到解决我的问题的方法 我的问题如下 A 有一个表格 其中包含来自车辆遥测提供商的开始和结束坐标以及日期 我需要将它们合并到一列中 以便我们的报告解决方案能够绘制它们 一些示例数据如下 DECLARE Tbl
  • 如何获取 lambda 层内的批量大小

    我正在尝试实现一个层 通过 lambda 层 它执行以下 numpy 过程 def func x n return np concatenate x n np tile x n mean axis 0 x shape 0 1 axis 1