模型目录中的模型对象
GDS 模型目录中的模型在 Python 客户端中表示为 Model
对象,类似于 图对象。Model
对象通常通过训练 流水线 或 GraphSAGE 模型 来构建,在这种情况下,会返回一个 Model
形式的训练模型引用。
创建后,Model
对象可以作为参数传递给 Python 客户端中的方法,例如 模型目录操作。此外,Model
对象还具有便捷方法,允许在不明确涉及模型目录的情况下检查所表示的模型。
在下面的示例中,我们假设我们有一个名为 gds
的 GraphDataScience
实例化对象。有关此内容的更多信息,请参阅 入门。
1. 构建模型对象
构建模型对象的主要方式是通过训练模型。模型分为两种类型:流水线模型和 GraphSAGE 模型。为了训练流水线模型,必须首先创建和配置流水线。有关如何操作流水线的更多信息,包括使用流水线模型的示例,请参阅 机器学习流水线。在本节中,我们将举例说明如何创建和使用 GraphSAGE 模型对象。
首先,我们引入一个小型路网图
gds.run_cypher(
"""
CREATE
(a:City {name: "New York City", settled: 1624}),
(b:City {name: "Philadelphia", settled: 1682}),
(c:City:Capital {name: "Washington D.C.", settled: 1790}),
(d:City {name: "Baltimore", settled: 1729}),
(e:City {name: "Atlantic City", settled: 1854}),
(f:City {name: "Boston", settled: 1822}),
(a)-[:ROAD {cost: 50}]->(b),
(a)-[:ROAD {cost: 50}]->(c),
(a)-[:ROAD {cost: 100}]->(d),
(b)-[:ROAD {cost: 40}]->(d),
(c)-[:ROAD {cost: 40}]->(d),
(c)-[:ROAD {cost: 80}]->(e),
(d)-[:ROAD {cost: 30}]->(e),
(d)-[:ROAD {cost: 80}]->(f),
(e)-[:ROAD {cost: 40}]->(f);
"""
)
G, project_result = gds.graph.project(
"road_graph",
{"City": {"properties": ["settled"]}},
{"ROAD": {"properties": ["cost"]}}
)
assert G.relationship_count() == 9
现在我们可以使用图 G
来训练 GraphSage 模型。
model, train_result = gds.beta.graphSage.train(G, modelName="city-representation", featureProperties=["settled"], randomSeed=42)
assert train_result["modelInfo"]["metrics"]["ranEpochs"] == 1
其中 model
是模型对象,res
是一个 pandas Series
,包含来自底层过程调用的元数据。
类似地,我们也可以从训练 机器学习流水线 中获取模型对象。
要获取表示已训练并存在于模型目录中的模型对象,可以调用仅限客户端的 get
方法并为其传递一个名称
model = gds.model.get("city-representation")
assert model.name() == "city-representation"
|
2. 检查模型对象
所有模型对象上都有便捷方法,可以让我们提取有关所表示模型的信息。
名称 | 参数 | 返回类型 | 描述 |
---|---|---|---|
|
|
|
模型在模型目录中显示的名称。 |
|
|
|
模型的类型,例如“graphSage”。 |
|
|
|
用于训练模型的配置。 |
|
|
|
训练模型所用的图的模式。 |
|
|
|
如果模型已 加载 到内存模型目录中,则为 |
|
|
|
如果模型已 存储 到磁盘上,则为 |
|
|
|
模型创建时间。 |
|
|
|
如果模型在用户之间 共享,则为 |
|
|
|
如果模型存在于 GDS 模型目录中,则为 |
|
|
|
例如,要获取上面创建的模型对象 model
的训练配置,我们可以执行以下操作:
train_config = model.train_config()
assert train_config["concurrency"] == 4
3. 使用模型对象
此外,模型对象可以用作 GDS 模型目录操作 的输入。例如,假设我们有上面创建的模型对象 model
,我们可以
# Store the model on disk (GDS Enterprise Edition)
_ = gds.model.store(model)
gds.model.drop(model) # same as model.drop()
# Load the model again for further use
gds.model.load(model.name())
3.1. GraphSAGE
如上文 构建模型对象 中所示,使用 Python 客户端训练 GraphSAGE 模型类似于 其 Cypher 对应项。
训练完成后,除了 上述方法 之外,GraphSAGE 模型对象还将具有以下方法。
名称 | 参数 | 返回类型 | 描述 |
---|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
返回训练时计算的指标值。 |
因此,给定我们上面训练的 GraphSAGE 模型 model
,我们可以执行以下操作:
# Make sure our training actually converged
metrics = model.metrics()
assert metrics["didConverge"]
# Predict on `G` and write embedding node properties back to the database
predict_result = model.predict_write(G, writeProperty="embedding")
assert predict_result["nodePropertiesWritten"] == G.node_count()