GCN是一种利用图结构和邻居顶点属性信息学习顶点Embedding表示的方法,GCN是直推式学习(只能在一个已知的图上进行学习),不能直接泛化到未知节点,当网络结构改变以及新节点的出现,直推式学习需要重新训练(复杂度高且可能会导致embedding会偏移),很难落地在需要快速生成未知节点embedding的机器学习系统上。
**GraphSAGE(Graph SAmple and aggreGatE)**是一种能利用顶点的属性信息高效产生未知顶点embedding的一种归纳式(inductive)学习的框架。
与GCN类似,其核心思想:学习一个映射
f
(
.
)
f(.)
f(.),通过该映射图中的节点
v
i
v_i
vi可以聚合它自己的特征
x
i
x_i
xi与它的邻居特征
x
j
(
j
∈
N
(
v
i
)
)
x_j \;(j \in N(v_i))
xj(j∈N(vi))来生成节点的新
v
i
v_i
vi表示。 区别在于并未利用所有的邻居节点,聚合的方式也不同。GraphSAGE框架的核心是如何聚合节点邻居特征信息。
GraphSAGE 前向传播算法
下图是GraphSAGE的学习过程:
主要步骤如下:
(1)对邻居随机采样
(2)使用聚合函数将采样的邻居节点的Embeddin进行聚合,用于更新节点的embedding。
(3)根据更新后的embedding预测节点的标签。
更新过程:
(1)为了更新红色节点,首先在第一层(k=1)我们会将蓝色节点的信息聚合到红色节点上,将绿色节点的信息聚合到蓝色节点上。所有的节点都有了新的包含邻居节点的embedding。
(2)在第二层(k=2)红色节点的embedding被再次更新,不过这次用的是更新后的蓝色节点embedding,这样就保证了红色节点更新后的embedding包括蓝色和绿色节点的信息。这样,每个节点又有了新的embedding向量,且包含更多的信息。
算法细节如下:
需要注意以下几点:
1、
h
v
0
h_v^0
hv0是每个节点的初始embedding特征向量
2、当
k
=
1
k=1
k=1时,遍历所有的节点,求
h
v
1
h_v^1
hv1,也就是算法的4-5行,也是最核心的部分。具体的:
(1)先对当前节点
v
v
v的邻居进行采样,得到邻居节点的集合
N
(
v
)
\mathcal N(v)
N(v),对所有的邻居节点
{
u
∈
N
(
v
)
}
\{ u \in \mathcal N(v)\}
{u∈N(v)}的
k
−
1
k-1
k−1层的embedding:
h
u
(
k
−
1
)
=
h
u
0
h_u^{(k-1)}=h_u^{0}
hu(k−1)=hu0 进行聚合,得到
v
v
v的邻居节点的代表向量
h
N
(
v
)
k
h_{\mathcal N(v)}^k
hN(v)k。如何聚合后面会提到。
(2)concat操作,将的、邻居节点的代表向量
h
N
(
v
)
k
h_{\mathcal N(v)}^k
hN(v)k 与自身的
h
v
k
−
1
=
h
v
0
h_v^{k-1}=h_v^0
hvk−1=hv0 进行连接,然后与权重变量
W
W
W相乘,并进行激活。其中
W
W
W用于控制在模型的不同层或“搜索深度”之间传播信息。
这样求出的
h
v
1
h_v^1
hv1就包含了邻居节点的信息。以此类推,当求
h
v
2
h_v^2
hv2时会用到
h
u
1
,
u
∈
N
(
v
)
h_u^1,u \in \mathcal N(v)
hu1,u∈N(v),而从上面的描述可知
h
u
1
h_u^1
hu1已经包含了
u
u
u的邻居节点信息。所以在每次迭代或搜索深度时,节点从它们的本地邻居处聚集信息,随着这个过程的迭代,节点从图的更远处获得越来越多的信息。
3、随着K增大,节点可以聚合更多的信息,K既是聚合器的数量,也是权重矩阵的数量,还是网络的层数,这是因为每一层网络中聚合器和权重矩阵是共享的。网络的层数可以理解为需要最大访问到的邻居的跳数(hops),比如在figure 1中,红色节点的更新拿到了它一、二跳邻居的信息,那么网络层数就是2。
采样算法&聚合(aggragator)操作
采样算法
GraphSAGE采用了定长抽样的方法。先确定需要采样的邻居数
N
N
N,然后采用有放回的重采样/负采样的方法达到
N
N
N,这样做可以方便后期训练。
聚合(aggragator)操作
聚合方式有:平均、GCN归纳式、LSTM、pooling聚合器。(因为邻居没有顺序,聚合函数需要满足排序不变量的特性,即输入顺序不会影响函数结果)
1,平均聚合:对邻居节点的embedding中的每个维度取平均,然后与自身节点的embedding拼接后进行非线性变换。
h
N
(
v
)
k
=
mean
(
{
h
u
k
−
1
,
u
∈
N
(
v
)
}
)
h
v
k
=
σ
(
W
k
⋅
CONCAT
(
h
v
k
−
1
,
h
N
(
u
)
k
)
)
\begin{array}{c} h_{N(v)}^{k}=\operatorname{mean}\left(\left\{h_{u}^{k-1}, u \in N(v)\right\}\right) \\ h_{v}^{k}=\sigma\left(W^{k} \cdot \operatorname{CONCAT}\left(h_{v}^{k-1}, h_{N(u)}^{k}\right)\right) \end{array}
hN(v)k=mean({huk−1,u∈N(v)})hvk=σ(Wk⋅CONCAT(hvk−1,hN(u)k))
2,归纳式聚合:直接对目标节点和所有邻居emebdding中每个维度取平均,后再非线性转换。
h
v
k
=
σ
(
W
k
⋅
mean
(
{
h
v
k
−
1
}
∪
{
h
u
k
−
1
,
∀
u
∈
N
(
v
)
}
)
h_{v}^{k}=\sigma\left(W^{k} \cdot \operatorname{mean}\left(\left\{h_{v}^{k-1}\right\} \cup\left\{h_{u}^{k-1}, \forall u \in N(v)\right\}\right)\right.
hvk=σ(Wk⋅mean({hvk−1}∪{huk−1,∀u∈N(v)})
3,LSTM 聚合
LSTM函数不符合“排序不变量”的性质,需要先对邻居随机排序,然后将随机的邻居序列embedding作为LSTM输入。
4,Pooling聚合:先对每个邻居节点上一层embedding进行非线性转换,再按维度应用 max/mean pooling,捕获邻居集上在某方面的突出的/综合的表现 以此表示目标节点embedding。
h
N
(
v
)
k
=
max
(
{
σ
(
W
pool
h
u
i
k
+
b
)
}
,
∀
u
i
∈
N
(
v
)
)
h
v
k
=
σ
(
W
k
⋅
CONCAT
(
h
v
k
−
1
,
h
N
(
u
)
k
−
1
)
)
\begin{aligned} h_{N(v)}^{k} &=\max \left(\left\{\sigma\left(W_{\text {pool}} h_{u i}^{k}+b\right)\right\}, \forall u_{i} \in N(v)\right) \\ h_{v}^{k} &=\sigma\left(W^{k} \cdot \operatorname{CONCAT}\left(h_{v}^{k-1}, h_{N(u)}^{k-1}\right)\right) \end{aligned}
hN(v)khvk=max({σ(Wpoolhuik+b)},∀ui∈N(v))=σ(Wk⋅CONCAT(hvk−1,hN(u)k−1))
参数学习
GraphSAGE的参数主要是聚合器的参数和权重变量
W
W
W。为了获得最优参数就得定义合适的损失函数。
1、有监督学习
可以使用每个节点的预测label和真实label的交叉熵作为损失函数。
2、无监督学习
其中:
z
u
z_u
zu是节点
u
u
u通过GraphSAGE生成的embedding;
v
v
v是节点
u
u
u随机游走可到达的"邻居"节点。
v
n
∼
p
n
(
v
)
v_n \sim p_n(v)
vn∼pn(v)表示
v
n
v_n
vn是从节点u的负采样分
p
n
(
v
)
p_n(v)
pn(v)的采样。负采样指我们还需要一批不是
u
u
u邻居的节点作为负样本。
Q为采样样本数。
embedding之间相似度通过向量点积计算得到。
如何理解这个损失函数?
先看损失函数的蓝色部分,当节点 u、v 比较接近时,那么其 embedding 向量
z
u
,
z
v
z_u, z_v
zu,zv的距离应该比较近,因此二者的内积应该很大,经过σ函数后是接近1的数,因此取对数后的数值接近于0。
再看看紫色的部分,当节点 u、v 比较远时,那么其 embedding 向量
z
u
,
z
v
z_u, z_v
zu,zv的距离应该比较远,在理想情况下,二者的内积应该是很大的负数,乘上-1后再经过σ函数可以得到接近1的数,因此取对数后的数值接近于0。
基于tensorflow2.0实现Graph SAGE
主要实现图的无监督学习与分类。
完整项目实现
参考文章
【Graph Neural Network】GraphSAGE: 算法原理,实现和应用
GNN教程:GraghSAGE算法细节详解!
GraphSAGE: GCN落地必读论文
GraphSAGE 模型解读与tensorflow2.0代码实现