关系嵌入模型

一种不太罕见的情况可能是,用户在图数据科学 (GDS) 库之外训练了知识图谱嵌入 (KGE) 模型,并将模型训练的输出存储在 Neo4j 数据库中。对于此类情况,GDS 支持使用此类 KGE 模型输出和 KGE 评分函数来推断 GDS 图投影的新关系。目前支持的评分函数是 TransE 和 DistMult。

下面我们将逐步介绍如何使用这些功能。首先查看方法及其签名,然后通过一个小玩具图上的端到端示例。

在下面的示例中,我们假设我们有一个名为 gds 的已实例化的 GraphDataScience 对象。在 入门 中了解更多相关信息。

1. 创建关系嵌入模型

使用预训练的 KGE 模型预测 GDS 中的新关系的工作流程的第一部分是创建关系嵌入模型。

有两种方法可以做到这一点,每种方法对应一个受支持的 KGE 评分函数

  • gds.model.transe.create 用于使用 TransE 评分函数创建模型,以及

  • gds.model.distmult.create 用于使用 DistMult 评分函数创建模型。

这两种方法都返回一个 SimpleRelEmbeddingModel,我们很快就会了解其用法。它们也采用相同的参数

表 1. 基于 KGE 的关系模型创建参数
名称 类型

G

表示模型训练所依据的图的对象

node_embedding_property

str

存储 KGE 模型嵌入的节点属性的名称

relationship_type_embeddings

dict[str, list[float]]

关系类型名称到 KGE 模型的关系类型嵌入的映射

2. 使用关系嵌入模型进行预测

SimpleRelEmbeddingModel 表示基于 KGE 模型的关系嵌入模型。它有三种方法可以预测新的关系。推断新嵌入的计算是相同的,但之后如何处理新关系有所不同。

此类具有三种方法

  • predict_stream 用于流式传输预测的关系,

  • predict_mutate 用于将关系添加到投影图中,

  • predict_write 用于将关系写回 Neo4j 数据库。

由于这些方法中预测部分的计算相同,因此这些方法共享一组参数

表 2. 共享的关系嵌入模型预测参数
名称 类型

source_node_filter

Union[str, int, list[int]]

要考虑的源节点的规范。节点标签、节点 ID 或节点 ID 列表

target_node_filter

Union[str, int, list[int]]

要考虑的源节点的规范。节点标签、节点 ID 或节点 ID 列表

relationship_type

str

将在计算中使用其嵌入的关系类型的名称

top_k

int

为每个源节点生成多少个关系。每个源节点将保留得分最高的 top_k 个目标节点

general_config

**dict[str, Any]

作为可选关键字参数的通用 GDS 算法配置

特别是,此算法支持作为关键字参数的通用算法配置参数为 concurrencyjobIdlogProgress。您可以在 GDS 手册的 此处 了解更多信息。

现在让我们概述这些预测方法之间的差异。

2.1. 流式传输预测的关系

predict_stream 方法返回一个 pandas.DataFrame,其中包含三列:sourceNodeIdtargetNodeIdscore。这些分别指的是源节点 ID、目标节点 ID 以及在节点对和关系类型上运行 KGE 模型评分函数得出的分数。

此方法除了 上面 概述的参数之外,没有其他额外参数。

2.2. 使用预测的关系修改图投影

predict_mutate 方法通过 mutate_relationship_type 参数指定的新类型将预测的关系添加到图投影中。此类关系将具有一个属性(通过 mutateProperty 参数指定),表示在节点对和关系类型上运行 KGE 模型评分函数的输出。该方法返回一个包含计算元数据的 pandas.Series

除了 上面 概述的共享参数之外,此方法还按顺序采用两个位置参数(在 top_k 参数之后)

表 3. .predict_mutate 特定的输入参数
名称 类型

mutate_relationship_type

str

预测关系的新关系类型的名称

mutate_property

str

将存储模型预测分数的新关系上的属性的名称

表 4. .predict_mutate 返回的 pandas.Series 对象的字段
名称 类型

relationshipsWritten

int

创建的关系数

mutateMillis

int

向投影图添加属性所用的毫秒数

postProcessingMillis

int

计算百分位数所用的毫秒数

preProcessingMillis

int

预处理数据所用的毫秒数

computeMillis

int

运行预测算法所用的毫秒数

configuration

dict[str, Any]

用于运行算法的配置

2.3. 将预测的关系写回数据库

predict_write 方法通过 write_relationship_type 参数指定的新类型将预测的关系写回 Neo4j 数据库。此类关系将具有一个属性(通过 writeProperty 参数指定),表示在节点对和关系类型上运行 KGE 模型评分函数的输出。

除了 上面 概述的共享参数之外,此方法还按顺序采用两个位置参数(在 top_k 参数之后)

表 5. .predict_write 特定的输入参数
名称 类型

write_relationship_type

str

预测关系的新关系类型的名称

write_property

str

将存储模型预测分数的新关系上的属性的名称

该方法返回一个包含计算元数据的 pandas.Series

表 6. .predict_write 返回的 pandas.Series 对象的字段
名称 类型

relationshipsWritten

int

创建的关系数

writeMillis

int

将结果数据写回 Neo4j 数据库所用的毫秒数

preProcessingMillis

int

预处理数据所用的毫秒数

computeMillis

int

运行预测算法所用的毫秒数

configuration

dict[str, Any]

用于运行算法的配置

3. 检查关系嵌入模型

SimpleRelEmbeddingModel 类上有一些方法可以让我们检查它。它们都不需要任何输入,只是返回有关模型的信息。它们列在下面。

表 7. 用于检查的 SimpleRelEmbeddingModel 获取器方法
名称 返回类型 描述

scoring_function

str

返回模型正在使用的评分函数的名称

graph_name

str

返回模型所依据的图的名称

node_embedding_property

str

返回存储图中嵌入的节点属性的名称

relationship_type_embeddings

dict[str, list[float]]

返回模型的关系类型嵌入

4. 示例

在本节中,我们将举例说明如何创建和使用基于使用 TransE 评分函数训练的 KGE 模型的关系嵌入模型。部分内容将包含一个 Graph,它表示包含 KGE 模型嵌入的投影。

因此,我们首先介绍一个小型道路网络图,其中包含一些居民

gds.run_cypher(
  """
  CREATE
    (a:City {name: "New York City", settled: 1624, emb: [0.52173235, 0.85803989, 0.31678055]}),
    (b:City {name: "Philadelphia", settled: 1682, emb: [0.61455845, 0.79957553, 0.83513986]}),
    (c:City:Capital {name: "Washington D.C.", settled: 1790, emb: [0.54354943, 0.64039515, 0.23094848]}),
    (d:City {name: "Baltimore", settled: 1729, emb: [0.67689553, 0.28851121, 0.43250516]}),
    (e:City {name: "Atlantic City", settled: 1854, emb: [0.79804478, 0.81980933, 0.9322812]}),
    (f:City {name: "Boston", settled: 1822, emb: [0.15583946, 0.16060805, 0.52078528]}),

    (g:Person {name: "Brian", emb: [0.4142066 , 0.18411476, 0.68245374]}),
    (h:Person {name: "Olga", emb: [0.61230904, 0.7735076 , 0.09668418]}),
    (i:Person {name: "Jacob", emb: [0.87470625, 0.63589938, 0.33536311]}),

    (a)-[:ROAD {cost: 50}]->(b),
    (a)-[:ROAD {cost: 50}]->(c),
    (a)-[:ROAD {cost: 100}]->(d),
    (b)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 80}]->(e),
    (d)-[:ROAD {cost: 30}]->(e),
    (d)-[:ROAD {cost: 80}]->(f),
    (e)-[:ROAD {cost: 40}]->(f),

    (g)-[:LIVES_IN]->(a),
    (h)-[:LIVES_IN]->(f),
    (i)-[:LIVES_IN]->(e);
  """
)
G, project_result = gds.graph.project(
    graph_name="road_graph",
    node_spec={"City": {"properties": ["emb"]}, "Person": {"properties": ["emb"]}},
    relationship_spec=["ROAD", "LIVES_IN"]
)

# Sanity check
assert G.relationship_count() == 12

此处的 "emb" 节点属性包含我们将用于计算中推断新关系的 TransE 节点嵌入。

4.1. 创建和检查我们的模型

使用我们的图 G 和我们预先计算的关系类型嵌入,我们现在可以构建一个 TransE 关系嵌入模型。

transe_model = gds.model.transe.create(
    G,
    node_embedding_property="emb",
    relationship_type_embeddings={
        "ROAD": [0.88355126, 0.15116676, 0.24225456],
        "LIVES_IN": [0.94185368, 0.60460752, 0.92028837]
    }
)

# Sanity check
assert transe_model.scoring_function() == "transe"

创建模型后,我们可以开始预测图的新关系。

4.2. 进行预测

让我们让我们的模型预测我们感兴趣的三个居民将来可能在哪里移动,并使用这些新关系修改由 G 表示的 GDS 投影。

result = transe_model.predict_mutate(
    source_node_filter="Person",
    target_node_filter="City",
    relationship_type="LIVES_IN",
    top_k=1,
    mutate_relationship_type="MIGHT_MOVE",
    mutate_property="likeliness_score"
)

# Let us make sure the number of new relationships makes sense
assert result["relationshipsWritten"] == 3
assert G.relationship_count() == 12 + 3

通过使用 TransE 嵌入和 GDS 的关系嵌入模型功能,我们能够推断出我们感兴趣的居民未来可能迁移到的位置。我们创建的新 "MIGHT_MOVE" 关系现在是 GDS 图投影(由 G 表示)的一部分。