Tensorflow:ValueError:形状必须为 2 级,但为 3 级

2023-11-27

我是张量流新手,我正在尝试将双向 LSTM 的一些代码从旧版本的张量流更新到最新版本(1.0),但出现此错误:

形状必须为等级 2,但“MatMul_3”(操作:“MatMul”)的等级为 3,输入形状为:[100,?,400]、[400,2]。

错误发生在 pred_mod 上。

    _weights = {
    # Hidden layer weights => 2*n_hidden because of foward + backward cells
        'w_emb' : tf.Variable(0.2 * tf.random_uniform([max_features,FLAGS.embedding_dim], minval=-1.0, maxval=1.0, dtype=tf.float32),name='w_emb',trainable=False),
        'c_emb' : tf.Variable(0.2 * tf.random_uniform([3,FLAGS.embedding_dim],minval=-1.0, maxval=1.0, dtype=tf.float32),name='c_emb',trainable=True),
        't_emb' : tf.Variable(0.2 * tf.random_uniform([tag_voc_size,FLAGS.embedding_dim], minval=-1.0, maxval=1.0, dtype=tf.float32),name='t_emb',trainable=False),
        'hidden_w': tf.Variable(tf.random_normal([FLAGS.embedding_dim, 2*FLAGS.num_hidden])),
        'hidden_c': tf.Variable(tf.random_normal([FLAGS.embedding_dim, 2*FLAGS.num_hidden])),
        'hidden_t': tf.Variable(tf.random_normal([FLAGS.embedding_dim, 2*FLAGS.num_hidden])),
        'out_w': tf.Variable(tf.random_normal([2*FLAGS.num_hidden, FLAGS.num_classes]))}

    _biases = {
         'hidden_b': tf.Variable(tf.random_normal([2*FLAGS.num_hidden])),
         'out_b': tf.Variable(tf.random_normal([FLAGS.num_classes]))}


    #~ input PlaceHolders
    seq_len = tf.placeholder(tf.int64,name="input_lr")
    _W = tf.placeholder(tf.int32,name="input_w")
    _C = tf.placeholder(tf.int32,name="input_c")
    _T = tf.placeholder(tf.int32,name="input_t")
    mask = tf.placeholder("float",name="input_mask")

    # Tensorflow LSTM cell requires 2x n_hidden length (state & cell)
    istate_fw = tf.placeholder("float", shape=[None, 2*FLAGS.num_hidden])
    istate_bw = tf.placeholder("float", shape=[None, 2*FLAGS.num_hidden])
    _Y = tf.placeholder("float", [None, FLAGS.num_classes])

    #~ transfortm into Embeddings
    emb_x = tf.nn.embedding_lookup(_weights['w_emb'],_W)
    emb_c = tf.nn.embedding_lookup(_weights['c_emb'],_C)
    emb_t = tf.nn.embedding_lookup(_weights['t_emb'],_T)

    _X = tf.matmul(emb_x, _weights['hidden_w']) + tf.matmul(emb_c, _weights['hidden_c']) + tf.matmul(emb_t, _weights['hidden_t']) + _biases['hidden_b']

    inputs = tf.split(_X, FLAGS.max_sent_length, axis=0, num=None, name='split')

    lstmcell = tf.contrib.rnn.BasicLSTMCell(FLAGS.num_hidden, forget_bias=1.0, 
    state_is_tuple=False)

    bilstm = tf.contrib.rnn.static_bidirectional_rnn(lstmcell, lstmcell, inputs, 
    sequence_length=seq_len, initial_state_fw=istate_fw, initial_state_bw=istate_bw)


    pred_mod = [tf.matmul(item, _weights['out_w']) + _biases['out_b'] for item in bilstm]

任何帮助表示赞赏。


对于将来遇到此问题的任何人,上面的代码片段不应该使用。

From tf.contrib.rnn.static_bidirectional_rnnv1.1 文档:

Returns:

A tuple (outputs, output_state_fw, output_state_bw)其中:outputs 是长度为 T 的输出列表(每个输入一个),它们是深度连接的前向和后向输出。 output_state_fw 是前向 rnn 的最终状态。 output_state_bw 是后向 rnn 的最终状态。

上面的列表理解需要 LSTM 输出,获取这些输出的正确方法是:

outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(lstmcell, lstmcell, ...)
pred_mod = [tf.matmul(item, _weights['out_w']) + _biases['out_b'] 
            for item in outputs]

这会起作用,因为每个item in outputs有形状[batch_size, 2 * num_hidden]并可以乘以权重tf.matmul().


附加组件:从tensorflow v1.2+开始,推荐使用的函数位于另一个包中:tf.nn.static_bidirectional_rnn。返回的张量是相同的,因此代码没有太大变化:

outputs, _, _ = tf.nn.static_bidirectional_rnn(lstmcell, lstmcell, ...)
pred_mod = [tf.matmul(item, _weights['out_w']) + _biases['out_b'] 
            for item in outputs]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow:ValueError:形状必须为 2 级,但为 3 级 的相关文章

随机推荐

  • 打印包含“word”的行 python

    我只想打印以下输出中包含 Server 的行 Date Sun 16 Dec 2012 20 07 44 GMT Expires 1 Cache Control private max age 0 Content Type text htm
  • Laravel 表单不会 PATCH,只会 POST - 嵌套 RESTfull 控制器、MethodNotAllowedHttpException

    我正在尝试允许users编辑他们的playlist 但是 每当我尝试执行 PATCH 请求时 我都会得到MethodNotAllowedHttpException错误 它正在等待一个帖子 我已经设置了 RESTful 资源控制器 路线 ph
  • 如何在 bash 的 CURL 请求中使用变量?

    Goal 我正在使用 bash CURL 脚本连接到 Cloudflare APIv4 目标是更新 A 记录 我的脚本 Get current public IP current ip curl silent ipecho net plai
  • 如何在android中动态提供地图api密钥

    我的 Android 应用程序中有一个用例 我的应用程序的用户必须提供 API 密钥 以便他们可以使用地图 但我发现我必须在清单文件中提供 API 密钥 我无法在运行时编辑它 有没有其他方法可以动态地将地图 API 密钥提供给谷歌地图 我正
  • iframe 中 url 的基本身份验证

    我需要验证从页面上的 iframe 通过 javascript 创建 发送的请求 身份验证将通过基本的 http 身份验证完成 我试过做 http user password server 但显然由于安全异常 这在 IE 中不可用 http
  • 如何在 IIS 上设置反向代理,以允许 host1.mydomain.com 和 host2.mydomain.com 之间进行跨主机通信?

    我在 host1 mydomain com page from host1 jsp 上有一个页面 在 host2 mydomain com page from host2 html 上有一个 HTML 页面 host1 是 IIS7 Tom
  • 在 Android 4.4 中启用 TLS 1.2

    我使用 Retrofit 和 OkHttp3 来发出请求 我知道在 Android 4 4 中 默认情况下未启用 TLS 1 1 和 TLS 1 2 所以我正在尝试启用它们 但到目前为止我还没有成功 我读到这可能是 android stud
  • 如何移动google地图的中心位置

    我正在使用以下代码在脚本中创建谷歌地图 var mapElement parent mapOptions map marker latLong openMarker parent document getElementsByClassNam
  • Gitlab 端口 8080

    我目前正在尝试在我的私人 Debian 服务器上安装 Gitlab 综合总线 它在端口 80 上运行得很好 问题是我还有一个 Apache 服务器在监听端口 80 所以我正在尝试让 Nginx监听端口 8080 但由于某种原因我得到了 50
  • 为什么多态性在没有指针/引用的情况下不起作用?

    我确实在 StackOverflow 上发现了一些具有类似标题的问题 但是当我阅读答案时 他们关注的是问题的不同部分 这些部分非常具体 例如 STL 容器 有人可以告诉我 为什么必须使用指针 引用来实现多态性吗 我可以理解指针可能会有所帮助
  • 检测用户所在国家/地区的最快方法

    我需要检测用户的国家 地区并按他 她的国家 地区显示网站的语言 土耳其人用土耳其语 其他人用英语 我怎样才能以最快的方式做到这一点 表现对我来说很重要 我在看IPInfoDB 的 API 还有更好的选择吗 我使用的是PHP 对于可能在 20
  • 消息 8114,级别 16,状态 5,第 1 行将数据类型 varchar 转换为数字时出错

    Select CAST de ornum AS numeric 1 as ornum2 from Cpaym as de left outer join Cpaym as de1 on CAST de ornum AS numeric de
  • 毕加索实际上是如何缓存图像的

    我想知道毕加索图书馆到底是如何缓存应用程序内的图像的 我知道它使用 HttpHeaders 来检查天气以从网络获取图像 但是 它缓存图像有时间范围吗 比如一天后使缓存无效之类的 问题是我的项目正在从网络加载大量小图像 有时 新图像会反映在下
  • 预测精度:没有以两个向量作为参数的 MASE

    我正在使用accuracy函数从forecast包 计算精度测量 我使用它来计算拟合时间序列模型的度量 例如 ARIMA 或指数平滑 当我在不同维度和聚合级别上测试不同模型类型时 我使用 Hyndman 等人引入的 MASE 平均绝对比例误
  • ggplot2 的图像文件压缩选项

    是否可以使用压缩图形的文件大小ggsave 我尝试过使用compression lzw 参数 但文件大小保持不变 使用 R studio 98 501 OS X Yosemite My code ggsave Figure1 tiff wi
  • Selenium Phantomjs 浏览器在启动时挂起。我该如何调试它?

    我正在尝试帮助在其他人的设置上运行我的 selenium Python 绑定版本 2 测试 它可以与 Firefox esr 两台机器上 配合使用 也可以与我的机器上最新的 phantomjs 配合使用 它挂在他的机器上 唯一明显的区别是他
  • 如何根据用户输入动态构建并返回 linq 谓词

    在这件事上有点卡住了 基本上我有一个方法 我想返回一个谓词表达式 我可以将其用作Where 条件 我认为我需要做的与此类似 http msdn microsoft com en us library bb882637 aspx但我对我需要做
  • 如何加速嵌套循环?

    我正在 python 中执行一个嵌套循环 如下所示 这是搜索现有金融时间序列并在时间序列中寻找符合某些特征的周期的基本方法 在这种情况下 有两个独立的 大小相等的数组 分别代表 收盘价 即资产的价格 和 交易量 即一段时间内交换的资产数量
  • 如何通过 SendKeys 发送特殊字符?

    我正在尝试在 Selenium2 中填写表格 One input has an autocomplete that I want to close preferably by sending esc after the search ter
  • Tensorflow:ValueError:形状必须为 2 级,但为 3 级

    我是张量流新手 我正在尝试将双向 LSTM 的一些代码从旧版本的张量流更新到最新版本 1 0 但出现此错误 形状必须为等级 2 但 MatMul 3 操作 MatMul 的等级为 3 输入形状为 100 400 400 2 错误发生在 pr