持久化和共享机器学习模型
在 |
此示例演示了如何使用 模型目录 训练、保存、发布和删除机器学习模型。
设置
有关如何开始使用 Python 的更多信息,请参阅 使用 Python 连接 教程。
pip install graphdatascience
# Import the client
from graphdatascience import GraphDataScience
# Replace with the actual URI, username, and password
AURA_CONNECTION_URI = "neo4j+s://xxxxxxxx.databases.neo4j.io"
AURA_USERNAME = "neo4j"
AURA_PASSWORD = ""
# Configure the client with AuraDS-recommended settings
gds = GraphDataScience(
AURA_CONNECTION_URI,
auth=(AURA_USERNAME, AURA_PASSWORD),
aura_ds=True
)
有关如何开始使用 Cypher Shell 的更多信息,请参阅 Neo4j Cypher Shell 教程。
从安装 Cypher shell 的目录运行以下命令。 |
export AURA_CONNECTION_URI="neo4j+s://xxxxxxxx.databases.neo4j.io"
export AURA_USERNAME="neo4j"
export AURA_PASSWORD=""
./cypher-shell -a $AURA_CONNECTION_URI -u $AURA_USERNAME -p $AURA_PASSWORD
有关如何开始使用 Python 的更多信息,请参阅 使用 Python 连接 教程。
pip install neo4j
# Import the driver
from neo4j import GraphDatabase
# Replace with the actual URI, username, and password
AURA_CONNECTION_URI = "neo4j+s://xxxxxxxx.databases.neo4j.io"
AURA_USERNAME = "neo4j"
AURA_PASSWORD = ""
# Instantiate the driver
driver = GraphDatabase.driver(
AURA_CONNECTION_URI,
auth=(AURA_USERNAME, AURA_PASSWORD)
)
# Import to prettify results
import json
# Import for the JSON helper function
from neo4j.time import DateTime
# Helper function for serializing Neo4j DateTime in JSON dumps
def default(o):
if isinstance(o, (DateTime)):
return o.isoformat()
创建示例图
我们首先创建一些基本的图数据。
gds.run_cypher("""
MERGE (dan:Person:ExampleData {name: 'Dan', age: 20, heightAndWeight: [185, 75]})
MERGE (annie:Person:ExampleData {name: 'Annie', age: 12, heightAndWeight: [124, 42]})
MERGE (matt:Person:ExampleData {name: 'Matt', age: 67, heightAndWeight: [170, 80]})
MERGE (jeff:Person:ExampleData {name: 'Jeff', age: 45, heightAndWeight: [192, 85]})
MERGE (brie:Person:ExampleData {name: 'Brie', age: 27, heightAndWeight: [176, 57]})
MERGE (elsa:Person:ExampleData {name: 'Elsa', age: 32, heightAndWeight: [158, 55]})
MERGE (john:Person:ExampleData {name: 'John', age: 35, heightAndWeight: [172, 76]})
MERGE (dan)-[:KNOWS {relWeight: 1.0}]->(annie)
MERGE (dan)-[:KNOWS {relWeight: 1.6}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 0.1}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 3.0}]->(jeff)
MERGE (annie)-[:KNOWS {relWeight: 1.2}]->(brie)
MERGE (matt)-[:KNOWS {relWeight: 10.0}]->(brie)
MERGE (brie)-[:KNOWS {relWeight: 1.0}]->(elsa)
MERGE (brie)-[:KNOWS {relWeight: 2.2}]->(jeff)
MERGE (john)-[:KNOWS {relWeight: 5.0}]->(jeff)
RETURN True AS exampleDataCreated
""")
MERGE (dan:Person:ExampleData {name: 'Dan', age: 20, heightAndWeight: [185, 75]})
MERGE (annie:Person:ExampleData {name: 'Annie', age: 12, heightAndWeight: [124, 42]})
MERGE (matt:Person:ExampleData {name: 'Matt', age: 67, heightAndWeight: [170, 80]})
MERGE (jeff:Person:ExampleData {name: 'Jeff', age: 45, heightAndWeight: [192, 85]})
MERGE (brie:Person:ExampleData {name: 'Brie', age: 27, heightAndWeight: [176, 57]})
MERGE (elsa:Person:ExampleData {name: 'Elsa', age: 32, heightAndWeight: [158, 55]})
MERGE (john:Person:ExampleData {name: 'John', age: 35, heightAndWeight: [172, 76]})
MERGE (dan)-[:KNOWS {relWeight: 1.0}]->(annie)
MERGE (dan)-[:KNOWS {relWeight: 1.6}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 0.1}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 3.0}]->(jeff)
MERGE (annie)-[:KNOWS {relWeight: 1.2}]->(brie)
MERGE (matt)-[:KNOWS {relWeight: 10.0}]->(brie)
MERGE (brie)-[:KNOWS {relWeight: 1.0}]->(elsa)
MERGE (brie)-[:KNOWS {relWeight: 2.2}]->(jeff)
MERGE (john)-[:KNOWS {relWeight: 5.0}]->(jeff)
RETURN True AS exampleDataCreated
# Cypher query
create_example_graph_on_disk_query = """
MERGE (dan:Person:ExampleData {name: 'Dan', age: 20, heightAndWeight: [185, 75]})
MERGE (annie:Person:ExampleData {name: 'Annie', age: 12, heightAndWeight: [124, 42]})
MERGE (matt:Person:ExampleData {name: 'Matt', age: 67, heightAndWeight: [170, 80]})
MERGE (jeff:Person:ExampleData {name: 'Jeff', age: 45, heightAndWeight: [192, 85]})
MERGE (brie:Person:ExampleData {name: 'Brie', age: 27, heightAndWeight: [176, 57]})
MERGE (elsa:Person:ExampleData {name: 'Elsa', age: 32, heightAndWeight: [158, 55]})
MERGE (john:Person:ExampleData {name: 'John', age: 35, heightAndWeight: [172, 76]})
MERGE (dan)-[:KNOWS {relWeight: 1.0}]->(annie)
MERGE (dan)-[:KNOWS {relWeight: 1.6}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 0.1}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 3.0}]->(jeff)
MERGE (annie)-[:KNOWS {relWeight: 1.2}]->(brie)
MERGE (matt)-[:KNOWS {relWeight: 10.0}]->(brie)
MERGE (brie)-[:KNOWS {relWeight: 1.0}]->(elsa)
MERGE (brie)-[:KNOWS {relWeight: 2.2}]->(jeff)
MERGE (john)-[:KNOWS {relWeight: 5.0}]->(jeff)
RETURN True AS exampleDataCreated
"""
# Create the driver session
with driver.session() as session:
# Run query
result = session.run(create_example_graph_on_disk_query).data()
# Prettify the result
print(json.dumps(result, indent=2, sort_keys=True))
然后,我们从刚刚创建的数据中投影一个内存中的图。
g, result = gds.graph.project(
"example_graph_for_graphsage",
{
"Person": {
"label": "ExampleData",
"properties": ["age", "heightAndWeight"]
}
},
{
"KNOWS": {
"type": "KNOWS",
"orientation": "UNDIRECTED",
"properties": ["relWeight"]
}
}
)
print(result)
CALL gds.graph.project(
'example_graph_for_graphsage',
{
Person: {
label: 'ExampleData',
properties: ['age', 'heightAndWeight']
}
},
{
KNOWS: {
type: 'KNOWS',
orientation: 'UNDIRECTED',
properties: ['relWeight']
}
}
)
# Cypher query
create_example_graph_in_memory_query = """
CALL gds.graph.project(
'example_graph_for_graphsage',
{
Person: {
label: 'ExampleData',
properties: ['age', 'heightAndWeight']
}
},
{
KNOWS: {
type: 'KNOWS',
orientation: 'UNDIRECTED',
properties: ['relWeight']
}
}
)
"""
# Create the driver session
with driver.session() as session:
# Run query
result = session.run(create_example_graph_in_memory_query).data()
# Prettify the result
print(json.dumps(result, indent=2, sort_keys=True))
训练模型
支持 train
模式的机器学习算法会生成经过训练的模型,这些模型存储在模型目录中。类似地,predict
过程可以使用此类经过训练的模型来生成预测。在此示例中,我们使用 train
模式训练 GraphSAGE 算法 的模型。
model, result = gds.beta.graphSage.train(
g,
modelName="example_graph_model_for_graphsage",
featureProperties=["age", "heightAndWeight"],
aggregator="mean",
activationFunction="sigmoid",
sampleSizes=[25, 10]
)
CALL gds.beta.graphSage.train(
'example_graph_for_graphsage',
{
modelName: 'example_graph_model_for_graphsage',
featureProperties: ['age', 'heightAndWeight'],
aggregator: 'mean',
activationFunction: 'sigmoid',
sampleSizes: [25, 10]
}
)
YIELD modelInfo as info
RETURN
info.name as modelName,
info.metrics.didConverge as didConverge,
info.metrics.ranEpochs as ranEpochs,
info.metrics.epochLosses as epochLosses
# Cypher query
train_graph_sage_on_in_memory_graph_query = """
CALL gds.beta.graphSage.train(
'example_graph_for_graphsage',
{
modelName: 'example_graph_model_for_graphsage',
featureProperties: ['age', 'heightAndWeight'],
aggregator: 'mean',
activationFunction: 'sigmoid',
sampleSizes: [25, 10]
}
)
YIELD modelInfo as info
RETURN
info.name as modelName,
info.metrics.didConverge as didConverge,
info.metrics.ranEpochs as ranEpochs,
info.metrics.epochLosses as epochLosses
"""
# Create the driver session
with driver.session() as session:
# Run query
result = session.run(train_graph_sage_on_in_memory_graph_query).data()
# Prettify the result
print(json.dumps(result, indent=2, sort_keys=True))
查看模型目录
我们可以使用 gds.beta.model.list
过程获取有关当前目录中所有可用模型的信息。除了有关图模式、模型名称和训练配置的信息外,调用的结果还包含以下字段
-
loaded
:表示模型是否在内存中 (true
) 或在磁盘上可用 (false
) 的标志 -
stored
:表示模型是否已持久化到磁盘的标志 -
shared
:表示模型是否已发布,使其对所有用户可访问的标志
results = gds.beta.model.list()
print(results)
CALL gds.beta.model.list()
# Cypher query
list_model_catalog_query = """
CALL gds.beta.model.list()
"""
# Create the driver session
with driver.session() as session:
# Run query
results = session.run(list_model_catalog_query).data()
# Prettify the results
print(json.dumps(results, indent=2, sort_keys=True, default=default))
将模型保存到磁盘
gds.alpha.model.store
过程可用于将模型持久化到磁盘。这对于保留模型以供以后重用以及释放内存都很有用。
并非所有模型都可以保存到磁盘。可以在 GDS 手册 上找到受支持模型的列表。 如果模型无法保存到磁盘,则在 AuraDS 实例重新启动时将丢失。 |
result = gds.alpha.model.store(model)
print(result)
CALL gds.alpha.model.store("example_graph_model_for_graphsage")
# Cypher query
save_graph_sage_model_to_disk_query = """
CALL gds.alpha.model.store("example_graph_model_for_graphsage")
"""
# Create the driver session
with driver.session() as session:
# Run query
result = session.run(save_graph_sage_model_to_disk_query).data()
# Prettify the result
print(json.dumps(result, indent=2, sort_keys=True))
如果在持久化模型后再次列出模型目录,我们可以看到该模型的 stored
标志已设置为 true
。
results = gds.beta.model.list()
print(results)
CALL gds.beta.model.list()
# Cypher query
list_model_catalog_query = """
CALL gds.beta.model.list()
"""
# Create the driver session
with driver.session() as session:
# Run query
results = session.run(list_model_catalog_query).data()
# Prettify the results
print(json.dumps(results, indent=2, sort_keys=True, default=default))
与其他用户共享模型
创建模型后,将其提供给其他用户以用于不同的用例可能很有用。
模型只能与同一 AuraDS 实例的其他用户共享。 |
创建新用户
为了查看这在 AuraDS 上是如何在实践中工作的,我们首先需要 创建另一个用户 来与之共享模型。
# Switch to the "system" database to run the
# "CREATE USER" admin command
gds.set_database("system")
gds.run_cypher("""
CREATE USER testUser IF NOT EXISTS
SET PASSWORD 'password'
SET PASSWORD CHANGE NOT REQUIRED
""")
:connect system
CREATE USER testUser IF NOT EXISTS
SET PASSWORD 'password'
SET PASSWORD CHANGE NOT REQUIRED
# Cypher query
create_a_new_user_query = """
CREATE USER testUser IF NOT EXISTS
SET PASSWORD 'password'
SET PASSWORD CHANGE NOT REQUIRED
"""
# Create the driver session using the "system" database
with driver.session(database="system") as session:
# Run query
result = session.run(create_a_new_user_query).data()
# Prettify the result
print(json.dumps(result, indent=2, sort_keys=True))
发布模型
可以使用 gds.alpha.model.publish
过程发布(使其他用户能够访问)模型。发布后,模型名称会通过在其原始名称后附加 _public
来更新。
# Switch back to the default "neo4j" database
# to publish the model
gds.set_database("neo4j")
model_public = gds.alpha.model.publish(model)
print(model_public)
:connect neo4j
CALL gds.alpha.model.publish('example_graph_model_for_graphsage')
# Cypher query
publish_graph_sage_model_to_disk_query = """
CALL gds.alpha.model.publish('example_graph_model_for_graphsage')
"""
# Create the driver session
with driver.session() as session:
# Run query
result = session.run(publish_graph_sage_model_to_disk_query).data()
# Prettify the result
print(json.dumps(result, indent=2, sort_keys=True, default=default))
以不同用户身份查看模型
为了验证已发布的模型是否对我们刚刚创建的用户可见,我们需要创建一个新的客户端(或驱动程序)会话。然后,我们可以在新用户下再次使用它来运行 gds.beta.model.list
过程,并验证模型是否包含在列表中。
test_user_gds = GraphDataScience(
AURA_CONNECTION_URI,
auth=("testUser", "password"),
aura_ds=True
)
results = test_user_gds.beta.model.list()
print(results)
// First, open a new Cypher shell with the following command:
//
// ./cypher-shell -a $AURA_CONNECTION_URI -u testUser -p password
CALL gds.beta.model.list()
test_user_driver = GraphDatabase.driver(
AURA_CONNECTION_URI,
auth=("testUser", "password")
)
# Create the driver session
with test_user_driver.session() as session:
# Run query
results = session.run(list_model_catalog_query).data()
# Prettify the results
print(json.dumps(results, indent=2, sort_keys=True, default=default))
清理
内存中的图、Neo4j 数据库中的数据、模型和测试用户现在都可以删除了。
# Delete the example dataset
gds.run_cypher("""
MATCH (example:ExampleData)
DETACH DELETE example
""")
# Delete the projected graph from memory
gds.graph.drop(g)
# Drop the model from memory
gds.beta.model.drop(model_public)
# Delete the model from disk
gds.alpha.model.delete(model_public)
# Switch to the "system" database to delete the example user
gds.set_database("system")
gds.run_cypher("""
DROP USER testUser
""")
// Delete the example dataset from the database
MATCH (example:ExampleData)
DETACH DELETE example;
// Delete the projected graph from memory
CALL gds.graph.drop("example_graph_for_graphsage");
// Drop the model from memory
CALL gds.beta.model.drop("example_graph_model_for_graphsage_public");
// Delete the model from disk
CALL gds.alpha.model.delete("example_graph_model_for_graphsage_public");
// Delete the example user
DROP USER testUser;
# Delete the example dataset from the database
delete_example_graph_query = """
MATCH (example:ExampleData)
DETACH DELETE example
"""
# Delete the projected graph from memory
drop_in_memory_graph_query = """
CALL gds.graph.drop("example_graph_for_graphsage")
"""
# Drop the model from memory
drop_example_models_query = """
CALL gds.beta.model.drop("example_graph_model_for_graphsage_public")
"""
# Delete the model from disk
delete_example_models_query = """
CALL gds.alpha.model.delete("example_graph_model_for_graphsage_public")
"""
# Delete the example user
drop_example_user_query = """
DROP USER testUser
"""
# Create the driver session
with driver.session() as session:
# Run queries
print(session.run(delete_example_graph_query).data())
print(session.run(drop_in_memory_graph_query).data())
print(session.run(drop_example_models_query).data())
print(session.run(delete_example_models_query).data())
# Create another driver session on the system database
# to drop the test user
with driver.session(database='system') as session:
print(session.run(drop_example_user_query).data())
driver.close()
test_user_driver.close()
参考文献
Cypher
-
了解有关 Cypher 语法的更多信息
-
您可以使用 Cypher 速查表 作为所有可用 Cypher 功能的参考