给定一个包含整数(代表类)的未知维度 [?, ?] 的 2D 张量,我想获得一个相同形状的新张量,但用从查找表中获取的浮点数替换值(代表类权重)。
例如:
inputs = [ [1,3,3], [2,4,2] ]
lookup table: {1: 0.2, 2: 0.25, 3: 0.1, 4: 0.45}
output: [ [0.2, 0.1, 0.1], [0.25, 0.45, 0.25] ]
我尝试使用 tf.map_fn 链接两个 lambda 函数,迭代每一行,然后迭代每个元素:
elem_iter = lambda y: unknown_lookup_function(y)
row_iter = lambda x: elem_iter(x)
weights = tf.map_fn(row_iter, inputs, dtype=tf.float32)
但无法找到定义查找函数的正确方法。
关于如何实施这种行为有什么建议吗?是否有我可以使用的本机操作来代替 map_fn ?