模型目录中的模型对象

模型目录中的 GDS 模型目录 模型以 Model 对象的形式在 Python 客户端中表示,类似于存在 图对象 的方式。Model 对象通常是通过训练 管道GraphSAGE 模型 来构建的,在这种情况下,会返回对训练好的模型的引用,该引用以 Model 对象的形式表示。

创建后,Model 对象可以作为参数传递给 Python 客户端中的方法,例如 模型目录操作。此外,Model 对象还具有便利方法,允许在不显式涉及模型目录的情况下检查所表示的模型。

在下面的示例中,我们假设我们有一个名为 gds 的已实例化的 GraphDataScience 对象。有关此内容的更多信息,请阅读 入门

1. 构建模型对象

构建模型对象的主要方法是通过训练模型。有两种类型的模型:管道模型和 GraphSAGE 模型。为了训练管道模型,必须首先创建和配置管道。有关如何操作管道的更多信息,请阅读 机器学习管道,其中包括使用管道模型的示例。在本节中,我们将举例说明创建和使用 GraphSAGE 模型对象。

首先,我们介绍一个小型的道路网络图

gds.run_cypher(
  """
  CREATE
    (a:City {name: "New York City", settled: 1624}),
    (b:City {name: "Philadelphia", settled: 1682}),
    (c:City:Capital {name: "Washington D.C.", settled: 1790}),
    (d:City {name: "Baltimore", settled: 1729}),
    (e:City {name: "Atlantic City", settled: 1854}),
    (f:City {name: "Boston", settled: 1822}),

    (a)-[:ROAD {cost: 50}]->(b),
    (a)-[:ROAD {cost: 50}]->(c),
    (a)-[:ROAD {cost: 100}]->(d),
    (b)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 80}]->(e),
    (d)-[:ROAD {cost: 30}]->(e),
    (d)-[:ROAD {cost: 80}]->(f),
    (e)-[:ROAD {cost: 40}]->(f);
  """
)
G, project_result = gds.graph.project(
    "road_graph",
    {"City": {"properties": ["settled"]}},
    {"ROAD": {"properties": ["cost"]}}
)

assert G.relationship_count() == 9

现在我们可以使用图 G 来训练 GraphSage 模型。

model, train_result = gds.beta.graphSage.train(G, modelName="city-representation", featureProperties=["settled"], randomSeed=42)

assert train_result["modelInfo"]["metrics"]["ranEpochs"] == 1

其中 model 是模型对象,res 是一个包含来自底层过程调用的元数据的 pandas Series

类似地,我们还可以从训练 机器学习管道 获取模型对象。

要获取代表已训练好并存在于模型目录中的模型的模型对象,可以调用客户端端独有的 get 方法,并将模型名称作为参数传递给它

model = gds.model.get("city-representation")

assert model.name() == "city-representation"

get 方法不使用任何层级前缀,因为它不与任何层级相关联。它只存在于客户端中,没有相应的 Cypher 过程。

2. 检查模型对象

所有模型对象上都有一些便利方法,让我们可以提取有关所表示模型的信息。

表 1. 模型对象方法
名称 参数 返回类型 描述

name

-

str

模型在模型目录中显示的名称。

type

-

str

模型的类型,例如“graphSage”。

train_config

-

Series

用于训练模型的配置。

graph_schema

-

Series

训练模型时所用图的模式。

loaded

-

bool

如果模型已 加载 到内存中的模型目录中,则为 True,否则为 False

stored

-

bool

如果模型已 存储 到磁盘上,则为 True,否则为 False

creation_time

-

neo4j.time.Datetime

模型创建时间。

shared

-

bool

如果模型在用户之间 共享,则为 True,否则为 False

exists

-

bool

如果模型存在于 GDS 模型目录中,则为 True,否则为 False

drop

failIfMissing: Optional[bool]

Series

从 GDS 模型目录中删除模型

例如,要获取我们上面创建的模型对象 model 的训练配置,可以执行以下操作

train_config = model.train_config()

assert train_config["concurrency"] == 4

3. 使用模型对象

使用模型对象的主要方法是用于预测。有关如何对 GraphSAGE 执行此操作的说明,请参见 以下描述,以及有关管道的 机器学习管道 页面。

此外,模型对象可以作为 GDS 模型目录操作 的输入使用。例如,假设我们有上面创建的模型对象 model,我们可以执行以下操作

# Store the model on disk (GDS Enterprise Edition)
_ = gds.alpha.model.store(model)

gds.beta.model.drop(model)  # same as model.drop()

# Load the model again for further use
gds.alpha.model.load(model.name())

3.1. GraphSAGE

如上面 构建模型对象 中所举例说明的,使用 Python 客户端训练 GraphSAGE 模型类似于 其 Cypher 对等项

训练完成后,除了 上面的方法 之外,GraphSAGE 模型对象还将具有以下方法。

表 2. GraphSAGE 模型方法
名称 参数 返回类型 描述

predict_mutate

G: Graph,
config: **kwargs

Series

预测输入图的节点的嵌入,并使用预测结果对图进行变异.

predict_stream

G: Graph,
config: **kwargs

DataFrame

预测输入图的节点的嵌入,并流式传输结果.

predict_write

G: Graph,
config: **kwargs

Series

预测输入图的节点的嵌入,并将结果写回数据库.

metrics

-

Series

返回训练时计算的指标的值。

因此,给定我们上面训练的 GraphSAGE 模型 model,我们可以执行以下操作

# Make sure our training actually converged
metrics = model.metrics()
assert metrics["didConverge"]

# Predict on `G` and write embedding node properties back to the database
predict_result = model.predict_write(G, writeProperty="embedding")
assert predict_result["nodePropertiesWritten"] == G.node_count()