图注意力网络(GAT) TensorFlow解析

2023-05-16

论文

图注意力网络来自 Graph Attention Networks,ICLR 2018. https://arxiv.org/abs/1710.10903

注意力机制

在这里插入图片描述

在这里插入图片描述

代码

import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras import activations
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers


class GraphAttentionLayer(keras.layers.Layer):
    def compute_output_signature(self, input_signature):
        pass

    def __init__(self,
                 input_dim,
                 output_dim,
                 adj,
                 nodes_num,
                 dropout_rate=0.0,
                 activation=None,
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 coef_dropout=0.0,
                 **kwargs):
        """
        :param input_dim: 输入的维度
        :param output_dim: 输出的维度,不等于input_dim
        :param adj: 具有自环的tuple类型的邻接表[coords, values, shape], 可以采用sp.coo_matrix生成
        :param nodes_num: 点数量
        :param dropout_rate: 丢弃率,防过拟合,默认0.5
        :param activation: 激活函数
        :param use_bias: 偏移,默认True
        :param kernel_initializer: 权值初始化方法
        :param bias_initializer: 偏移初始化方法
        :param kernel_regularizer: 权值正则化
        :param bias_regularizer: 偏移正则化
        :param activity_regularizer: 输出正则化
        :param kernel_constraint: 权值约束
        :param bias_constraint: 偏移约束
        :param coef_dropout: 互相关系数丢弃,默认0.0
        :param kwargs:
        """
        super(GraphAttentionLayer, self).__init__()
        self.activation = activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.support = [tf.SparseTensor(indices=adj[0][0], values=adj[0][1], dense_shape=adj[0][2])]
        self.dropout_rate = dropout_rate
        self.coef_drop = coef_dropout
        self.nodes_num = nodes_num
        self.kernel = None
        self.mapping = None
        self.bias = None

    def build(self, input_shape):
        """
        只执行一次
        """
        self.kernel = self.add_weight(shape=(self.input_dim, self.output_dim),
                                      initializer=self.kernel_initializer,
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint,
                                      trainable=True)

        if self.use_bias:
            self.bias = self.add_weight(shape=(self.nodes_num, self.output_dim),
                                        initializer=self.kernel_initializer,
                                        regularizer=self.kernel_regularizer,
                                        constraint=self.kernel_constraint,
                                        trainable=True)
        print('[GAT LAYER]: GAT W & b built.')

    def call(self, inputs, training=True):
        # 完成输入到输出的映射关系
        # inputs = tf.nn.l2_normalize(inputs, 1)
        raw_shape = inputs.shape
        inputs = tf.reshape(inputs, shape=(1, raw_shape[0], raw_shape[1]))  # (1, nodes_num, input_dim)
        mapped_inputs = keras.layers.Conv1D(self.output_dim, 1, use_bias=False)(inputs)  # (1, nodes_num, output_dim)
        # mapped_inputs = tf.nn.l2_normalize(mapped_inputs)

        sa_1 = keras.layers.Conv1D(1, 1)(mapped_inputs)  # (1, nodes_num, 1)
        sa_2 = keras.layers.Conv1D(1, 1)(mapped_inputs)  # (1, nodes_num, 1)

        con_sa_1 = tf.reshape(sa_1, shape=(raw_shape[0], 1))  # (nodes_num, 1)
        con_sa_2 = tf.reshape(sa_2, shape=(raw_shape[0], 1))  # (nodes_num, 1)

        con_sa_1 = tf.cast(self.support[0], dtype=tf.float32) * con_sa_1  # (nodes_num, nodes_num) W_hi
        con_sa_2 = tf.cast(self.support[0], dtype=tf.float32) * tf.transpose(con_sa_2, [1, 0])  # (nodes_num, nodes_num) W_hj

        weights = tf.sparse.add(con_sa_1, con_sa_2)  # concatenation
        weights_act = tf.SparseTensor(indices=weights.indices,
                                      values=tf.nn.leaky_relu(weights.values),
                                      dense_shape=weights.dense_shape)  # 注意力互相关系数
        attention = tf.sparse.softmax(weights_act)  # 输出注意力机制
        inputs = tf.reshape(inputs, shape=raw_shape)
        if self.coef_drop > 0.0:
            attention = tf.SparseTensor(indices=attention.indices,
                                        values=tf.nn.dropout(attention.values, self.coef_dropout),
                                        dense_shape=attention.dense_shape)
        if training and self.dropout_rate > 0.0:
            inputs = tf.nn.dropout(inputs, self.dropout_rate)
        if not training:
            print("[GAT LAYER]: GAT not training now.")

        attention = tf.sparse.reshape(attention, shape=[self.nodes_num, self.nodes_num])
        value = tf.matmul(inputs, self.kernel)
        value = tf.sparse.sparse_dense_matmul(attention, value)

        if self.use_bias:
            ret = tf.add(value, self.bias)
        else:
            ret = tf.reshape(value, (raw_shape[0], self.output_dim))
        return self.activation(ret)

参考

https://blog.csdn.net/weixin_36474809/article/details/89401552

https://github.com/PetarV-/GAT

更多内容访问 omegaxyz.com
网站所有代码采用Apache 2.0授权
网站文章采用知识共享许可协议BY-NC-SA4.0授权
© 2020 • OmegaXYZ-版权所有 转载请注明出处

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

图注意力网络(GAT) TensorFlow解析 的相关文章

随机推荐

  • Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.

    Graphviz的可执行文件 http www graphviz org Download windows PHP 参考 xff1a http blog csdn net u014749291 article details 5489108
  • 计算机保研-中科院计算所霸面(笔试面试)

    基本情况 xff1a 学校 xff1a 末流211 排名 xff1a 1 70 绩点 xff1a 4 33 5 0 竞赛 xff1a 无ACM xff0c 有某水赛国奖 xff08 中国人工智能学会主办 xff09 科研 xff1a 一篇水
  • 计算机保研-中科大计算机

    Abstract 2019年中科大计算机夏令营比往年增加了不少难度 xff0c 统一增加了机试环节 xff0c 面试难度提高 xff08 陈恩红实验室和李向阳实验室向来包含机试 xff09 xff0c 最终录取率在60 左右 xff08 往
  • NSGA-II资料合集

    关于NSGA II的一些资料 NSGA II中文翻译 MATLAB代码 NSGA II的解释 简介 关于演化计算 生物系统中 xff0c 进化被认为是一种成功的自适应方法 xff0c 具有很好的健壮性 基本思想 xff1a 达尔文进化论是一
  • 简单区块链Python实现

    什么是区块链 区块链是一种数据结构 xff0c 也是一个分布式数据库 从技术上来看 xff1a 区块是一种记录交易的数据结构 xff0c 反映了一笔交易的资金流向 系统中已经达成的交易的区块连接在一起形成了一条主链 xff0c 所有参与计算
  • 复旦大学计算机保研夏令营

    Abstract 复旦的夏令营 xff1a 自由而无用 xff0c 一期招了200人入营 xff0c 不提供住宿 xff08 导致我租了个旅馆每天要骑单车来学校 xff0c 不过沿途环境不错 xff0c 有很多吃的地方 xff09 xff0
  • 计算机保研夏令营预推免

    夏令营与预推免个人情况 学校 xff1a 末流211 xff08 安徽大学 xff09 排名 xff1a 1 70绩点 xff1a 4 33 5 0竞赛 xff1a 无ACM xff0c 有某水赛国奖 xff08 中国人工智能学会主办 xf
  • 知识图谱嵌入的应用场景

    In KG应用 xff08 在 KG 范围内的应用 xff09 链接预测 xff08 Link prediction xff09 链接预测任务有时也称为实体预测或实体排序 xff0c 用来预测两个实体之间是否有特定的关系 即已知头实体h和关
  • Neo4j数据导入与可视化

    本文共1262个字 xff0c 预计阅读时间需要5分钟 简介 Neo4j是一个高性能的NoSQL图形数据库 xff0c 它将结构化数据存储在网络上而不是表中 它是一个嵌入式的 基于磁盘的 具备完全的事务特性的Java持久化引擎 xff0c
  • 用户身份链接方法——DeepLink

    论文 xff1a DeepLink A Deep Learning Approach for User Identity Linkage UIL xff08 User Identity Linkage xff09 xff1a 用户身份链接
  • 可视化图布局算法简介

    Fruchterman Reingold FR FR算法将所有的结点看做是电子 xff0c 每个结点收到两个力的作用 xff1a 其他结点的库伦力 xff08 斥力 xff09 f a d
  • Windows无法连接到打印机怎么办?快收藏这些正确做法!

    案例 xff1a Windows无法连接到打印机怎么办 xff1f 朋友们朋友们 xff0c 最近为了备考国考 xff0c 我特地买了个打印机回来打印资料 xff0c 但是我的Windows无法连接到打印机 xff0c 这是为什么呢 xff
  • Python爬虫Scrapy入门

    Scrapy组成 Scrapy是Python开发的一个快速 高层次的屏幕抓取和web抓取框架 xff0c 用于抓取web站点并从页面中提取结构化的数据 引擎 xff08 Scrapy xff09 xff1a 用来处理整个系统的数据流 xff
  • Mac下终端pip与pip3配置(软链接)

    缘起 今日Mac上的Python环境绝对是个asshole 系统自带一个Python2 7我官网下载一个3 6homebrew悄悄下了个3 xanaconda自带了一个3 x前天更新了一下Xcode命令行工具 xff0c 竟然给我偷偷下了个
  • 推荐系统摘要

    作为一个推荐系统的门外汉 xff0c 或者说是用户 xff0c 我觉得推荐系统有以下几个特性 推荐系统的真实目的并不是做到让用户满意 xff0c 而是提高销售能力 xff0c 业务水平和收益 一个好的推荐系统并不是推荐用户最喜爱 想要的东西
  • 数据分析岗位面试必备

    业务逻辑 数据分析遵循一定的流程 xff0c 不仅可以保证数据分析每一个阶段的工作内容有章可循 xff0c 而且还可以让分析最终的结果更加准确 xff0c 更加有说服力 一般情况下 xff0c 数据分析分为以下几个步骤 xff1a 业务理解
  • 基于LDA的文本主题聚类Python实现

    LDA简介 LDA xff08 Latent Dirichlet Allocation xff09 是一种文档主题生成模型 xff0c 也称为一个三层贝叶斯概率模型 xff0c 包含词 主题和文档三层结构 所谓生成模型 xff0c 就是说
  • Neo4j-import导入CSV的数据

    本文共1215个字 xff0c 预计阅读时间需要4分钟 最近有个上亿个关系 节点的数据需要导入到Neo4j xff0c 有以下几个工具可以导入 xff1a Cypher CREATE 语句 xff0c 为每一条数据写一个CREATECyph
  • Ajax与jQuery异步加载数据

    本文共1096个字 xff0c 预计阅读时间需要4分钟 简介 一次性从服务器数据库中读取数据并传送到前端页面上是不现实的 xff0c 一方面会加重服务器的压力 xff0c 另一方面客户的带宽资源也会被占用 Ajax刚好可以解决数据异步加载的
  • 图注意力网络(GAT) TensorFlow解析

    论文 图注意力网络来自 Graph Attention Networks xff0c ICLR 2018 https arxiv org abs 1710 10903 注意力机制 代码 span class token keyword im