keras中使用Lambda封装使用tf的函数经验记录(包含传入多个参数遇到的坑)

2023-05-16

有很多函数是tf特有的,如果在keras模型中混用这些tf的函数,程序就会报错:提示你不是keras tensor
那么这个时候就需要我们利用Lambda对tf的函数进行封装,最后利用keras.layer.Lambda把结果转成keras tensor即可~

我当时还遇到了一个问题:Lambda函数传两个参数的时候,会报TypeError: __init__() takes 2 positional argument but 3 were given,倒腾了半天,最终解决方法是:传参的时候加中括号[],自定义函数里面通过索引访问传入的列表,比如:x[0],x[1],x[2],…

(这个方法是自己试出来的,因为在网上确实没搜到针对该问题的解决方案,真的是自己半猜半试出来的)

  • 使用前先导包
from keras.layer import Lambda
  • 下面是封装的tf的矩阵维度转换和矩阵乘法的函数
# 自定义封装矩阵维度转换
def my_permute(self, x):
        theta_d = tf.transpose(x, (0, 2, 1))
        return theta_d

# 自定义封装矩阵乘法
def my_matmul(self, x):
    res = tf.matmul(x[0], x[1])
    return res

theta_d = Lambda(self.my_permute)(theta_d)
SelfGuid = Lambda(self.my_matmul)([theta_d, phi_d])
  • 无关的小补充:keras中add的函数传入的参数需要用中括号括起来~
from keras.layer import add
res = add([x,y])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

keras中使用Lambda封装使用tf的函数经验记录(包含传入多个参数遇到的坑) 的相关文章

随机推荐