是否有任何文档描述 Keras 中的哪些字符串名称映射到哪些对象?例如,下面我创建了一个嵌入层tf.keras.layers
我可以用'uniform'
映射到tf.keras.initializers.RandomUniform
class.
tf.keras.layers.Embedding(1000, 64, embeddings_initializer='uniform')
但我只是通过查看该用法的示例才知道这一点。我认为受支持的字符串形式已在某处记录,但我似乎找不到此类文档,并且挖掘代码变得过于抽象而难以轻松理解。
版本:TF 1.13.1
TF 中的 keras 实现中没有可用的字符串常量列表(我想,在原始 keras 中也没有)。
For the 初始化器 https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/keras/initializers.py案例'uniform'
字符串被转换为配置,并在该配置上调用结构方法,并提示从初始化器命名空间创建对象(可以在此处找到)def deserialize_keras_object https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/keras/utils/generic_utils.py):
config = {'class_name': str(identifier), 'config': {}}
deserialize_keras_object(
config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='initializer')
因此,我想不出比以下更好的方法,例如列出所有初始化程序:
import tensorflow as tf
for k, v in tf.keras.initializers.__dict__.items():
if not k[0].isupper() and not k[0] == "_":
print(k)
输出虽然有额外的值,但类似于:
constant
glorot_normal
glorot_uniform
identity
ones
orthogonal
zeros
he_normal
he_uniform
lecun_normal
lecun_uniform
normal
random_normal
random_uniform
uniform
truncated_normal
deserialize
get
serialize
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)