【Graph Neural Network】 GraphSAGE 基本原理与tensorflow2.0实现

2023-10-31


GCN是一种利用图结构和邻居顶点属性信息学习顶点Embedding表示的方法,GCN是直推式学习(只能在一个已知的图上进行学习),不能直接泛化到未知节点,当网络结构改变以及新节点的出现,直推式学习需要重新训练(复杂度高且可能会导致embedding会偏移),很难落地在需要快速生成未知节点embedding的机器学习系统上。

**GraphSAGE(Graph SAmple and aggreGatE)**是一种能利用顶点的属性信息高效产生未知顶点embedding的一种归纳式(inductive)学习的框架。

与GCN类似,其核心思想:学习一个映射 f ( . ) f(.) f(.),通过该映射图中的节点 v i v_i vi可以聚合它自己的特征 x i x_i xi与它的邻居特征 x j    ( j ∈ N ( v i ) ) x_j \;(j \in N(v_i)) xj(jN(vi))来生成节点的新 v i v_i vi表示。 区别在于并未利用所有的邻居节点,聚合的方式也不同。GraphSAGE框架的核心是如何聚合节点邻居特征信息。

GraphSAGE 前向传播算法

下图是GraphSAGE的学习过程:

在这里插入图片描述

主要步骤如下:

(1)对邻居随机采样

(2)使用聚合函数将采样的邻居节点的Embeddin进行聚合,用于更新节点的embedding。

(3)根据更新后的embedding预测节点的标签。

更新过程:

(1)为了更新红色节点,首先在第一层(k=1)我们会将蓝色节点的信息聚合到红色节点上,将绿色节点的信息聚合到蓝色节点上。所有的节点都有了新的包含邻居节点的embedding。

(2)在第二层(k=2)红色节点的embedding被再次更新,不过这次用的是更新后的蓝色节点embedding,这样就保证了红色节点更新后的embedding包括蓝色和绿色节点的信息。这样,每个节点又有了新的embedding向量,且包含更多的信息。

算法细节如下:

在这里插入图片描述

需要注意以下几点:

1、 h v 0 h_v^0 hv0是每个节点的初始embedding特征向量

2、当 k = 1 k=1 k=1时,遍历所有的节点,求 h v 1 h_v^1 hv1,也就是算法的4-5行,也是最核心的部分。具体的:

(1)先对当前节点 v v v的邻居进行采样,得到邻居节点的集合 N ( v ) \mathcal N(v) N(v),对所有的邻居节点 { u ∈ N ( v ) } \{ u \in \mathcal N(v)\} {uN(v)} k − 1 k-1 k1层的embedding: h u ( k − 1 ) = h u 0 h_u^{(k-1)}=h_u^{0} hu(k1)=hu0 进行聚合,得到 v v v的邻居节点的代表向量 h N ( v ) k h_{\mathcal N(v)}^k hN(v)k。如何聚合后面会提到。

(2)concat操作,将的、邻居节点的代表向量 h N ( v ) k h_{\mathcal N(v)}^k hN(v)k 与自身的 h v k − 1 = h v 0 h_v^{k-1}=h_v^0 hvk1=hv0 进行连接,然后与权重变量 W W W相乘,并进行激活。其中 W W W用于控制在模型的不同层或“搜索深度”之间传播信息。

这样求出的 h v 1 h_v^1 hv1就包含了邻居节点的信息。以此类推,当求 h v 2 h_v^2 hv2时会用到 h u 1 , u ∈ N ( v ) h_u^1,u \in \mathcal N(v) hu1,uN(v),而从上面的描述可知 h u 1 h_u^1 hu1已经包含了 u u u的邻居节点信息。所以在每次迭代或搜索深度时,节点从它们的本地邻居处聚集信息,随着这个过程的迭代,节点从图的更远处获得越来越多的信息。

3、随着K增大,节点可以聚合更多的信息,K既是聚合器的数量,也是权重矩阵的数量,还是网络的层数,这是因为每一层网络中聚合器和权重矩阵是共享的。网络的层数可以理解为需要最大访问到的邻居的跳数(hops),比如在figure 1中,红色节点的更新拿到了它一、二跳邻居的信息,那么网络层数就是2。

采样算法&聚合(aggragator)操作

采样算法

GraphSAGE采用了定长抽样的方法。先确定需要采样的邻居数 N N N,然后采用有放回的重采样/负采样的方法达到 N N N,这样做可以方便后期训练。

聚合(aggragator)操作

聚合方式有:平均、GCN归纳式、LSTM、pooling聚合器。(因为邻居没有顺序,聚合函数需要满足排序不变量的特性,即输入顺序不会影响函数结果)

1,平均聚合:对邻居节点的embedding中的每个维度取平均,然后与自身节点的embedding拼接后进行非线性变换。
h N ( v ) k = mean ⁡ ( { h u k − 1 , u ∈ N ( v ) } ) h v k = σ ( W k ⋅ CONCAT ⁡ ( h v k − 1 , h N ( u ) k ) ) \begin{array}{c} h_{N(v)}^{k}=\operatorname{mean}\left(\left\{h_{u}^{k-1}, u \in N(v)\right\}\right) \\ h_{v}^{k}=\sigma\left(W^{k} \cdot \operatorname{CONCAT}\left(h_{v}^{k-1}, h_{N(u)}^{k}\right)\right) \end{array} hN(v)k=mean({huk1,uN(v)})hvk=σ(WkCONCAT(hvk1,hN(u)k))
2,归纳式聚合:直接对目标节点和所有邻居emebdding中每个维度取平均,后再非线性转换。
h v k = σ ( W k ⋅ mean ⁡ ( { h v k − 1 } ∪ { h u k − 1 , ∀ u ∈ N ( v ) } ) h_{v}^{k}=\sigma\left(W^{k} \cdot \operatorname{mean}\left(\left\{h_{v}^{k-1}\right\} \cup\left\{h_{u}^{k-1}, \forall u \in N(v)\right\}\right)\right. hvk=σ(Wkmean({hvk1}{huk1,uN(v)})
3,LSTM 聚合

LSTM函数不符合“排序不变量”的性质,需要先对邻居随机排序,然后将随机的邻居序列embedding作为LSTM输入。

4,Pooling聚合:先对每个邻居节点上一层embedding进行非线性转换,再按维度应用 max/mean pooling,捕获邻居集上在某方面的突出的/综合的表现 以此表示目标节点embedding。
h N ( v ) k = max ⁡ ( { σ ( W pool h u i k + b ) } , ∀ u i ∈ N ( v ) ) h v k = σ ( W k ⋅ CONCAT ⁡ ( h v k − 1 , h N ( u ) k − 1 ) ) \begin{aligned} h_{N(v)}^{k} &=\max \left(\left\{\sigma\left(W_{\text {pool}} h_{u i}^{k}+b\right)\right\}, \forall u_{i} \in N(v)\right) \\ h_{v}^{k} &=\sigma\left(W^{k} \cdot \operatorname{CONCAT}\left(h_{v}^{k-1}, h_{N(u)}^{k-1}\right)\right) \end{aligned} hN(v)khvk=max({σ(Wpoolhuik+b)},uiN(v))=σ(WkCONCAT(hvk1,hN(u)k1))

参数学习

GraphSAGE的参数主要是聚合器的参数和权重变量 W W W。为了获得最优参数就得定义合适的损失函数。

1、有监督学习

可以使用每个节点的预测label和真实label的交叉熵作为损失函数。

2、无监督学习

在这里插入图片描述

其中: z u z_u zu是节点 u u u通过GraphSAGE生成的embedding;

v v v是节点 u u u随机游走可到达的"邻居"节点。

v n ∼ p n ( v ) v_n \sim p_n(v) vnpn(v)表示 v n v_n vn是从节点u的负采样分 p n ( v ) p_n(v) pn(v)的采样。负采样指我们还需要一批不是 u u u邻居的节点作为负样本。

​ Q为采样样本数。

​ embedding之间相似度通过向量点积计算得到。

如何理解这个损失函数?

先看损失函数的蓝色部分,当节点 u、v 比较接近时,那么其 embedding 向量 z u , z v z_u, z_v zu,zv的距离应该比较近,因此二者的内积应该很大,经过σ函数后是接近1的数,因此取对数后的数值接近于0。

再看看紫色的部分,当节点 u、v 比较远时,那么其 embedding 向量 z u , z v z_u, z_v zu,zv的距离应该比较远,在理想情况下,二者的内积应该是很大的负数,乘上-1后再经过σ函数可以得到接近1的数,因此取对数后的数值接近于0。

基于tensorflow2.0实现Graph SAGE

主要实现图的无监督学习与分类。

完整项目实现

参考文章

【Graph Neural Network】GraphSAGE: 算法原理,实现和应用

GNN教程:GraghSAGE算法细节详解!

GraphSAGE: GCN落地必读论文

GraphSAGE 模型解读与tensorflow2.0代码实现

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【Graph Neural Network】 GraphSAGE 基本原理与tensorflow2.0实现 的相关文章

随机推荐

  • C/C++ 中的%d等意义

    d 整型输出 ld长整型输出 o 以八进制数形式输出整数 x 以十六进制数形式输出整数 或输出字符串的地址 u 以十进制数输出unsigned型数据 无符号数 注意 d与 u有无符号的数值范围 也就是极限的值 不然数值打印出来会有误 c 用
  • HTTP 错误 404.3 – Not Found 由于扩展配置问题而无法提供您请求的页面。如果该页面是脚本,请添加处理程序。如果应下载文件,请添加 MIME 映射。

    今天 在vs2013中新建了一个placard json文件 当我用jq读取它的时候 去提示404 直接在浏览器访问这个文件 提示 HTTP 错误 404 3 Not Found 由于扩展配置问题而无法提供您请求的页面 如果该页面是脚本 请
  • session 和 cookie 有什么区别?

    Session 和 Cookie 都是Web开发中非常重要的概念 它们用于保存Web应用程序状态和用户信息 但是它们有一些重要的区别 1 存储位置不同 Cookie 存储于客户端 浏览器 而 Session 存储于服务器端 对于服务器端存储
  • python训练模型、如何得到模型训练总时长_模型训练时间的估算

    模型训练时间的估算 昨天群里一个朋友训练一个BERT句子对模型 使用的是CPU来进行训练 由于代码是BERT官方代码 并没有显示训练需要的总时间 所以训练的时候只能等待 他截图发了基本的信息 想知道训练完整个模型需要多久 最开始跑BERT模
  • 如何做一个人工智能聊天机器人的毕业设计

    毕业设计是大学生的必修课程之一 许多学生在毕业设计中选择了人工智能方向的课题 人工智能聊天机器人是一个很好的毕业设计课题 它涉及到自然语言处理 机器学习 深度学习等人工智能的核心技术 做好一个聊天机器人的毕业设计需要考虑好聊天机器人的功能
  • 树莓派3 有线静态路由设置_配置树莓派为wifi热点(AP模式)

    该功能主要用于搭建一个小型的的网络 使得连接至热点的各个设备可以进行通信 用于构建物联网系统 如智能家居 或是无线控制指定设备 另外这一网络也是独立的 并未启用NAT连接至互联网 具有一定的安全性也为研究提供了一定的便利 目前已经有很多方案
  • 删除rabbitmq的队列和队列中的数据

    欢迎访问本人博客查看原文 http wangnan tech 访问http rabbitmq安装IP 15672 帐号guest 密码guest 也可以使用自己创建的帐号 登录后访问http rabbitmq安装IP 15672 queue
  • Error:() java: 程序包org.springframework.beans.factory.annotation不存在

    写在前面 很重要 idea的2019 2020版本确实是存在很多bug的 我也踩过几个坑 我推荐使用idea2018 1 8版本 这个版本比较稳定 我暂时没遇到什么bug 其实遇到这个bug我很高兴 因为之前就出现过这个bug 当时公司前辈
  • 华为30道Python面试题总结

    Python是目前编程领域最受欢迎的语言 在本文中 我将总结华为 阿里巴巴等互联网公司Python面试中最常见的30个问题 每道题都提供参考答案 希望能够帮助你在求职面试中脱颖而出 找到一份高薪工作 这些面试题涉及Python基础知识 Py
  • hutool json转map_记一个Jackson与Hutool混用的坑

    技术公众号 Java In Mind Java In Mind 欢迎关注 问题出现 最近遇到一个问题 Hutool从4 1 7升级到4 6 8之后 使用feign调用出现错误 Caused by feign codec EncodeExce
  • CXF java.lang.RuntimeException: Cannot create a secure XMLInputFactory

    刚开始接触cxf 照着网上的例子写了一个demo 在测试 编写客户端访问服务运行的时候后台报了 CXF java lang RuntimeException Cannot create a secure XMLInputFactory 的错
  • Android gradle配置抽取合并

    一 为什么要合并 当项目中model或library变多过后 比如用到组件化或者引入第三方库需要配置多个build gradle文件 一旦需要统一其SDK或者其他组件版本就需要同时修改多个文件 这确实很麻烦 所以抽取gradle配置非常有必
  • JAVA单元测试框架-9-testng.xml管理依赖

    在testng xml里配置依赖管理 先写个测试用例 Test description 测试分组 groups operation public void TestGroupAdd System out print String value
  • 对七牛云的简单了解

    一 初识七牛云 1 概述 七牛云是国内领先的企业级公有云服务商 致力于打造以数据为核心的场景化PaaS服务 围绕富媒体场景 七牛先后推出了对象存储 融合CDN加速 数据通用处理 内容反垃圾服务 以及直播云服务等 通俗来讲七牛云就是一个服务器
  • UE4基础学习笔记——— 材质编辑器04

    材质实例 原理 不用在原父级材质编辑器中去调节材质 我们把重要的调节值设置为 转换为参数 将材质实例化 要修改只要修改参数即可 选择父级材质右键 创建材质实例 注意标识颜色是 深绿 在实例编辑界面中 出现了之前设置的可变参数 材质实例化方便
  • 《Java Web程序设计——开发环境的搭建》

    Java Web程序设计 开发环境的搭建 一 前言 这里主要分享一下我搭建开发环境的过程以及遇到的问题 其中涉及到的软件都可以从官网获取 若官方访问过慢也可到镜像网站或者下面分享的网盘链接中下载 软件安装路径尽量不要有中文 否则可能会报错
  • 试题 算法训练 拿金币-蓝桥杯

    这里的关键字仍然是动态规划 动态规划核心 拆分子 记住过往 减少重复计算 计算结果 1 不难发现 对于某个确定的路径上的特定位置上的金币总数 总是由该位置的上方的值或左边的值确定的 所以遍历数组位置的上方和左边的 再 比较 递加 就能计算出
  • K8S之资源管理

    文章目录 一 K8S中的资源 二 YAML语言 三 资源管理方式 一 命令式对象管理 二 命令式对象配置 三 声明式对象配置 一 K8S中的资源 在kuberbnetes中 所有的内容都抽象为资源 用户需要通过操作资源来管理kubernet
  • 可视化埋点方案和实践-PC-WEB端(一)

    目录 一 什么是可视化埋点 1 圈选 点选 即标记页面元素 的逻辑代码 2 捕获监听标记的元素的逻辑代码 二 遇到的坑 1 标记元素兼容性难 2 监听难 三 优点 1 方便了测试人员和运营人员 2 埋点的变更是即时的 不需要更新系统代码 3
  • 【Graph Neural Network】 GraphSAGE 基本原理与tensorflow2.0实现

    文章目录 GraphSAGE 前向传播算法 采样算法 聚合 aggragator 操作 参数学习 基于tensorflow2 0实现Graph SAGE GCN是一种利用图结构和邻居顶点属性信息学习顶点Embedding表示的方法 GCN是