GraphSAGE 节点分类训练

GraphSAGE 是一种图神经网络 (GNN) 架构,可用作监督算法来预测图中节点的类别标签。本节提供了如何使用 GraphSAGE 端点通过 Neo4j Snowflake 图分析来训练节点分类模型的说明。

语法

本节介绍用于执行 GraphSAGE 节点分类训练算法的语法。

运行 GraphSAGE 节点分类训练。
CALL graph.gs_nc_train(
  'CPU_X64_XS',                    (1)
  {
    ['defaultTablePrefix': '...',] (2)
    'project': {...},              (3)
    'compute': {...},              (4)
  }
);
1 计算池选择器。
2 表引用可选前缀。
3 项目配置。
4 计算配置。
表 1. 参数
名称 类型 默认值 可选 描述

computePoolSelector

字符串

不适用

用于运行 GraphSAGE 节点分类训练作业的计算池选择器。

configuration

映射

{}

图项目、算法计算和结果回写配置。

对于此算法,我们强烈建议使用 GPU 计算池,除非数据集非常小且模型较浅。

配置映射包含以下三个条目。

有关以下项目配置的更多详细信息,请参阅项目文档
表 2. 项目配置
名称 类型

nodeTables

节点表列表。

relationshipTables

关系类型到关系表的映射。

请注意,为了使 GraphSAGE 能够正确传播节点嵌入的更新,每种类型的节点都必须是至少一种关系类型的目标。`orientation` 参数可用于为仅是关系源的节点类型添加反向关系(使用“REVERSE”或“UNDIRECTED”方向)。

表 3. 计算配置
名称 类型 默认值 可选 描述

target_label

字符串

不适用

要训练以进行预测的节点标签(即类型)

target_property

字符串

不适用

要训练以进行预测的节点属性,由指定“target_label”的输入节点表中的列表示

modelname

字符串

不适用

要训练的模型名称(必须唯一)

numEpochs

整数

不适用

训练模型的 epoch 数量

numSamples

整数列表

不适用

每层采样的邻居数量。请注意,这也决定了层数

hiddenChannels

整数

256

模型层输出的节点嵌入维度

activation

字符串

“relu”

要使用的激活函数。有效值为“relu”和“sigmoid”

aggregator

字符串

“mean”

要使用的邻域嵌入聚合器。有效值为“mean”和“max”

learningRate

浮点数

0.001

优化器的学习率

dropout

浮点数

0.1

每层的 dropout 概率。必须是 >= 0.0 且 < 1.0 的值

layerNormalization

布尔值

true

是否在模型层之间应用层归一化

epochsPerCheckpoint

整数

max(numEpochs / 10, 1)

保存模型检查点之间的 epoch 数量

randomSeed

整数

一个随机整数

用于为计算的所有随机性播种的数字

split_ratios

映射

{"TRAIN": 0.6, "TEST": 0.2, "VALID": 0.2}

将输入图的目标节点拆分为训练集、测试集和验证集的比率映射。键必须是“TRAIN”、“TEST”和“VALID”。值的总和必须为 1.0

epochs_per_val

整数

0

在验证集上评估模型之间的 epoch 数量。如果设置为 0,则模型不会在验证集上进行评估

train_batch_size

整数

自动推断

每个批次要训练的目标节点数量。如果未提供,算法将自动推断在可用内存限制内的最大允许批次大小

eval_batch_size

整数

训练批次大小

用于评估的批次大小

class_weights

布尔值或映射

false

是否使用类别权重来平衡训练数据。如果设置为 true,则将根据训练集中目标标签的分布计算类别权重。如果设置为映射,则该映射必须包含每个目标类别标签的类别权重

示例

在此示例中,我们将使用包含演员、导演、电影和类型(流派)的 IMDB 数据集。所有这些都关联有关键词,我们将把这些关键词用作节点的特征。它们通过关系连接,其中演员出演电影,导演执导电影。目标是预测电影的类型。

我们有一个名为 `imdb` 的数据库,其中包含以下表:

  • `actor` 表,包含列 `nodeid` 和 `plot_keywords`

  • `movie` 表,包含列 `nodeid`、`plot_keywords` 和 `genre`

  • `director` 表,包含列 `nodeid` 和 `plot_keywords`

  • `acted_in` 表,包含列 `sourcenodeid` 和 `targetnodeid`,分别表示 `actor` 和 `movie` 节点 ID

  • `directed_in` 表,包含列 `sourcenodeid` 和 `targetnodeid`,分别表示 `director` 和 `movie` 节点 ID

`plot_keywords` 列包含与节点关联的关键词,编码为浮点数向量。`genre` 列包含电影节点的目标类别标签,这是我们想要预测的内容。

您可以按照 GitHub 上的说明将此数据集上传到您的 Snowflake 账户:neo4j-product-examples/snowflake-graph-analytics

训练查询

在以下查询中,我们对数据集上的 GraphSAGE 模型进行节点分类训练。我们训练 10 个 epoch,使用两个隐藏层,并使用类别权重来平衡类别分布。

要运行此查询,需要为应用程序、您的消费者角色和您的环境设置必要的权限。有关更多信息,请参阅入门页面。

我们还假设应用程序名称是默认的 Neo4j_Graph_Analytics。如果您在安装过程中选择了不同的应用程序名称,请将其替换为该名称。

CALL Neo4j_Graph_Analytics.graph.gs_nc_train('GPU_NV_S', {
    'defaultTablePrefix': 'imdb.gml',
    'project': {
        'nodeTables': ['actor', 'director', 'movie'],
        'relationshipTables': {
            'acted_in': {
                'sourceTable': 'actor',
                'targetTable': 'movie',
                'orientation': 'UNDIRECTED'
            },
            'directed_in': {
                'sourceTable': 'director',
                'targetTable': 'movie',
                'orientation': 'UNDIRECTED'
            }
        }
    },
    'compute': {
        'modelname': 'nc-imdb',
        'numEpochs': 10,
        'numSamples': [20, 20],
        'targetLabel': 'movie',
        'targetProperty': 'genre',
        'classWeights': true
    }
});

上述查询应产生与以下结果相似的内容。数值结果可能有所不同。

作业 ID

作业开始时间

作业结束时间

作业结果

job_63b8083fc8ef463ab38cd95d2ac345ea

2025-04-29 12:06:28.791

2025-04-29 12:07:10.318

{ "metrics": { "test_acc": 0.7441860437393188, "test_f1_macro": 0.7236689925193787, "test_f1_micro": 0.7441860437393188, "train_acc": 0.9911160469055176, "train_f1_macro": 0.9900508522987366, "train_f1_micro": 0.9911160469055176 } }

© . All rights reserved.