调用Variable类即可向Graph中添加变量。Variable在创建之后需要给定初始值,可以是任意type、shape的Tensor。一旦使用初始值完成了初始化,type和shape都固定,除非使用assign方法改变。
一、Variable基本使用方法
给定了init_value并没有真正进行赋值,还需要初始化,初始化方法有两种:
- 单变量手动初始化:每个Variable都有initializer操作,调用方法为:
my_tensor = tf.constant([[1.0, 2.0], [3.0, 4.0]])
w = tf.Variable(initial_value=my_tensor, name="W")
with tf.Session() as sess:
sess.run(w.initializer)
- 全局变量统一初始化:一般在launch整个Graph之前进行全局所有变量的初始化。使用**global_variables_initializer()**方法做为一个Op添加进图中,也需要run:
init_op = tf.compat.v1.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
关于Variable还有两个重要的、关于collection的方法:
- tf.global_variables():在确定的Graph中,每创建一个变量,都会在GraphKeys.GLOBAL_VARIABLES中进行新增。因此调用此函数可以获取到此图中所有变量。
- tf.trainable_variables():ML中要区分可训练参数和不可训练参数(训练轮次),不可训练要求时需要在定义时设定trainable=False。默认可训练,系统自动将此变量添加在GraphKeys.TRAINABLE_VARIABLES中,通过此函数返回可训练参数集合,进一步送入Optimizer中进行指定。
二、Variable的构造函数
构造函数具有非常多的值传入,但一般只需要考虑initial_value和name即可,对参数具体说明如下:
__init__(
initial_value=None,
trainable=None,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None,
constraint=None,
use_resource=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.VariableAggregation.NONE,
shape=None
)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)