模型目录中的模型对象
模型目录中的 GDS 模型目录 模型以 Model
对象的形式在 Python 客户端中表示,类似于存在 图对象 的方式。Model
对象通常是通过训练 管道 或 GraphSAGE 模型 来构建的,在这种情况下,会返回对训练好的模型的引用,该引用以 Model
对象的形式表示。
创建后,Model
对象可以作为参数传递给 Python 客户端中的方法,例如 模型目录操作。此外,Model
对象还具有便利方法,允许在不显式涉及模型目录的情况下检查所表示的模型。
在下面的示例中,我们假设我们有一个名为 gds
的已实例化的 GraphDataScience
对象。有关此内容的更多信息,请阅读 入门。
1. 构建模型对象
构建模型对象的主要方法是通过训练模型。有两种类型的模型:管道模型和 GraphSAGE 模型。为了训练管道模型,必须首先创建和配置管道。有关如何操作管道的更多信息,请阅读 机器学习管道,其中包括使用管道模型的示例。在本节中,我们将举例说明创建和使用 GraphSAGE 模型对象。
首先,我们介绍一个小型的道路网络图
gds.run_cypher(
"""
CREATE
(a:City {name: "New York City", settled: 1624}),
(b:City {name: "Philadelphia", settled: 1682}),
(c:City:Capital {name: "Washington D.C.", settled: 1790}),
(d:City {name: "Baltimore", settled: 1729}),
(e:City {name: "Atlantic City", settled: 1854}),
(f:City {name: "Boston", settled: 1822}),
(a)-[:ROAD {cost: 50}]->(b),
(a)-[:ROAD {cost: 50}]->(c),
(a)-[:ROAD {cost: 100}]->(d),
(b)-[:ROAD {cost: 40}]->(d),
(c)-[:ROAD {cost: 40}]->(d),
(c)-[:ROAD {cost: 80}]->(e),
(d)-[:ROAD {cost: 30}]->(e),
(d)-[:ROAD {cost: 80}]->(f),
(e)-[:ROAD {cost: 40}]->(f);
"""
)
G, project_result = gds.graph.project(
"road_graph",
{"City": {"properties": ["settled"]}},
{"ROAD": {"properties": ["cost"]}}
)
assert G.relationship_count() == 9
现在我们可以使用图 G
来训练 GraphSage 模型。
model, train_result = gds.beta.graphSage.train(G, modelName="city-representation", featureProperties=["settled"], randomSeed=42)
assert train_result["modelInfo"]["metrics"]["ranEpochs"] == 1
其中 model
是模型对象,res
是一个包含来自底层过程调用的元数据的 pandas Series
。
类似地,我们还可以从训练 机器学习管道 获取模型对象。
要获取代表已训练好并存在于模型目录中的模型的模型对象,可以调用客户端端独有的 get
方法,并将模型名称作为参数传递给它
model = gds.model.get("city-representation")
assert model.name() == "city-representation"
|
2. 检查模型对象
所有模型对象上都有一些便利方法,让我们可以提取有关所表示模型的信息。
名称 | 参数 | 返回类型 | 描述 |
---|---|---|---|
|
|
|
模型在模型目录中显示的名称。 |
|
|
|
模型的类型,例如“graphSage”。 |
|
|
|
用于训练模型的配置。 |
|
|
|
训练模型时所用图的模式。 |
|
|
|
如果模型已 加载 到内存中的模型目录中,则为 |
|
|
|
如果模型已 存储 到磁盘上,则为 |
|
|
|
模型创建时间。 |
|
|
|
如果模型在用户之间 共享,则为 |
|
|
|
如果模型存在于 GDS 模型目录中,则为 |
|
|
|
例如,要获取我们上面创建的模型对象 model
的训练配置,可以执行以下操作
train_config = model.train_config()
assert train_config["concurrency"] == 4
3. 使用模型对象
此外,模型对象可以作为 GDS 模型目录操作 的输入使用。例如,假设我们有上面创建的模型对象 model
,我们可以执行以下操作
# Store the model on disk (GDS Enterprise Edition)
_ = gds.alpha.model.store(model)
gds.beta.model.drop(model) # same as model.drop()
# Load the model again for further use
gds.alpha.model.load(model.name())
3.1. GraphSAGE
如上面 构建模型对象 中所举例说明的,使用 Python 客户端训练 GraphSAGE 模型类似于 其 Cypher 对等项。
训练完成后,除了 上面的方法 之外,GraphSAGE 模型对象还将具有以下方法。
名称 | 参数 | 返回类型 | 描述 |
---|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
返回训练时计算的指标的值。 |
因此,给定我们上面训练的 GraphSAGE 模型 model
,我们可以执行以下操作
# Make sure our training actually converged
metrics = model.metrics()
assert metrics["didConverge"]
# Predict on `G` and write embedding node properties back to the database
predict_result = model.predict_write(G, writeProperty="embedding")
assert predict_result["nodePropertiesWritten"] == G.node_count()