GraphSAGE 节点分类训练
GraphSAGE 是一种图神经网络 (GNN) 架构,可用作监督算法来预测图中节点的类别标签。本节提供了如何使用 GraphSAGE 端点通过 Neo4j Snowflake 图分析来训练节点分类模型的说明。
语法
本节介绍用于执行 GraphSAGE 节点分类训练算法的语法。
CALL graph.gs_nc_train(
'CPU_X64_XS', (1)
{
['defaultTablePrefix': '...',] (2)
'project': {...}, (3)
'compute': {...}, (4)
}
);
1 | 计算池选择器。 |
2 | 表引用可选前缀。 |
3 | 项目配置。 |
4 | 计算配置。 |
名称 | 类型 | 默认值 | 可选 | 描述 |
---|---|---|---|---|
computePoolSelector |
字符串 |
|
否 |
用于运行 GraphSAGE 节点分类训练作业的计算池选择器。 |
configuration |
映射 |
|
否 |
图项目、算法计算和结果回写配置。 |
对于此算法,我们强烈建议使用 GPU 计算池,除非数据集非常小且模型较浅。
配置映射包含以下三个条目。
有关以下项目配置的更多详细信息,请参阅项目文档。 |
名称 | 类型 |
---|---|
nodeTables |
节点表列表。 |
relationshipTables |
关系类型到关系表的映射。 |
请注意,为了使 GraphSAGE 能够正确传播节点嵌入的更新,每种类型的节点都必须是至少一种关系类型的目标。`orientation` 参数可用于为仅是关系源的节点类型添加反向关系(使用“REVERSE”或“UNDIRECTED”方向)。
名称 | 类型 | 默认值 | 可选 | 描述 |
---|---|---|---|---|
target_label |
字符串 |
|
否 |
要训练以进行预测的节点标签(即类型) |
target_property |
字符串 |
|
否 |
要训练以进行预测的节点属性,由指定“target_label”的输入节点表中的列表示 |
modelname |
字符串 |
|
否 |
要训练的模型名称(必须唯一) |
numEpochs |
整数 |
|
否 |
训练模型的 epoch 数量 |
numSamples |
整数列表 |
|
否 |
每层采样的邻居数量。请注意,这也决定了层数 |
hiddenChannels |
整数 |
|
是 |
模型层输出的节点嵌入维度 |
activation |
字符串 |
|
是 |
要使用的激活函数。有效值为“relu”和“sigmoid” |
aggregator |
字符串 |
|
是 |
要使用的邻域嵌入聚合器。有效值为“mean”和“max” |
learningRate |
浮点数 |
|
是 |
优化器的学习率 |
dropout |
浮点数 |
|
是 |
每层的 dropout 概率。必须是 >= 0.0 且 < 1.0 的值 |
layerNormalization |
布尔值 |
|
是 |
是否在模型层之间应用层归一化 |
epochsPerCheckpoint |
整数 |
|
是 |
保存模型检查点之间的 epoch 数量 |
randomSeed |
整数 |
|
是 |
用于为计算的所有随机性播种的数字 |
split_ratios |
映射 |
|
是 |
将输入图的目标节点拆分为训练集、测试集和验证集的比率映射。键必须是“TRAIN”、“TEST”和“VALID”。值的总和必须为 1.0 |
epochs_per_val |
整数 |
|
是 |
在验证集上评估模型之间的 epoch 数量。如果设置为 0,则模型不会在验证集上进行评估 |
train_batch_size |
整数 |
|
是 |
每个批次要训练的目标节点数量。如果未提供,算法将自动推断在可用内存限制内的最大允许批次大小 |
eval_batch_size |
整数 |
|
是 |
用于评估的批次大小 |
class_weights |
布尔值或映射 |
|
是 |
是否使用类别权重来平衡训练数据。如果设置为 true,则将根据训练集中目标标签的分布计算类别权重。如果设置为映射,则该映射必须包含每个目标类别标签的类别权重 |
示例
在此示例中,我们将使用包含演员、导演、电影和类型(流派)的 IMDB 数据集。所有这些都关联有关键词,我们将把这些关键词用作节点的特征。它们通过关系连接,其中演员出演电影,导演执导电影。目标是预测电影的类型。
我们有一个名为 `imdb` 的数据库,其中包含以下表:
-
`actor` 表,包含列 `nodeid` 和 `plot_keywords`
-
`movie` 表,包含列 `nodeid`、`plot_keywords` 和 `genre`
-
`director` 表,包含列 `nodeid` 和 `plot_keywords`
-
`acted_in` 表,包含列 `sourcenodeid` 和 `targetnodeid`,分别表示 `actor` 和 `movie` 节点 ID
-
`directed_in` 表,包含列 `sourcenodeid` 和 `targetnodeid`,分别表示 `director` 和 `movie` 节点 ID
`plot_keywords` 列包含与节点关联的关键词,编码为浮点数向量。`genre` 列包含电影节点的目标类别标签,这是我们想要预测的内容。
您可以按照 GitHub 上的说明将此数据集上传到您的 Snowflake 账户:neo4j-product-examples/snowflake-graph-analytics。
训练查询
在以下查询中,我们对数据集上的 GraphSAGE 模型进行节点分类训练。我们训练 10 个 epoch,使用两个隐藏层,并使用类别权重来平衡类别分布。
要运行此查询,需要为应用程序、您的消费者角色和您的环境设置必要的权限。有关更多信息,请参阅入门页面。
我们还假设应用程序名称是默认的 Neo4j_Graph_Analytics。如果您在安装过程中选择了不同的应用程序名称,请将其替换为该名称。
CALL Neo4j_Graph_Analytics.graph.gs_nc_train('GPU_NV_S', {
'defaultTablePrefix': 'imdb.gml',
'project': {
'nodeTables': ['actor', 'director', 'movie'],
'relationshipTables': {
'acted_in': {
'sourceTable': 'actor',
'targetTable': 'movie',
'orientation': 'UNDIRECTED'
},
'directed_in': {
'sourceTable': 'director',
'targetTable': 'movie',
'orientation': 'UNDIRECTED'
}
}
},
'compute': {
'modelname': 'nc-imdb',
'numEpochs': 10,
'numSamples': [20, 20],
'targetLabel': 'movie',
'targetProperty': 'genre',
'classWeights': true
}
});
上述查询应产生与以下结果相似的内容。数值结果可能有所不同。
作业 ID |
作业开始时间 |
作业结束时间 |
作业结果 |
job_63b8083fc8ef463ab38cd95d2ac345ea |
2025-04-29 12:06:28.791 |
2025-04-29 12:07:10.318 |
{ "metrics": { "test_acc": 0.7441860437393188, "test_f1_macro": 0.7236689925193787, "test_f1_micro": 0.7441860437393188, "train_acc": 0.9911160469055176, "train_f1_macro": 0.9900508522987366, "train_f1_micro": 0.9911160469055176 } } |