持久化和共享机器学习模型

Colab Google Colab 中使用笔记本进行学习。

此示例演示了如何使用 模型目录 训练、保存、发布和删除机器学习模型。

设置

有关如何开始使用 Python 的更多信息,请参阅 使用 Python 连接 教程。

pip install graphdatascience
# Import the client
from graphdatascience import GraphDataScience

# Replace with the actual URI, username, and password
AURA_CONNECTION_URI = "neo4j+s://xxxxxxxx.databases.neo4j.io"
AURA_USERNAME = "neo4j"
AURA_PASSWORD = ""

# Configure the client with AuraDS-recommended settings
gds = GraphDataScience(
    AURA_CONNECTION_URI,
    auth=(AURA_USERNAME, AURA_PASSWORD),
    aura_ds=True
)

在以下代码示例中,我们使用 print 函数打印 Pandas DataFrameSeries 对象。您可以尝试不同的方法来打印 Pandas 对象,例如通过 to_stringto_json 方法;如果您使用 JSON 表示形式,在某些情况下,您可能需要包含一个 默认处理程序 来处理 Neo4j DateTime 对象。有关一些示例,请查看 Python 连接 部分。

有关如何开始使用 Cypher Shell 的更多信息,请参阅 Neo4j Cypher Shell 教程。

从安装 Cypher shell 的目录运行以下命令。
export AURA_CONNECTION_URI="neo4j+s://xxxxxxxx.databases.neo4j.io"
export AURA_USERNAME="neo4j"
export AURA_PASSWORD=""

./cypher-shell -a $AURA_CONNECTION_URI -u $AURA_USERNAME -p $AURA_PASSWORD

有关如何开始使用 Python 的更多信息,请参阅 使用 Python 连接 教程。

pip install neo4j
# Import the driver
from neo4j import GraphDatabase

# Replace with the actual URI, username, and password
AURA_CONNECTION_URI = "neo4j+s://xxxxxxxx.databases.neo4j.io"
AURA_USERNAME = "neo4j"
AURA_PASSWORD = ""

# Instantiate the driver
driver = GraphDatabase.driver(
    AURA_CONNECTION_URI,
    auth=(AURA_USERNAME, AURA_PASSWORD)
)
# Import to prettify results
import json

# Import for the JSON helper function
from neo4j.time import DateTime

# Helper function for serializing Neo4j DateTime in JSON dumps
def default(o):
    if isinstance(o, (DateTime)):
        return o.isoformat()

创建示例图

我们首先创建一些基本的图数据。

gds.run_cypher("""
    MERGE (dan:Person:ExampleData   {name: 'Dan',   age: 20, heightAndWeight: [185, 75]})
    MERGE (annie:Person:ExampleData {name: 'Annie', age: 12, heightAndWeight: [124, 42]})
    MERGE (matt:Person:ExampleData  {name: 'Matt',  age: 67, heightAndWeight: [170, 80]})
    MERGE (jeff:Person:ExampleData  {name: 'Jeff',  age: 45, heightAndWeight: [192, 85]})
    MERGE (brie:Person:ExampleData  {name: 'Brie',  age: 27, heightAndWeight: [176, 57]})
    MERGE (elsa:Person:ExampleData  {name: 'Elsa',  age: 32, heightAndWeight: [158, 55]})
    MERGE (john:Person:ExampleData  {name: 'John',  age: 35, heightAndWeight: [172, 76]})

    MERGE (dan)-[:KNOWS {relWeight: 1.0}]->(annie)
    MERGE (dan)-[:KNOWS {relWeight: 1.6}]->(matt)
    MERGE (annie)-[:KNOWS {relWeight: 0.1}]->(matt)
    MERGE (annie)-[:KNOWS {relWeight: 3.0}]->(jeff)
    MERGE (annie)-[:KNOWS {relWeight: 1.2}]->(brie)
    MERGE (matt)-[:KNOWS {relWeight: 10.0}]->(brie)
    MERGE (brie)-[:KNOWS {relWeight: 1.0}]->(elsa)
    MERGE (brie)-[:KNOWS {relWeight: 2.2}]->(jeff)
    MERGE (john)-[:KNOWS {relWeight: 5.0}]->(jeff)

    RETURN True AS exampleDataCreated
""")
MERGE (dan:Person:ExampleData   {name: 'Dan',   age: 20, heightAndWeight: [185, 75]})
MERGE (annie:Person:ExampleData {name: 'Annie', age: 12, heightAndWeight: [124, 42]})
MERGE (matt:Person:ExampleData  {name: 'Matt',  age: 67, heightAndWeight: [170, 80]})
MERGE (jeff:Person:ExampleData  {name: 'Jeff',  age: 45, heightAndWeight: [192, 85]})
MERGE (brie:Person:ExampleData  {name: 'Brie',  age: 27, heightAndWeight: [176, 57]})
MERGE (elsa:Person:ExampleData  {name: 'Elsa',  age: 32, heightAndWeight: [158, 55]})
MERGE (john:Person:ExampleData  {name: 'John',  age: 35, heightAndWeight: [172, 76]})

MERGE (dan)-[:KNOWS {relWeight: 1.0}]->(annie)
MERGE (dan)-[:KNOWS {relWeight: 1.6}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 0.1}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 3.0}]->(jeff)
MERGE (annie)-[:KNOWS {relWeight: 1.2}]->(brie)
MERGE (matt)-[:KNOWS {relWeight: 10.0}]->(brie)
MERGE (brie)-[:KNOWS {relWeight: 1.0}]->(elsa)
MERGE (brie)-[:KNOWS {relWeight: 2.2}]->(jeff)
MERGE (john)-[:KNOWS {relWeight: 5.0}]->(jeff)

RETURN True AS exampleDataCreated
# Cypher query
create_example_graph_on_disk_query = """
    MERGE (dan:Person:ExampleData   {name: 'Dan',   age: 20, heightAndWeight: [185, 75]})
    MERGE (annie:Person:ExampleData {name: 'Annie', age: 12, heightAndWeight: [124, 42]})
    MERGE (matt:Person:ExampleData  {name: 'Matt',  age: 67, heightAndWeight: [170, 80]})
    MERGE (jeff:Person:ExampleData  {name: 'Jeff',  age: 45, heightAndWeight: [192, 85]})
    MERGE (brie:Person:ExampleData  {name: 'Brie',  age: 27, heightAndWeight: [176, 57]})
    MERGE (elsa:Person:ExampleData  {name: 'Elsa',  age: 32, heightAndWeight: [158, 55]})
    MERGE (john:Person:ExampleData  {name: 'John',  age: 35, heightAndWeight: [172, 76]})

    MERGE (dan)-[:KNOWS {relWeight: 1.0}]->(annie)
    MERGE (dan)-[:KNOWS {relWeight: 1.6}]->(matt)
    MERGE (annie)-[:KNOWS {relWeight: 0.1}]->(matt)
    MERGE (annie)-[:KNOWS {relWeight: 3.0}]->(jeff)
    MERGE (annie)-[:KNOWS {relWeight: 1.2}]->(brie)
    MERGE (matt)-[:KNOWS {relWeight: 10.0}]->(brie)
    MERGE (brie)-[:KNOWS {relWeight: 1.0}]->(elsa)
    MERGE (brie)-[:KNOWS {relWeight: 2.2}]->(jeff)
    MERGE (john)-[:KNOWS {relWeight: 5.0}]->(jeff)

    RETURN True AS exampleDataCreated
"""

# Create the driver session
with driver.session() as session:
    # Run query
    result = session.run(create_example_graph_on_disk_query).data()

    # Prettify the result
    print(json.dumps(result, indent=2, sort_keys=True))

然后,我们从刚刚创建的数据中投影一个内存中的图。

g, result = gds.graph.project(
    "example_graph_for_graphsage",
    {
        "Person": {
            "label": "ExampleData",
            "properties": ["age", "heightAndWeight"]
        }
    },
    {
        "KNOWS": {
            "type": "KNOWS",
            "orientation": "UNDIRECTED",
            "properties": ["relWeight"]
        }
    }
)

print(result)
CALL gds.graph.project(
  'example_graph_for_graphsage',
  {
    Person: {
      label: 'ExampleData',
      properties: ['age', 'heightAndWeight']
    }
  },
  {
    KNOWS: {
      type: 'KNOWS',
      orientation: 'UNDIRECTED',
      properties: ['relWeight']
    }
  }
)
# Cypher query
create_example_graph_in_memory_query = """
    CALL gds.graph.project(
      'example_graph_for_graphsage',
      {
        Person: {
          label: 'ExampleData',
          properties: ['age', 'heightAndWeight']
        }
      },
      {
        KNOWS: {
          type: 'KNOWS',
          orientation: 'UNDIRECTED',
          properties: ['relWeight']
        }
      }
    )
"""

# Create the driver session
with driver.session() as session:
    # Run query
    result = session.run(create_example_graph_in_memory_query).data()

    # Prettify the result
    print(json.dumps(result, indent=2, sort_keys=True))

训练模型

支持 train 模式的机器学习算法会生成经过训练的模型,这些模型存储在模型目录中。类似地,predict 过程可以使用此类经过训练的模型来生成预测。在此示例中,我们使用 train 模式训练 GraphSAGE 算法 的模型。

model, result = gds.beta.graphSage.train(
    g,
    modelName="example_graph_model_for_graphsage",
    featureProperties=["age", "heightAndWeight"],
    aggregator="mean",
    activationFunction="sigmoid",
    sampleSizes=[25, 10]
)
  CALL gds.beta.graphSage.train(
    'example_graph_for_graphsage',
    {
      modelName: 'example_graph_model_for_graphsage',
      featureProperties: ['age', 'heightAndWeight'],
      aggregator: 'mean',
      activationFunction: 'sigmoid',
      sampleSizes: [25, 10]
    }
  )
  YIELD modelInfo as info
  RETURN
    info.name as modelName,
    info.metrics.didConverge as didConverge,
    info.metrics.ranEpochs as ranEpochs,
    info.metrics.epochLosses as epochLosses
# Cypher query
train_graph_sage_on_in_memory_graph_query  = """
    CALL gds.beta.graphSage.train(
      'example_graph_for_graphsage',
      {
        modelName: 'example_graph_model_for_graphsage',
        featureProperties: ['age', 'heightAndWeight'],
        aggregator: 'mean',
        activationFunction: 'sigmoid',
        sampleSizes: [25, 10]
      }
    )
    YIELD modelInfo as info
    RETURN
      info.name as modelName,
      info.metrics.didConverge as didConverge,
      info.metrics.ranEpochs as ranEpochs,
      info.metrics.epochLosses as epochLosses
"""

# Create the driver session
with driver.session() as session:
    # Run query
    result = session.run(train_graph_sage_on_in_memory_graph_query).data()

    # Prettify the result
    print(json.dumps(result, indent=2, sort_keys=True))

查看模型目录

我们可以使用 gds.beta.model.list 过程获取有关当前目录中所有可用模型的信息。除了有关图模式、模型名称和训练配置的信息外,调用的结果还包含以下字段

  • loaded:表示模型是否在内存中 (true) 或在磁盘上可用 (false) 的标志

  • stored:表示模型是否已持久化到磁盘的标志

  • shared:表示模型是否已发布,使其对所有用户可访问的标志

results = gds.beta.model.list()

print(results)
CALL gds.beta.model.list()
# Cypher query
list_model_catalog_query  = """
    CALL gds.beta.model.list()
"""

# Create the driver session
with driver.session() as session:
    # Run query
    results = session.run(list_model_catalog_query).data()

    # Prettify the results
    print(json.dumps(results, indent=2, sort_keys=True, default=default))

将模型保存到磁盘

gds.alpha.model.store 过程可用于将模型持久化到磁盘。这对于保留模型以供以后重用以及释放内存都很有用。

并非所有模型都可以保存到磁盘。可以在 GDS 手册 上找到受支持模型的列表。

如果模型无法保存到磁盘,则在 AuraDS 实例重新启动时将丢失。

result = gds.alpha.model.store(model)

print(result)
CALL gds.alpha.model.store("example_graph_model_for_graphsage")
# Cypher query
save_graph_sage_model_to_disk_query = """
    CALL gds.alpha.model.store("example_graph_model_for_graphsage")
"""

# Create the driver session
with driver.session() as session:
    # Run query
    result = session.run(save_graph_sage_model_to_disk_query).data()

    # Prettify the result
    print(json.dumps(result, indent=2, sort_keys=True))

如果在持久化模型后再次列出模型目录,我们可以看到该模型的 stored 标志已设置为 true

results = gds.beta.model.list()

print(results)
CALL gds.beta.model.list()
# Cypher query
list_model_catalog_query  = """
    CALL gds.beta.model.list()
"""

# Create the driver session
with driver.session() as session:
    # Run query
    results = session.run(list_model_catalog_query).data()

    # Prettify the results
    print(json.dumps(results, indent=2, sort_keys=True, default=default))

与其他用户共享模型

创建模型后,将其提供给其他用户以用于不同的用例可能很有用。

模型只能与同一 AuraDS 实例的其他用户共享。

创建新用户

为了查看这在 AuraDS 上是如何在实践中工作的,我们首先需要 创建另一个用户 来与之共享模型。

# Switch to the "system" database to run the
# "CREATE USER" admin command
gds.set_database("system")

gds.run_cypher("""
    CREATE USER testUser IF NOT EXISTS
    SET PASSWORD 'password'
    SET PASSWORD CHANGE NOT REQUIRED
""")
:connect system

CREATE USER testUser IF NOT EXISTS
SET PASSWORD 'password'
SET PASSWORD CHANGE NOT REQUIRED
# Cypher query
create_a_new_user_query = """
    CREATE USER testUser IF NOT EXISTS
    SET PASSWORD 'password'
    SET PASSWORD CHANGE NOT REQUIRED
"""

# Create the driver session using the "system" database
with driver.session(database="system") as session:
    # Run query
    result = session.run(create_a_new_user_query).data()

    # Prettify the result
    print(json.dumps(result, indent=2, sort_keys=True))

发布模型

可以使用 gds.alpha.model.publish 过程发布(使其他用户能够访问)模型。发布后,模型名称会通过在其原始名称后附加 _public 来更新。

# Switch back to the default "neo4j" database
# to publish the model
gds.set_database("neo4j")

model_public = gds.alpha.model.publish(model)

print(model_public)
:connect neo4j

CALL gds.alpha.model.publish('example_graph_model_for_graphsage')
# Cypher query
publish_graph_sage_model_to_disk_query  = """
    CALL gds.alpha.model.publish('example_graph_model_for_graphsage')
"""

# Create the driver session
with driver.session() as session:
    # Run query
    result = session.run(publish_graph_sage_model_to_disk_query).data()

    # Prettify the result
    print(json.dumps(result, indent=2, sort_keys=True, default=default))

以不同用户身份查看模型

为了验证已发布的模型是否对我们刚刚创建的用户可见,我们需要创建一个新的客户端(或驱动程序)会话。然后,我们可以在新用户下再次使用它来运行 gds.beta.model.list 过程,并验证模型是否包含在列表中。

test_user_gds = GraphDataScience(
    AURA_CONNECTION_URI,
    auth=("testUser", "password"),
    aura_ds=True
)

results = test_user_gds.beta.model.list()

print(results)
// First, open a new Cypher shell with the following command:
//
// ./cypher-shell -a $AURA_CONNECTION_URI -u testUser -p password

CALL gds.beta.model.list()
test_user_driver = GraphDatabase.driver(
    AURA_CONNECTION_URI,
    auth=("testUser", "password")
)

# Create the driver session
with test_user_driver.session() as session:
    # Run query
    results = session.run(list_model_catalog_query).data()

    # Prettify the results
    print(json.dumps(results, indent=2, sort_keys=True, default=default))

清理

内存中的图、Neo4j 数据库中的数据、模型和测试用户现在都可以删除了。

# Delete the example dataset
gds.run_cypher("""
    MATCH (example:ExampleData)
    DETACH DELETE example
""")

# Delete the projected graph from memory
gds.graph.drop(g)

# Drop the model from memory
gds.beta.model.drop(model_public)

# Delete the model from disk
gds.alpha.model.delete(model_public)

# Switch to the "system" database to delete the example user
gds.set_database("system")

gds.run_cypher("""
    DROP USER testUser
""")
// Delete the example dataset from the database
MATCH (example:ExampleData)
DETACH DELETE example;

// Delete the projected graph from memory
CALL gds.graph.drop("example_graph_for_graphsage");

// Drop the model from memory
CALL gds.beta.model.drop("example_graph_model_for_graphsage_public");

// Delete the model from disk
CALL gds.alpha.model.delete("example_graph_model_for_graphsage_public");

// Delete the example user
DROP USER testUser;
# Delete the example dataset from the database
delete_example_graph_query = """
    MATCH (example:ExampleData)
    DETACH DELETE example
"""

# Delete the projected graph from memory
drop_in_memory_graph_query = """
    CALL gds.graph.drop("example_graph_for_graphsage")
"""

# Drop the model from memory
drop_example_models_query = """
    CALL gds.beta.model.drop("example_graph_model_for_graphsage_public")
"""

# Delete the model from disk
delete_example_models_query = """
    CALL gds.alpha.model.delete("example_graph_model_for_graphsage_public")
"""

# Delete the example user
drop_example_user_query = """
    DROP USER testUser
"""

# Create the driver session
with driver.session() as session:
    # Run queries
    print(session.run(delete_example_graph_query).data())
    print(session.run(drop_in_memory_graph_query).data())
    print(session.run(drop_example_models_query).data())
    print(session.run(delete_example_models_query).data())

# Create another driver session on the system database
# to drop the test user
with driver.session(database='system') as session:
    print(session.run(drop_example_user_query).data())

driver.close()
test_user_driver.close()

关闭连接

当不再需要连接时,应始终关闭它。

尽管 GDS 客户端在对象被删除时会自动关闭连接,但最好显式关闭它。

# Close the client connection
gds.close()
# Close the driver connection
driver.close()

参考文献

Cypher

  • 了解有关 Cypher 语法的更多信息

  • 您可以使用 Cypher 速查表 作为所有可用 Cypher 功能的参考