图神经网络 PyTorch Geometric 入门教程

2023-11-06

简介

Graph Neural Networks 简称 GNN,称为图神经网络,是深度学习中近年来一个比较受关注的领域。近年来 GNN 在学术界受到的关注越来越多,与之相关的论文数量呈上升趋势,GNN 通过对信息的传递,转换和聚合实现特征的提取,类似于传统的 CNN,只是 CNN 只能处理规则的输入,如图片等输入的高、宽和通道数都是固定的,而 GNN 可以处理不规则的输入,如点云等。 可查看【GNN】万字长文带你入门 GCN

而 PyTorch Geometric Library (简称 PyG) 是一个基于 PyTorch 的图神经网络库,地址是:https://github.com/rusty1s/pytorch_geometric。它包含了很多 GNN 相关论文中的方法实现和常用数据集,并且提供了简单易用的接口来生成图,因此对于复现论文来说也是相当方便。用法大多数和 PyTorch 很相近,因此熟悉 PyTorch 的同学使用这个库可以很快上手。

torch_geometric.data.Data

节点和节点之间的边构成了图。所以在 PyG 中,如果你要构建图,那么需要两个要素:节点和边。PyG 提供了torch_geometric.data.Data (下面简称Data) 用于构建图,包括 5 个属性,每一个属性都不是必须的,可以为空。

  • x: 用于存储每个节点的特征,形状是[num_nodes, num_node_features]
  • edge_index: 用于存储节点之间的边,形状是 [2, num_edges]
  • pos: 存储节点的坐标,形状是[num_nodes, num_dimensions]
  • y: 存储样本标签。如果是每个节点都有标签,那么形状是[num_nodes, *];如果是整张图只有一个标签,那么形状是[1, *]
  • edge_attr: 存储边的特征。形状是[num_edges, num_edge_features]

实际上,Data对象不仅仅限制于这些属性,我们可以通过data.face来扩展Data,以张量保存三维网格中三角形的连接性。

需要注意的的是,在Data里包含了样本的 label,这意味和 PyTorch 稍有不同。在PyTorch中,我们重写Dataset__getitem__(),根据 index 返回对应的样本和 label。在 PyG 中,我们使用的不是这种写法,而是在get()函数中根据 index 返回torch_geometric.data.Data类型的数据,在Data里包含了数据和 label。

下面一个例子是未加权无向图 ( 未加权指边上没有权值 ),包括 3 个节点和 4 条边。


由于是无向图,因此有 4 条边:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)。每个节点都有自己的特征。上面这个图可以使用`torch_geometric.data.Data`来表示如下:
import torch
from torch_geometric.data import Data
# 由于是无向图,因此有 4 条边:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
# 节点的特征                           
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)

注意edge_index中边的存储方式,有两个list,第 1 个list是边的起始点,第 2 个list是边的目标节点。注意与下面的存储方式的区别。

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous())

这种情况edge_index需要先转置然后使用contiguous()方法。关于contiguous()函数的作用,查看 PyTorch中的contiguous

最后再复习一遍,Data中最基本的 4 个属性是xedge_indexposy,我们一般都需要这 4 个参数。

有了Data,我们可以创建自己的Dataset,读取并返回Data了。

Dataset 与 DataLoader

PyG 的 Dataset继承自torch.utils.data.Dataset,自带了很多图数据集,我们以TUDataset为例,通过以下代码就可以加载数据集,root参数设置数据下载的位置。通过索引可以访问每一个数据。

from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
data = dataset[0]

在一个图中,由edge_indexedge_attr可以决定所有节点的邻接矩阵。PyG 通过创建稀疏的对角邻接矩阵,并在节点维度中连接特征矩阵和 label 矩阵,实现了在 mini-batch 的并行化。PyG 允许在一个 mini-batch 中的每个Data (图) 使用不同数量的节点和边。


自定义 Dataset

尽管 PyG 已经包含许多有用的数据集,我们也可以通过继承torch_geometric.data.Dataset使用自己的数据集。提供 2 种不同的Dataset

  • InMemoryDataset:使用这个Dataset会一次性把数据全部加载到内存中。
  • Dataset: 使用这个Dataset每次加载一个数据到内存中,比较常用。

我们需要在自定义的Dataset的初始化方法中传入数据存放的路径,然后 PyG 会在这个路径下再划分 2 个文件夹:

  • raw_dir: 存放原始数据的路径,一般是 csv、mat 等格式
  • processed_dir: 存放处理后的数据,一般是 pt 格式 ( 由我们重写process()方法实现)。

在 PyTorch 中,是没有这两个文件夹的。下面来说明一下这两个文件夹在 PyG 中的实际意义和处理逻辑。

torch_geometric.data.Dataset继承自torch.utils.data.Dataset,在初始化方法 __init__()中,会调用_download()方法和_process()方法。

def __init__(self, root=None, transform=None, pre_transform=None,
			 pre_filter=None):
	super(Dataset, self).__init__()

	if isinstance(root, str):
		root = osp.expanduser(osp.normpath(root))

	self.root = root
	self.transform = transform
	self.pre_transform = pre_transform
	self.pre_filter = pre_filter
	self.__indices__ = None

	# 执行 self._download() 方法
	if 'download' in self.__class__.__dict__.keys():
		self._download()
    # 执行 self._process() 方法
	if 'process' in self.__class__.__dict__.keys():
		self._process()

_download()方法如下,首先检查self.raw_paths列表中的文件是否存在;如果存在,则返回;如果不存在,则调用self.download()方法下载文件。

def _download(self):
	if files_exist(self.raw_paths):  # pragma: no cover
		return

	makedirs(self.raw_dir)
	self.download()

_process()方法如下,首先在self.processed_dir中有pre_transform,那么判断这个pre_transform和传进来的pre_transform是否一致,如果不一致,那么警告提示用户先删除self.processed_dir文件夹。pre_filter同理。

然后检查self.processed_paths列表中的文件是否存在;如果存在,则返回;如果不存在,则调用self.process()生成文件。

def _process(self):
	f = osp.join(self.processed_dir, 'pre_transform.pt')
	if osp.exists(f) and torch.load(f) != __repr__(self.pre_transform):
		warnings.warn(
			'The `pre_transform` argument differs from the one used in '
			'the pre-processed version of this dataset. If you really '
			'want to make use of another pre-processing technique, make '
			'sure to delete `{}` first.'.format(self.processed_dir))
	f = osp.join(self.processed_dir, 'pre_filter.pt')
	if osp.exists(f) and torch.load(f) != __repr__(self.pre_filter):
		warnings.warn(
			'The `pre_filter` argument differs from the one used in the '
			'pre-processed version of this dataset. If you really want to '
			'make use of another pre-fitering technique, make sure to '
			'delete `{}` first.'.format(self.processed_dir))

	if files_exist(self.processed_paths):  # pragma: no cover
		return

	print('Processing...')

	makedirs(self.processed_dir)
	self.process()

	path = osp.join(self.processed_dir, 'pre_transform.pt')
	torch.save(__repr__(self.pre_transform), path)
	path = osp.join(self.processed_dir, 'pre_filter.pt')
	torch.save(__repr__(self.pre_filter), path)

	print('Done!')

一般来说不用实现downloand()方法

如果你直接把处理好的 pt 文件放在了self.processed_dir中,那么也不用实现process()方法。

在 Pytorch 的dataset中,我们需要实现__getitem__()方法,根据index返回样本和标签。在这里torch_geometric.data.Dataset中,重写了__getitem__()方法,其中调用了get()方法获取数据。

def __getitem__(self, idx):
	if isinstance(idx, int):
		data = self.get(self.indices()[idx])
		data = data if self.transform is None else self.transform(data)
		return data
	else:
		return self.index_select(idx)

我们需要实现的是get()方法,根据index返回torch_geometric.data.Data类型的数据。

process()方法存在的意义是原始的格式可能是 csv 或者 mat,在process()函数里可以转化为 pt 格式的文件,这样在get()方法中就可以直接使用torch.load()函数读取 pt 格式的文件,返回的是torch_geometric.data.Data类型的数据,而不用在get()方法做数据转换操作 (把其他格式的数据转换为 torch_geometric.data.Data类型的数据)。当然我们也可以提前把数据转换为 torch_geometric.data.Data类型,使用 pt 格式保存在self.processed_dir中。

DataLoader

通过torch_geometric.data.DataLoader可以方便地使用 mini-batch。

from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:
	# 对每一个 mini-batch 进行操作
	...

torch_geometric.data.Batch继承自torch_geometric.data.Data,并且多了一个属性:batchbatch是一个列向量,它将每个元素映射到每个 mini-batch 中的相应图:

batch = [ 0 ⋯ 0 1 ⋯ n − 2 n − 1 ⋯ n − 1 ] ⊤ =\left[\begin{array}{cccccccc}0 & \cdots & 0 & 1 & \cdots & n-2 & n-1 & \cdots & n-1\end{array}\right]^{\top} =[001n2n1n1]

我们可以使用它分别为每个图的节点维度计算平均的节点特征:

from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for data in loader:
    data
    #data: Batch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])

    x = scatter_mean(data.x, data.batch, dim=0)
    # x.size(): torch.Size([32, 21])

关于 batching 的流程细节,你可以点击Pytorch Geometric Documentation查看。关于scatter方法的说明,你可以查看torch-scatter说明文档

Transforms

transforms在计算机视觉领域是一种很常见的数据增强。PyG 有自己的transforms,输出是Data类型,输出也是Data类型。可以使用torch_geometric.transforms.Compose封装一系列的transforms。我们以 ShapeNet 数据集 (包含 17000 个 point clouds,每个 point 分类为 16 个类别的其中一个) 为例,我们可以使用transforms从 point clouds 生成最近邻图:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

还可以通过transform在一定范围内随机平移每个点,增加坐标上的扰动,做数据增强:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6),
                    transform=T.RandomTranslate(0.01))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

模型训练

这里只是展示一个简单的 GCN 模型构造和训练过程,没有用到DatasetDataLoader

我们将使用一个简单的 GCN 层,并在 Cora 数据集上实验。有关 GCN 的更多内容,请查看 关于 GCN 的理解

部分朋友反映:Cora 数据集无法下载。因此我上传到了百度云,并提供下载链接。
链接:https://pan.baidu.com/s/182L84pypm2e-ZXDuiRgFgA
提取码:2479

下载完成后解压,放置到你配置的路径,具体路径是在下面的dataset = Planetoid(root='/tmp/Cora', name='Cora')中的root参数和name参数定义的。

我们首先加载数据集:

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/tmp/Cora', name='Cora')

然后定义 2 层的 GCN:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

然后训练 200 个 epochs:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

最后在测试集上验证了模型的准确率:

model.eval()
_, pred = model(data).max(dim=1)
correct = float (pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))

至此,关于`Pytorch Geometric`的简单使用教程就讲完了。

回顾一下,在这篇文章中,在讲述使用Pytorch Geometric的过程中,花了较多篇幅分析了图数据是如何表示的,分析了Dataset的工作流程,让你明白图数据在Dataset里都经过了哪些步骤,才得以输入到模型,最终可以利用Dataset来构建自己的数据集。

如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。

我的文章会首发在公众号上,欢迎扫码关注我的公众号张贤同学


本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

图神经网络 PyTorch Geometric 入门教程 的相关文章

随机推荐

  • 一些网站1

    N1BOOK平台 Nu1L Team Nu1L Team 0004 Median of Two Sorted Arrays LeetCode Cookbook 题库 力扣 LeetCode 全球极客挚爱的技术成长平台
  • 解决shell断开后java进程被结束

    偶尔会碰到用SecureCRT在shell启动java进程并后台运行 命令最后加 的时候 因为断电死机等原因断开shell 然后进程被结束了 运维大佬也说用他们的工具启动进程后一断开连接进程就结束了 后来查到是因为shell在断开的时候会向
  • 漫谈数据挖掘从入门到进阶

    做数据挖掘也有些年头了 写这篇文一方面是写篇文 给有个朋友作为数据挖掘方面的参考 另一方面也是有抛砖引玉之意 希望能够和一些大牛交流 相互促进 让大家见笑了 入门 数据挖掘入门的书籍 中文的大体有这些 Jiawei Han的 数据挖掘概念与
  • Day_1 Part_4 Structures of R

    1 Vector Matrix Array 1 1 What are they Collection of observations Vector 1 dimensional Matrix 2 dimensional Array 3 dim
  • 常见web漏洞及防范(转)

    单个漏洞 需要进行排查与整改 借着别人的智慧 做一个简单的收集 最好能够将常见漏洞 不限于web类的 进行一个统一的整理 这是今年的任务 进行漏洞的工具的收集 为未来的工作做好基础 一 SQL注入漏洞 SQL注入攻击 SQL Injecti
  • MMDetection 3.x中的PackDetInputs

    MMDetection 3 X 里面对pipeline有一个重点修改是新增了 PackDetInputs 有利于统一 进行检测 语义分割 全景分割任务 从配置文件中我们可以看出包含LoadImageFromFile LoadAnnotati
  • electron在BrowserWindow中禁止右键菜单

    最近使用 electron vite solid js 做一个网络流量实时监控的小工具 其中需要禁止用户在获取 BrowserWindow 焦点后弹出默认右键菜单 解决方案 在 new BrowserWindow 后中添加以下代码 禁止右键
  • 静默执行bat文件

    让bat隐藏运行需要用vbs文件才能实现 方式一 使用vbs文件 新建一个 文本文档后缀改为 vbs 可以这样写 set ws WScript CreateObject WScript Shell ws Run d yy bat 0 其中d
  • 《区块链技术与应用》学习笔记2——BTC数据结构

    Hash pointer 哈希指针 指针 在程序运行过程中 需要用到数据 最简单的是直接获取数据 但当数据本身较大 需要占用较大空间时 明显会造成一定麻烦 因此可以引用指针 每次获取相应的数据即可 实际使用中 指针实际上存储的是逻辑地址更多
  • C语言用scanf来判断键盘输入数据类型

    2 可以用来判断是否和定义的类型一致 如 int n if scanf d n 1 else 可以用来判断键盘输入的数据是否是整数
  • 打不开微软自带的软件,或者初次安装sql server 提示下载不了,版本不支持的,一定检查一下这里

    检查这里 这里一定要把代理模式关掉 血的教训啊
  • 吉首大学_编译原理实验题_基于预测方法的语法分析程序的设计【通过代码】

    一 实验要求 实验二 基于预测方法的语法分析程序的设计 一 实验目的 了解预测分析器的基本构成及用自顶向下的预测法对表达式进行语法分析的方法 掌握预测语法分析程序的手工构造方法 二 实验内容 1 了解编译程序的基于预测方法的语法分析过程 2
  • GFS论文解读

    文章目录 1 设计概述 1 1 假设 1 2 GFS架构 1 3 读取流程 1 4 元数据 1 5 操作日志 1 6 一致性模型 2 系统交互 2 1 契约机制 2 2 数据写入过程 2 3 数据流 2 4 原子的记录追加 1 设计概述 1
  • Leetcode[链表] 反转链表 -- 双指针法

    0 题目描述 leetcode原题链接 反转链表 1 双指针法 定义两个指针 pre 和 cur pre 在前 cur 在后 每次让 pre 的 next 指向 cur 实现一次局部反转 局部反转完成之后 pre 和 cur 同时往前移动一
  • Kotlin 31. Kotlin 如何删除文件或文件夹

    Kotlin 如何删除文件或文件夹 比如 我们想要删除 Documents 年月日 文件夹下面的所有文件 包括这个文件夹 我们首先需要获得 Documents 的路径 val extDir File Environment getExter
  • ESP32——WIFI

    WiFi Wi Fi 库支持配置及监控 ESP32 Wi Fi 连网功能 WiFi工作模式 基站模式 即 STA 模式或 Wi Fi 客户端模式 此时 ESP32 连接到接入点 AP AP 模式 即 Soft AP 模式或接入点模式 此时基
  • 第二章 构造函数语意学 编译器何时合成拷贝构造函数?

    首先要清楚位拷贝 浅拷贝 和值拷贝 深拷贝 的区别 参考http blog sina com cn s blog a2aa00d70101gpvj html 位拷贝 及 bitwise copy 是指将一个对象的内存映像按位原封不动的复制给
  • STM32开发实例 基于STM32单片机的氛围灯

    一 系统设计 我想做的是个基于WIFI 的智能氛围灯 这个灯用app控制 首先这个灯在APP上面可以选择颜色 注 RGB 和亮度调节 音乐律动模式可跟随手机上播放的音乐改变亮度 光照模式白天关灯晚上开灯 人体感应模式有人时开灯反之关灯 智能
  • CSDN竞赛第45期题解

    CSDN竞赛第45期题解 1 题目名称 勾股数 勾股数是一组三个正整数 它们可以作为直角三角形的三条边 比如3 4 5就是一组勾股数 如果给出一组勾股数其中的 两个 你能找出余下的一个吗 ll a b cin gt gt a gt gt b
  • 图神经网络 PyTorch Geometric 入门教程

    简介 Graph Neural Networks 简称 GNN 称为图神经网络 是深度学习中近年来一个比较受关注的领域 近年来 GNN 在学术界受到的关注越来越多 与之相关的论文数量呈上升趋势 GNN 通过对信息的传递 转换和聚合实现特征的