如何在张量流中将 2d 张量与 3d 张量相乘?

2024-02-09

In numpy您可以将 2d 数组与 3d 数组相乘,如下例所示:

>>> X = np.random.randn(3,5,4) # [3,5,4]
... W = np.random.randn(5,5) # [5,5]
... out = np.matmul(W, X) # [3,5,4]

据我了解,np.matmul() takes W并沿着第一维度广播它X。但在tensorflow不允许:

>>> _X = tf.constant(X)
... _W = tf.constant(W)
... _out = tf.matmul(_W, _X)

ValueError: Shape must be rank 2 but is rank 3 for 'MatMul_1' (op: 'MatMul') with input shapes: [5,5], [3,5,4].

那么是否有一个等价的东西np.matmul()上面做的tensorflow?最好的做法是什么tensorflow将 2d 张量与 3d 张量相乘?


尝试使用tf.tile https://www.tensorflow.org/api_docs/python/tf/tile匹配乘法之前矩阵的维度。 numpy的自动广播功能在tensorflow中似乎没有实现。你必须手动完成。

W_T = tf.tile(tf.expand_dims(W,0),[3,1,1])

这应该可以解决问题

import numpy as np
import tensorflow as tf

X = np.random.randn(3,4,5)
W = np.random.randn(5,5)

_X = tf.constant(X)
_W = tf.constant(W)
_W_t = tf.tile(tf.expand_dims(_W,0),[3,1,1])

with tf.Session() as sess:
    print(sess.run(tf.matmul(_X,_W_t)))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在张量流中将 2d 张量与 3d 张量相乘? 的相关文章

随机推荐