tensorflow码源-运行流程
简介
通过分析用户构建的计算是如何在tensorflow中运行的,了解tensorflow中的基本元素和op、kernel和device之间的交互。
用户程序
matrix1 = tf.constant([[3., 3.]]
matrix2 = tf.constant([[2.],[2.]])
product = tf.matmul(matrix1, matrix2)
with tf.Session() as sess:
result = sess.run([product])
print(result)
如何构建计算流图的
Op的注册
tensroflow中的计算流图中的节点被称为ops,只记录了计算的节点的属性。计算有绑定该op的kernel来实现。
在python中,op等的注册实际上是有C++实现自动生成的。对于matmul属于math,它的注册在gen_math_ops.py(源码中看不到,需要在生成的python包中找)。
注册的ops保存在OpDefLibrary类型的实例_op_def_lib中,_op_def_lib在初始化阶段会接收一个op_list,其中包含了各math ops的定义。
def _InitOpDefLibrary():
op_list = _op_def_pb2.OpList()
_text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
_op_def_registry.register_op_list(op_list)
op_def_lib = _op_def_library.OpDefLibrary()
op_def_lib.add_op_list(op_list)
return op_def_lib
_InitOpDefLibrary.op_list_ascii = """op {
3072 name: "Abs"
3073 input_arg {
3074 name: "x"
3075 type_attr: "T"
3076 }
3077 output_arg {
3078 name: "y"
3079 type_attr: "T"
3080 }
3081 attr {
3082 name: "T"
3083 type: "type"
3084 allowed_values {
3085 list {
3086 type: DT_HALF
3087 type: DT_FLOAT
3088 type: DT_DOUBLE
3089 type: DT_INT32
3090 type: DT_INT64
3091 }
3092 }
3093 }
3094 }
.."
_op_def_lib = _InitOpDefLibrary()
//_op_def_registry.register_op_list(op_list)对应的是
tensorflow.core.framework.op_def_registry.register_op_list
//会把op有注册到到_registered_ops = {}中,可以通过get_registered_ops()获取
//两着的本质应该是一样的,用的地方应该不一样,
//_op_def_lib用于构建计算的图?
//_registered_ops 用于其他?
在C++中也有类似的过程,实现在C++中op的注册
MatMul的在 core/ops/math_ops.cc::1021
插入op到graph
tf.matmul等函数会在图中插入我们给定的操作。这些操作最终都会转化成
def matmul(a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False, a_is_sparse=False, b_is_sparse=False, name=None): #tensorflow.python.ops.math_ops
return gen_math_ops._mat_mul( a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
#tensorflow.python.ops. gen_math_ops(生成的包中)
def _mat_mul(a, b, transpose_a=None, transpose_b=None, name=None)
_op_def_lib.apply_op(“MatMul”,…)
#apply_op比较复杂,最终会在图中添加该op
331 g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
...
# Add Op to graph
766 op = g.create_op(op_type_name, inputs, output_types, name=scope,
767 input_types=input_types, attrs=attr_protos,
768 op_def=op_def)
return op
#如果参数没有提供图,则会在当前线程默认的图中添加操作。
#def get_graph_from_inputs(op_input_list, graph=None):
# """Returns the appropriate graph to use for the given inputs.
# 1. If `graph` is provided, we validate that all inputs in `op_input_list` are
# from the same graph.
# 2. Otherwise, we attempt to select a graph from the first Operation- or</