reshape 的输入是一个具有 2 *“batch_size”值的张量,但请求的形状具有“batch_size”

2024-02-05

我想使用带有张量流后端的 Keras 顺序模型制作 RNN。当我实现以下代码时:

batch_size = 8
batch_inputshape = (batch_size,x_train.shape[1],x_train.shape[2])
print(batch_inputshape) #(8, 600, 103)
​
model = Sequential()
model.add(LSTM(103, 
               batch_input_shape = batch_inputshape, 
               return_sequences = True,
              stateful = True))
model.add(Dropout(0.2))
​
model.add(LSTM(50, 
               return_sequences = True,
              stateful = True))
model.add(Dropout(0.2))
​
​
model.add(TimeDistributed(Dense(10)))
model.add(TimeDistributed(Dense(2)))
model.add(Activation('softmax'))
model.compile(loss= ncce, optimizer='adam')    ​
​
print (model.output_shape) #(8, 600, 2)

model.fit(x_train,y_train, batch_size = batch_size,
                           nb_epoch = 1, validation_split=0.25)

我收到以下错误消息:

reshape 的输入是一个有 16 个值的张量,但请求的形状有 8 个

但无论我将batch_size更改为错误,都将遵循以下公式:

重塑的输入是一个张量2 * batch_size值,但要求的形状有batch_size

我看过其他的,但我认为它们对我帮助不大。或者我对答案的理解不够好。

任何帮助将非常感激!

编辑: 根据要求输入和目标的形状:

print(x_train.shape) #(512,600,103)
print(y_train.shape) #(512,600,2)

EDIT 2:

from functools import partial
import keras.backend as K 
from itertools import product
​
def w_categorical_crossentropy(y_true, y_pred, weights):
    # https://github.com/fchollet/keras/issues/2115#issuecomment-274101310 #
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
    y_pred_max_mat = K.cast(K.equal(y_pred, y_pred_max), K.floatx())
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):
        final_mask += (weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
    return K.categorical_crossentropy(y_pred, y_true) * final_mask
​
w_array = np.ones((2,2))
w_array[1, 0] = 100
​
​
print(w_array)
ncce = partial(w_categorical_crossentropy, weights=w_array)
ncce.__name__ ='w_categorical_crossentropy

编辑 3:更新

在@Nassim Ben 的帮助下,他发现问题出在损失函数中。他发布了带有常规损失函数的代码,然后它就可以正常工作了。然而,对于自定义损失函数,该代码不起作用。正如这个问题的任何读者都可以看到的,我在上面发布了我的服装损失函数,并且存在问题。目前我还不知道为什么会出现这个错误,但这就是当前状态。


编辑 : 这段代码对我有用,为了简单起见我只改变了损失。

import keras
from keras.layers import *
from keras.models import Sequential
from keras.objectives import *
import numpy as np

x_train = np.random.random((512,600, 103))
y_train = np.random.random((512,600,2))
batch_size = 8
batch_inputshape = (batch_size,x_train.shape[1],x_train.shape[2]) 
print(batch_inputshape) #(8, 600, 103)

model = Sequential()
model.add(LSTM(103,
           batch_input_shape = batch_inputshape,
           return_sequences = True,
          stateful = True))
model.add(Dropout(0.2))
model.add(LSTM(50,
           return_sequences = True,
          stateful = True))
model.add(Dropout(0.2))


model.add(TimeDistributed(Dense(10)))
model.add(TimeDistributed(Dense(2)))
model.add(Activation('softmax'))
model.compile(loss= "mse", optimizer='adam')

print (model.output_shape) #(8, 600, 2)

model.fit(x_train,y_train, batch_size = batch_size,
                       nb_epoch = 1, validation_split=0.25)

EDIT 2:

所以错误来自损失函数。在您从 github 复制的 ncce 损失代码中,它们的输出形状为 (batch,10)。您的输出形状为 (batch, 600, 2)。这是我对该函数的编辑:

def w_categorical_crossentropy(y_true, y_pred, weights):
# https://github.com/fchollet/keras/issues/2115#issuecomment-274101310 #
    nb_cl = len(weights)
    # Create a mask with zeroes
    final_mask = K.zeros_like(y_pred[:,:,0])
    # get the maximum probability value for every output (shape = (batch,600,1))
    y_pred_max = K.max(y_pred, axis=2, keepdims=True)
    # Get the actual predictions for every output (shape = (batch,600,2))
    # This K.equal uses broadcasting, we compare two tensors of different sizes but it works (magic)
    y_pred_max_mat = K.equal(y_pred, y_pred_max)
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):
        # Create the mask of weights to apply to the result of the cat_crossentropy
        final_mask += (weights[c_t, c_p] * K.cast(y_pred_max_mat[:,:, c_p], K.floatx()) * y_true[:,:, c_t])
    return K.categorical_crossentropy(y_pred, y_true) * final_mask

w_array = np.ones((2,2))
w_array[1, 0] = 100

正如你所看到的,由于你的特殊形状,我刚刚修改了索引玩法。 掩模必须成形(批量,600)。 最大值必须在第三维上完成,因为那里存在您想要输出的概率。 由于张量的形状,构建最大值的矩阵乘法也需要更新。

这应该有效。

如果您需要更详细的解释,请随时询问:-)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

reshape 的输入是一个具有 2 *“batch_size”值的张量,但请求的形状具有“batch_size” 的相关文章

随机推荐

  • play 2.4 中的插件、依赖项、模块和子项目有什么区别?

    我是 playframework 的新手 刚刚学习 我对依赖项 模块 插件和子项目有点困惑 它们有何不同 这是我的理解 可能是错的 依赖项 是播放应用程序运行所需的所有库 子项目 是另一个父应用程序内的播放应用程序 不确定 插件 和 模块
  • Rescue_from 不会从视图或助手中拯救 Timeout::Error

    我的应用程序控制器中有一个 around filter 用于将所有操作封装在超时块中 以便操作在达到 30 秒 Heroku 限制之前失败 我还有一个rescue from Timeout Error 来彻底挽救这些超时 不幸的是 resc
  • 允许在 React Native 中关注 TextInput 时点击/按下项目

    我有一个TextInput其功能是对某些结果进行搜索 过滤 结果显示在ScrollView 我遇到的问题是 虽然国家focus on the TextInput 用户必须点击两次才能选择该项目 这是一个TouchableOpacity 在里
  • Laravel 中的一次性自定义 cron 计划

    我想在用户在表单中输入的自定义日期和时间运行一次 cron 做这个的最好方式是什么 我发现可以像这样在 laravel 中安排自定义 cron gt cron 按照自定义 Cron 计划运行任务 但我找不到时间格式 的含义 或者更简单 可以
  • Numpy-convertible 类可以从序列内部正确转换为 ndarray?

    The array 方法允许自定义类型自动转换为 numpy 例如 gt gt gt class Convertible def array self return np zeros 7 gt gt gt np array Converti
  • 改变spacy NER中的beam_width

    我想将 nlp entity cfg beam width 默认情况下为 1 更改为 3 我尝试了 nlp entity cfg update beam width 3 但看起来 nlp 的东西在这次更改后被破坏了 如果我执行nlp str
  • 如何为 IP 地址签署 SSL 证书? [关闭]

    Closed 这个问题是与编程或软件开发无关 help closed questions 目前不接受答案 我有一台服务器 在我家里的一台机器上仅托管我网站的节点后端 我正在使用express 我想从另一个后端调用该服务器 我们正在尝试构建一
  • Java(Android Studio)libgdx中的代码,如何计算弹丸[关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 Java Android Studio libgdx中的代码 当您单击 触摸屏幕时 如何计算圆形 如球 的射弹以及如何显示它 就像打篮
  • 电子邮件模板位置绝对吗?

    使用安全吗position absolute在电子邮件模板中 取决于您的用户使用的邮件客户端 例如 Outlook 处理位置 绝对好 而 Thunderbird 则不然 我会尝试将您的邮件模板设计得尽可能 正常 例如 表格有很大帮助 恶心
  • NamedScope 和垃圾收集

    这个问题首先是在 Ninject Google Group 中提出的 但我现在发现 Stackoverflow 似乎更活跃 我使用 NamedScopeExtension 将相同的 ViewModel 注入到 View 和 Presente
  • 具有多个间隔的序列

    seq只能使用单个值by范围 有没有办法矢量化by 即使用多个间隔 像这样的事情 seq 1 10 by c 1 2 会回来c 1 2 4 5 7 8 10 现在 可以使用例如来做到这一点seq 1 10 by 1 c T T F 因为这是
  • 升级到 NPM 5.4.1 后,在不删除 node_modules 的情况下无法运行“npm install”

    我已将 NPM 从 5 3 0 升级到 5 4 1 之后 该命令似乎npm 安装仅当我删除后才有效节点模块 当我尝试重新运行安装时 收到以下错误消息 之后 如果我再次删除节点模块 命令运行安装作品 once PS C source webs
  • iOS 通讯软件 SDK

    我正在寻找在我当前的 iPhone android 应用程序中实现应用程序内消息程序 要求是它必须免费 实时并提供推送通知 我已经研究过自己创建系统 但注意到很多应用程序实现了非常相似的概念 所以我认为 SDK 包装器可用 以下是使用要实现
  • 我可以在 Web 配置中设置应用程序池吗?

    我使用 IIS 7 0 我想知道如何在 Web 配置文件中设置应用程序池 我认为这是不可能的 为您的应用程序选择应用程序池是一个 设置 问题 而不是一个 配置 问题
  • jQuery Mobile 导航栏中每行超过 5 个项目

    我未能成功地寻找一个变量来更改导航栏中单行中的最大项目数 我刚刚开始使用 jQuery Mobile 尝试创建一个包含大约 7 个单字母项目的导航栏 当存在超过 5 个项目时 导航栏会自动换行 这对于我的项目来说是不可取的 谁能指出我的代码
  • 简单的 Java Hangman 分配

    我被困在一个类的 Java 作业中 我们需要制作一个 Hangman 游戏 但是一个非常基本的游戏 这是 Java 类的介绍 基本上 我有一个由某人输入的单词 另一个人必须猜测该单词 但他们看不到该单词 因此它会像这样显示 如果该单词是 a
  • 如何在我们的应用程序中给出 zend 库路径? (在 zend 框架 2.3 中)

    我已经在本地计算机上安装了 zend 骨架应用程序 我正在ubuntu上工作 我是手动安装的 没有使用composer 我已经在我的 httpd conf 中给出了 ZF2 PATH zend 库路径 如下所示
  • 如何比较 Django 中的两个日期时间字段

    我用过datetime datetime now 用于存储datefield在我的模型中 另存为2016 06 27 15 21 17 248951 05 30 现在我想比较一下datefield与datetime从前端获取的值 例如Thu
  • 是否可以使用 NumPy 重现 MATLAB 的 randn() ?

    我想知道是否有可能准确地重现整个序列randn MATLAB 与 NumPy 的结合 我用 Python Numpy 编写了自己的例程 它给我的结果与其他人编写的 MATLAB 代码有些不同 而且由于随机抽取不同 我很难找出它的来源 我已经
  • reshape 的输入是一个具有 2 *“batch_size”值的张量,但请求的形状具有“batch_size”

    我想使用带有张量流后端的 Keras 顺序模型制作 RNN 当我实现以下代码时 batch size 8 batch inputshape batch size x train shape 1 x train shape 2 print b