模型目录中的模型对象

GDS 模型目录中的模型在 Python 客户端中表示为 Model 对象,类似于 图对象Model 对象通常通过训练 流水线GraphSAGE 模型 来构建,在这种情况下,会返回一个 Model 形式的训练模型引用。

创建后,Model 对象可以作为参数传递给 Python 客户端中的方法,例如 模型目录操作。此外,Model 对象还具有便捷方法,允许在不明确涉及模型目录的情况下检查所表示的模型。

在下面的示例中,我们假设我们有一个名为 gdsGraphDataScience 实例化对象。有关此内容的更多信息,请参阅 入门

1. 构建模型对象

构建模型对象的主要方式是通过训练模型。模型分为两种类型:流水线模型和 GraphSAGE 模型。为了训练流水线模型,必须首先创建和配置流水线。有关如何操作流水线的更多信息,包括使用流水线模型的示例,请参阅 机器学习流水线。在本节中,我们将举例说明如何创建和使用 GraphSAGE 模型对象。

首先,我们引入一个小型路网图

gds.run_cypher(
  """
  CREATE
    (a:City {name: "New York City", settled: 1624}),
    (b:City {name: "Philadelphia", settled: 1682}),
    (c:City:Capital {name: "Washington D.C.", settled: 1790}),
    (d:City {name: "Baltimore", settled: 1729}),
    (e:City {name: "Atlantic City", settled: 1854}),
    (f:City {name: "Boston", settled: 1822}),

    (a)-[:ROAD {cost: 50}]->(b),
    (a)-[:ROAD {cost: 50}]->(c),
    (a)-[:ROAD {cost: 100}]->(d),
    (b)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 80}]->(e),
    (d)-[:ROAD {cost: 30}]->(e),
    (d)-[:ROAD {cost: 80}]->(f),
    (e)-[:ROAD {cost: 40}]->(f);
  """
)
G, project_result = gds.graph.project(
    "road_graph",
    {"City": {"properties": ["settled"]}},
    {"ROAD": {"properties": ["cost"]}}
)

assert G.relationship_count() == 9

现在我们可以使用图 G 来训练 GraphSage 模型。

model, train_result = gds.beta.graphSage.train(G, modelName="city-representation", featureProperties=["settled"], randomSeed=42)

assert train_result["modelInfo"]["metrics"]["ranEpochs"] == 1

其中 model 是模型对象,res 是一个 pandas Series,包含来自底层过程调用的元数据。

类似地,我们也可以从训练 机器学习流水线 中获取模型对象。

要获取表示已训练并存在于模型目录中的模型对象,可以调用仅限客户端的 get 方法并为其传递一个名称

model = gds.model.get("city-representation")

assert model.name() == "city-representation"

get 方法不使用任何层级前缀,因为它不与任何层级关联。它仅存在于客户端中,没有相应的 Cypher 过程。

2. 检查模型对象

所有模型对象上都有便捷方法,可以让我们提取有关所表示模型的信息。

表 1. 模型对象方法
名称 参数 返回类型 描述

name

-

str

模型在模型目录中显示的名称。

type

-

str

模型的类型,例如“graphSage”。

train_config

-

Series

用于训练模型的配置。

graph_schema

-

Series

训练模型所用的图的模式。

loaded

-

bool

如果模型已 加载 到内存模型目录中,则为 True,否则为 False

stored

-

bool

如果模型已 存储 到磁盘上,则为 True,否则为 False

creation_time

-

neo4j.time.Datetime

模型创建时间。

shared

-

bool

如果模型在用户之间 共享,则为 True,否则为 False

exists

-

bool

如果模型存在于 GDS 模型目录中,则为 True,否则为 False

drop

failIfMissing: Optional[bool]

Series

从 GDS 模型目录中移除模型

例如,要获取上面创建的模型对象 model 的训练配置,我们可以执行以下操作:

train_config = model.train_config()

assert train_config["concurrency"] == 4

3. 使用模型对象

使用模型对象的主要方式是进行预测。如何对 GraphSAGE 执行此操作 如下所述,以及在 机器学习流水线 页面上针对流水线进行了说明。

此外,模型对象可以用作 GDS 模型目录操作 的输入。例如,假设我们有上面创建的模型对象 model,我们可以

# Store the model on disk (GDS Enterprise Edition)
_ = gds.model.store(model)

gds.model.drop(model)  # same as model.drop()

# Load the model again for further use
gds.model.load(model.name())

3.1. GraphSAGE

如上文 构建模型对象 中所示,使用 Python 客户端训练 GraphSAGE 模型类似于 其 Cypher 对应项

训练完成后,除了 上述方法 之外,GraphSAGE 模型对象还将具有以下方法。

表 2. GraphSAGE 模型方法
名称 参数 返回类型 描述

predict_mutate

G: 图,
config: **kwargs

Series

预测输入图节点的嵌入并根据预测修改图.

predict_stream

G: 图,
config: **kwargs

DataFrame

预测输入图节点的嵌入并流式传输结果.

predict_write

G: 图,
config: **kwargs

Series

预测输入图节点的嵌入并将结果写回数据库.

metrics

-

Series

返回训练时计算的指标值。

因此,给定我们上面训练的 GraphSAGE 模型 model,我们可以执行以下操作:

# Make sure our training actually converged
metrics = model.metrics()
assert metrics["didConverge"]

# Predict on `G` and write embedding node properties back to the database
predict_result = model.predict_write(G, writeProperty="embedding")
assert predict_result["nodePropertiesWritten"] == G.node_count()