知识图谱嵌入:在 PyG 中训练,在 GDS 中预测

Open In Colab

此 Jupyter 笔记本托管在 Neo4j 图数据科学客户端 Github 存储库 中。

此笔记本演示了如何使用 graphdatascience 和 PyTorch Geometric (PyG) Python 库来

  1. FB15k-237 数据集直接导入 GDS

  2. 使用 PyG 训练 TransE 模型

  3. 使用 GDS 知识图谱嵌入功能对数据库中的数据进行预测

1. 先决条件

要运行此笔记本,您需要安装了最新 GDS 版本 (2.5+ 或更高版本) 的 Neo4j 服务器。

此外,还需要以下 Python 库

2. 设置

我们将从导入依赖项并建立与数据库的 GDS 客户端连接开始。

%pip install graphdatascience torch torch_geometric
import os
from graphdatascience import GraphDataScience
import torch
import torch.optim as optim
from torch_geometric.data import Data, download_url
from torch_geometric.nn import TransE
import collections
from tqdm import tqdm
import pandas as pd
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
NEO4J_AUTH = None
NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j")
if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"):
    NEO4J_AUTH = (
        os.environ.get("NEO4J_USER"),
        os.environ.get("NEO4J_PASSWORD"),
    )
gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB)
# This notebook requires GDS 2.5.0 or later
assert gds.version() >= "2.5.0"

3. 在数据库中下载和存储 FB15k-237 数据集

下载 FB15k-237 数据集 提取所需文件:train.txt、valid.txt 和 test.txt。

import os
import zipfile

url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip"
raw_dir = "./data_from_zip"
download_url(f"{url}", raw_dir)

raw_file_names = ["train.txt", "valid.txt", "test.txt"]
with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref:
    for filename in raw_file_names:
        zip_ref.extract(f"Release/{filename}", path=raw_dir)
data_dir = raw_dir + "/" + "Release"

设置唯一 ID 条目的约束以加快数据上传速度。

gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE")

创建实体节点:使用标签 Entity 创建一个节点。此节点应具有属性 idtext。- 语法:(:Entity {id: int, text: str})

为 PyG 训练创建关系:根据训练阶段,创建类型为 TRAINTESTVALID 的关系。这些关系中的每一个都应该有一个 rel_id 属性。- 示例语法:[:TRAIN {rel_id: int}]

为 GDS 预测创建关系:对于预测阶段,创建特定类型的关系,表示为 REL_i。这些关系中的每一个都应该具有 rel_idtext 属性。- 示例语法:[:REL_7 {rel_id: int, text: str}]

rel_types = {
    "train.txt": "TRAIN",
    "valid.txt": "VALID",
    "test.txt": "TEST",
}
rel_id_to_text_dict = {}
rel_type_dict = collections.defaultdict(list)
rel_dict = {}


def process():
    node_dict_ = {}
    for file_name in raw_file_names:
        file_name_path = data_dir + "/" + file_name

        with open(file_name_path, "r") as f:
            data = [x.split("\t") for x in f.read().split("\n")[:-1]]

        list_of_dicts = []
        for i, (src, rel, dst) in enumerate(data):
            if src not in node_dict_:
                node_dict_[src] = len(node_dict_)
            if dst not in node_dict_:
                node_dict_[dst] = len(node_dict_)
            if rel not in rel_dict:
                rel_dict[rel] = len(rel_dict)
                rel_id_to_text_dict[rel_dict[rel]] = rel

            source = node_dict_[src]
            target = node_dict_[dst]
            edge_type = rel_dict[rel]

            rel_type_dict[edge_type].append(
                {
                    "source": source,
                    "target": target,
                }
            )
            list_of_dicts.append(
                {
                    "source": source,
                    "source_text": src,
                    "target": target,
                    "target_text": dst,
                    "rel_id": edge_type,
                }
            )

        rel_type = rel_types[file_name]
        print(f"Writing {len(list_of_dicts)} entities of {rel_type}")
        gds.run_cypher(
            f"""
            UNWIND $ll as l
            MERGE (n:Entity {{id:l.source, text:l.source_text}})
            MERGE (m:Entity {{id:l.target, text:l.target_text}})
            MERGE (n)-[:{rel_type} {{rel_id:l.rel_id}}]->(m)
            """,
            params={"ll": list_of_dicts},
        )

    print("Writing relationships as different relationship types")
    for rel_id, rels in tqdm(rel_type_dict.items()):
        REL_TYPE = f"REL_{rel_id}"
        gds.run_cypher(
            f"""
            UNWIND $ll AS l MATCH (n:Entity {{id:l.source}}), (m:Entity {{id:l.target}})
            MERGE (n)-[:{REL_TYPE} {{rel_id:$rel_id, text:$text}}]->(m)
            """,
            params={"ll": rels, "rel_id": rel_id, "text": rel_id_to_text_dict[rel_id]},
        )


process()

将图形中的所有数据进行投影,以获取 id 与数据库中的内部 nodeId 字段之间的映射。

node_projection = {"Entity": {"properties": "id"}}
relationship_projection = [
    {"TRAIN": {"orientation": "NATURAL", "properties": "rel_id"}},
    {"TEST": {"orientation": "NATURAL", "properties": "rel_id"}},
    {"VALID": {"orientation": "NATURAL", "properties": "rel_id"}},
]

ttv_G, result = gds.graph.project(
    "fb15k-graph-ttv",
    node_projection,
    relationship_projection,
)

node_properties = gds.graph.nodeProperties.stream(
    ttv_G,
    ["id"],
    separate_property_columns=True,
)

nodeId_to_id = dict(zip(node_properties.nodeId, node_properties.id))
id_to_nodeId = dict(zip(node_properties.id, node_properties.nodeId))

4. 使用 PyG 训练 TransE 模型

从数据库中检索数据,将其转换为 torch 张量,并将其格式化为适合使用 PyG 进行训练的 Data 结构。

def create_data_from_graph(relationship_type):
    rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type)
    topology = [
        rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),
        rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),
    ]
    edge_index = torch.tensor(topology, dtype=torch.long)
    edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)
    data = Data(edge_index=edge_index, edge_type=edge_type)
    data.num_nodes = len(nodeId_to_id)
    display(data)
    return data


train_tensor_data = create_data_from_graph("TRAIN")
test_tensor_data = create_data_from_graph("TEST")
val_tensor_data = create_data_from_graph("VALID")

删除投影的图形以节省内存。

gds.graph.drop(ttv_G)

TransE 模型的训练过程遵循相应的 PyG 示例

def train_model_with_pyg():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = TransE(
        num_nodes=train_tensor_data.num_nodes,
        num_relations=train_tensor_data.num_edge_types,
        hidden_channels=50,
    ).to(device)

    loader = model.loader(
        head_index=train_tensor_data.edge_index[0],
        rel_type=train_tensor_data.edge_type,
        tail_index=train_tensor_data.edge_index[1],
        batch_size=1000,
        shuffle=True,
    )

    optimizer = optim.Adam(model.parameters(), lr=0.01)

    def train():
        model.train()
        total_loss = total_examples = 0
        for head_index, rel_type, tail_index in loader:
            optimizer.zero_grad()
            loss = model.loss(head_index, rel_type, tail_index)
            loss.backward()
            optimizer.step()
            total_loss += float(loss) * head_index.numel()
            total_examples += head_index.numel()
        return total_loss / total_examples

    @torch.no_grad()
    def test(data):
        model.eval()
        return model.test(
            head_index=data.edge_index[0],
            rel_type=data.edge_type,
            tail_index=data.edge_index[1],
            batch_size=1000,
            k=10,
        )

    # Consider increasing the number of epochs
    epoch_count = 5
    for epoch in range(1, epoch_count):
        loss = train()
        print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")
        if epoch % 75 == 0:
            rank, hits = test(val_tensor_data)
            print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}")

    torch.save(model, f"./model_{epoch_count}.pt")

    mean_rank, mrr, hits_at_k = test(test_tensor_data)
    print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}")

    return model
model = train_model_with_pyg()
# The model can be loaded if it was trained before
# model = torch.load("./model_501.pt")

从训练好的模型中提取节点嵌入,并将它们放入数据库中。

for i in tqdm(range(len(nodeId_to_id))):
    gds.run_cypher(
        "MATCH (n:Entity {id: $i}) SET n.emb=$EMBEDDING",
        params={"i": i, "EMBEDDING": model.node_emb.weight[i].tolist()},
    )

5. 使用 GDS 知识图谱边缘嵌入功能进行预测

选择要进行预测的关系类型。

relationship_to_predict = "/film/film/genre"
rel_id_to_predict = rel_dict[relationship_to_predict]
rel_label_to_predict = f"REL_{rel_id_to_predict}"

使用所有节点和所选类型的现有关系将图形进行投影。

G_test, result = gds.graph.project(
    "graph_to_predict_",
    {"Entity": {"properties": ["id", "emb"]}},
    rel_label_to_predict,
)


def print_graph_info(G):
    print(f"Graph '{G.name()}' node count: {G.node_count()}")
    print(f"Graph '{G.name()}' node labels: {G.node_labels()}")
    print(f"Graph '{G.name()}' relationship types: {G.relationship_types()}")
    print(f"Graph '{G.name()}' relationship count: {G.relationship_count()}")


print_graph_info(G_test)

从 PyG 模型中检索所选关系的嵌入。然后,使用图形、节点嵌入属性和要预测的关系的嵌入来创建一个 GDS TransE 模型。

target_emb = model.node_emb.weight[rel_id_to_predict].tolist()
transe_model = gds.model.transe.create(G_test, "emb", {rel_label_to_predict: target_emb})
source_node_list = ["/m/07l450", "/m/0ds2l81", "/m/0jvt9"]
source_ids_df = gds.run_cypher(
    "UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId",
    params={"node_text_list": source_node_list},
)

现在,我们可以使用该模型进行预测。

result = transe_model.predict_stream(
    source_node_filter=source_ids_df.nodeId,
    target_node_filter="Entity",
    relationship_type=rel_label_to_predict,
    top_k=3,
    concurrency=4,
)
print(result)

使用节点标识符及其文本值来扩展预测结果。

ids_in_result = pd.unique(pd.concat([result.sourceNodeId, result.targetNodeId]))

ids_to_text = gds.run_cypher(
    "UNWIND $ids AS id MATCH (n:Entity) WHERE id(n)=id RETURN id(n) AS nodeId, n.text AS tag, n.id AS id",
    params={"ids": ids_in_result},
)

nodeId_to_text_res = dict(zip(ids_to_text.nodeId, ids_to_text.tag))
nodeId_to_id_res = dict(zip(ids_to_text.nodeId, ids_to_text.id))

result.insert(1, "sourceTag", result.sourceNodeId.map(lambda x: nodeId_to_text_res[x]))
result.insert(2, "sourceId", result.sourceNodeId.map(lambda x: nodeId_to_id_res[x]))
result.insert(4, "targetTag", result.targetNodeId.map(lambda x: nodeId_to_text_res[x]))
result.insert(5, "targetId", result.targetNodeId.map(lambda x: nodeId_to_id_res[x]))

print(result)

6. 使用写入模式

写入模式允许您将结果直接写入数据库作为新的关系类型。这种方法有助于避免从 nodeIdid 的映射。

write_relationship_type = "PREDICTED_" + rel_label_to_predict
result_write = transe_model.predict_write(
    source_node_filter=source_ids_df.nodeId,
    target_node_filter="Entity",
    relationship_type=rel_label_to_predict,
    write_relationship_type=write_relationship_type,
    write_property="transe_score",
    top_k=3,
    concurrency=4,
)

从数据库中提取结果。

gds.run_cypher(
    "MATCH (n)-[r:"
    + write_relationship_type
    + "]->(m) RETURN n.id AS sourceId, n.text AS sourceTag, m.id AS targetId, m.text AS targetTag, r.transe_score AS score"
)
gds.graph.drop(G_test)