Node
Node类构造函数的各项参数如下(参考torch.fx下的node.py):
- graph:指明实例化的Node属于哪个Graph
- op:节点的类型。一共有如下的几种类型:
- placeholder:占位符,一般代表输入。
- call_method:表示一种操作,该操作表示让前驱节点的输出对象调用自己的方法。
- call_module:表示一种操作,该操作表示将前驱节点的输出输入到nn.Module中。
- call_function:表示一种操作,该操作表示将前驱节点的输出输入到一个函数中。
- get_attr:表示一个操作,该操作获取Module自己的一个属性,并保存到输出中。
- output:输出节点,表示该节点是所属Graph的输出节点,即在所属Graph中无后继节点。
- root:整个Node在底层的数据结构是一个循环双向链表,root代表这个双向链表的头,是一个空的链表节点,用于维护双向链表。
- name:节点的名字。
- target:该节点需要调用的对象。如果op是call_function,那么target必须是一个Callable,否则必须是str。
- args:需要传递给target的变长参数。
- kwargs:需要传递给target的位置参数。
- return_type:代表该节点的输出数据的数据类型。
除此之外,Node在初始化时会创建几个属性,其中有几个比较重要:
- _input_nodes:一个哈希表,key是Node,value是None:代表self在Graph模式下的所有前驱节点。
- users:一个哈希表,数据类型同_input_nodes,代表self在Graph模式下的所有后继节点。
- _prev:self底层存储逻辑的前驱节点。
- _next:self底层存储逻辑的后继节点。
剩下来对于双向循环链表的插入(往前还是往后)和删除在Node中都实现了。
总结一下,Node通过_input_nodes和users来表示原本计算图的拓扑结构。其管理和存储是通过双向循环链表来的(和CPython的堆变量管理类似)
#查找conv节点
model = models.resnet18()
fx_model = fx.symbolic_trace(model)
modules = dict(fx_model.named_modules())
for node in fx_model.graph.nodes:
if node.target in modules:
print(node.target)
if type(modules[node.target]) == nn.Conv2d:
print('conv node')
Graph
由于Node中已经定义了完整管理计算图的属性和方法,因此Graph更多是对Node的管理和封装。
首先Graph在初始化时会创建一个root节点(在一张图中,有且仅有一个root):
self._root : Node = Node(self, '', 'root', '', (), {})
这个root就是底层用于存储和管理Node的双向循环链表。并且指定了插入节点的方法为向前插入。
在Graph的create_node方法中。就是创建一个node,然后再插入到链表中。
除此之外,Graph还指定了一些用于描述上下文、所属模块的信息,这些和后续的原理关系不大,就不赘述了。
GraphModule
GraphModule是对Graph的封装,为啥还要封装呢?因为Graph的mro中没有nn.Module,为了工程规范,GraphModule继承了nn.Module并对Graph做了一个简单的封装。
Proxy
正如其名,Proxy类是对Node类的一层包裹,它允许用户在不修改原图的情况下,用自定义的函数代理其中的节点,从而完成重载。
Tracer
Tracer类是对符号跟踪的一层抽象,它的symbolic_trace(m)
等价于Tracer().trace(m)
。Tracer继承自TracerBase。
而symbolic_trace也只是对Tracer().trace的输出结果使用GraphModule进行了一次包装。因此,只需要看懂Tracer().trace的实现逻辑即可。