我正在运行一个具有非常大的词嵌入(> 2M 词)的模型。当我使用 tf.embedding_lookup 时,它需要一个很大的矩阵。当我运行时,我随后出现了 GPU 内存错误。如果我减小嵌入的大小,一切都会正常。
有没有办法处理更大的嵌入?
推荐的方法是使用分区器 https://www.tensorflow.org/versions/r0.12/api_docs/python/state_ops/variable_partitioners_for_sharding将这个大张量分成几个部分:
embedding = tf.get_variable("embedding", [1000000000, 20],
partitioner=tf.fixed_size_partitioner(3))
这会将张量沿 0 轴分成 3 个分片,但程序的其余部分会将其视为普通张量。最大的好处是使用分区器参数服务器复制, 像这样:
with tf.device(tf.train.replica_device_setter(ps_tasks=3)):
embedding = tf.get_variable("embedding", [1000000000, 20],
partitioner=tf.fixed_size_partitioner(3))
这里的关键函数是tf.train.replica_device_setter https://www.tensorflow.org/api_docs/python/tf/train/replica_device_setter。
它允许您运行 3 个不同的进程,称为参数服务器 https://www.tensorflow.org/deploy/distributed,存储所有模型变量。大的embedding
张量将被分割到这些服务器上,如图所示。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)