给定一个类实例列表A
, [A() for _ in range(5)]
,我想随机选择其中一个(示例请参见下面的代码)
class A:
def __init__(self, a):
self.a = a
def __call__(self):
return self.a
def f():
a_list = [A(i) for i in range(5)]
a = a_list[random.randint(0, 5)]()
return a
f()
有没有什么装饰的方法f
with @tf.function
不改变什么f
执行并且不调用所有项目a_list
?
注意直接装饰f
with @tf.function
不对上述代码进行任何其他更改是不可行的,因为它将始终返回相同的结果。另外,我知道这可以通过调用中的所有元素来实现a_list
首先,然后使用索引它们tf.gather_nd
。但是如果调用类型的对象,这会产生大量的开销A
涉及深度神经网络。
我现在正在做同样的事情。这是我到目前为止所得到的。如果有人知道更好的方法,我也有兴趣听听。当我在昂贵的调用上运行它时,它比计算并返回所有值要快。
@tf.function
def f2():
a_list = [A(i) for i in range(5)]
idx = tf.cast(tf.random.uniform(shape=[], maxval=4), tf.int32)
return tf.switch_case(idx, a_list)
为了进行速度比较,我使用了昂贵的矩阵代数的调用方法。然后考虑一个调用每个函数的替代函数:
@tf.function
def f3():
a_list = [A(i) for i in range(40)]
results = [a() for a in a_list]
return results
运行具有 40 个元素的 f2:0.42643 秒
运行具有 40 个元素的 f3:14.9153 秒
因此,仅选择一个分支即可实现 40 倍的预期加速,这看起来是正确的。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)