我正在使用 Tensorflow 2.0 并面临以下情况:
@tf.function
def my_fn(items):
.... #do stuff
return
如果 items 是张量的字典,例如:
item1 = tf.zeros([1, 1])
item2 = tf.zeros(1)
items = {"item1": item1, "item2": item2}
有没有办法使用 tf.function 的 input_signature 参数,以便我可以强制 tf2 避免在 item1 为示例时创建多个图形tf.zeros([2,1])
?
输入签名必须是一个列表,但列表中的元素可以是字典或张量规格列表。对于你的情况我会尝试:(name
属性是可选的)
signature_dict = { "item1": tf.TensorSpec(shape=[2], dtype=tf.int32, name="item1"),
"item2": tf.TensorSpec(shape=[], dtype=tf.int32, name="item2") }
# don't forget the brackets around the 'signature_dict'
@tf.function(input_signature = [signature_dict])
def my_fn(items):
.... # do stuff
return
# calling the TensorFlow function
my_fun(items)
但是,如果您想调用由my_fn
,你必须解压字典。您还必须提供name
属性在tf.TensorSpec
.
# creating a concrete function with an input signature as before but without
# brackets and with mandatory 'name' attributes in the TensorSpecs
my_concrete_fn = my_fn.get_concrete_function(signature_dict)
# calling the concrete function with the unpacking operator
my_concrete_fn(**items)
这很烦人,但应该在 TensorFlow 2.3 中得到解决。 (参见 TF 指南“具体函数”的末尾)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)