知识图谱嵌入:在 PyG 中训练,在 GDS 中预测
此 Jupyter 笔记本托管在 Neo4j 图数据科学客户端 Github 存储库 中。
此笔记本演示了如何使用 graphdatascience
和 PyTorch Geometric (PyG) Python 库来
-
将 FB15k-237 数据集直接导入 GDS
-
使用 PyG 训练 TransE 模型
-
使用 GDS 知识图谱嵌入功能对数据库中的数据进行预测
1. 先决条件
要运行此笔记本,您需要安装了最新 GDS 版本 (2.5+ 或更高版本) 的 Neo4j 服务器。
此外,还需要以下 Python 库
-
graphdatascience
,查看文档以获取安装说明 -
pytorch-geometric
版本 >= 2.5.0,查看 PyG 文档以获取安装说明
2. 设置
我们将从导入依赖项并建立与数据库的 GDS 客户端连接开始。
%pip install graphdatascience torch torch_geometric
import os
from graphdatascience import GraphDataScience
import torch
import torch.optim as optim
from torch_geometric.data import Data, download_url
from torch_geometric.nn import TransE
import collections
from tqdm import tqdm
import pandas as pd
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
NEO4J_AUTH = None
NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j")
if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"):
NEO4J_AUTH = (
os.environ.get("NEO4J_USER"),
os.environ.get("NEO4J_PASSWORD"),
)
gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB)
# This notebook requires GDS 2.5.0 or later
assert gds.version() >= "2.5.0"
3. 在数据库中下载和存储 FB15k-237 数据集
下载 FB15k-237 数据集 提取所需文件:train.txt、valid.txt 和 test.txt。
import os
import zipfile
url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip"
raw_dir = "./data_from_zip"
download_url(f"{url}", raw_dir)
raw_file_names = ["train.txt", "valid.txt", "test.txt"]
with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref:
for filename in raw_file_names:
zip_ref.extract(f"Release/{filename}", path=raw_dir)
data_dir = raw_dir + "/" + "Release"
设置唯一 ID 条目的约束以加快数据上传速度。
gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE")
创建实体节点:使用标签 Entity
创建一个节点。此节点应具有属性 id
和 text
。- 语法:(:Entity {id: int, text: str})
为 PyG 训练创建关系:根据训练阶段,创建类型为 TRAIN
、TEST
或 VALID
的关系。这些关系中的每一个都应该有一个 rel_id
属性。- 示例语法:[:TRAIN {rel_id: int}]
为 GDS 预测创建关系:对于预测阶段,创建特定类型的关系,表示为 REL_i
。这些关系中的每一个都应该具有 rel_id
和 text
属性。- 示例语法:[:REL_7 {rel_id: int, text: str}]
rel_types = {
"train.txt": "TRAIN",
"valid.txt": "VALID",
"test.txt": "TEST",
}
rel_id_to_text_dict = {}
rel_type_dict = collections.defaultdict(list)
rel_dict = {}
def process():
node_dict_ = {}
for file_name in raw_file_names:
file_name_path = data_dir + "/" + file_name
with open(file_name_path, "r") as f:
data = [x.split("\t") for x in f.read().split("\n")[:-1]]
list_of_dicts = []
for i, (src, rel, dst) in enumerate(data):
if src not in node_dict_:
node_dict_[src] = len(node_dict_)
if dst not in node_dict_:
node_dict_[dst] = len(node_dict_)
if rel not in rel_dict:
rel_dict[rel] = len(rel_dict)
rel_id_to_text_dict[rel_dict[rel]] = rel
source = node_dict_[src]
target = node_dict_[dst]
edge_type = rel_dict[rel]
rel_type_dict[edge_type].append(
{
"source": source,
"target": target,
}
)
list_of_dicts.append(
{
"source": source,
"source_text": src,
"target": target,
"target_text": dst,
"rel_id": edge_type,
}
)
rel_type = rel_types[file_name]
print(f"Writing {len(list_of_dicts)} entities of {rel_type}")
gds.run_cypher(
f"""
UNWIND $ll as l
MERGE (n:Entity {{id:l.source, text:l.source_text}})
MERGE (m:Entity {{id:l.target, text:l.target_text}})
MERGE (n)-[:{rel_type} {{rel_id:l.rel_id}}]->(m)
""",
params={"ll": list_of_dicts},
)
print("Writing relationships as different relationship types")
for rel_id, rels in tqdm(rel_type_dict.items()):
REL_TYPE = f"REL_{rel_id}"
gds.run_cypher(
f"""
UNWIND $ll AS l MATCH (n:Entity {{id:l.source}}), (m:Entity {{id:l.target}})
MERGE (n)-[:{REL_TYPE} {{rel_id:$rel_id, text:$text}}]->(m)
""",
params={"ll": rels, "rel_id": rel_id, "text": rel_id_to_text_dict[rel_id]},
)
process()
将图形中的所有数据进行投影,以获取 id
与数据库中的内部 nodeId
字段之间的映射。
node_projection = {"Entity": {"properties": "id"}}
relationship_projection = [
{"TRAIN": {"orientation": "NATURAL", "properties": "rel_id"}},
{"TEST": {"orientation": "NATURAL", "properties": "rel_id"}},
{"VALID": {"orientation": "NATURAL", "properties": "rel_id"}},
]
ttv_G, result = gds.graph.project(
"fb15k-graph-ttv",
node_projection,
relationship_projection,
)
node_properties = gds.graph.nodeProperties.stream(
ttv_G,
["id"],
separate_property_columns=True,
)
nodeId_to_id = dict(zip(node_properties.nodeId, node_properties.id))
id_to_nodeId = dict(zip(node_properties.id, node_properties.nodeId))
4. 使用 PyG 训练 TransE 模型
从数据库中检索数据,将其转换为 torch 张量,并将其格式化为适合使用 PyG 进行训练的 Data
结构。
def create_data_from_graph(relationship_type):
rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type)
topology = [
rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),
rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),
]
edge_index = torch.tensor(topology, dtype=torch.long)
edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)
data = Data(edge_index=edge_index, edge_type=edge_type)
data.num_nodes = len(nodeId_to_id)
display(data)
return data
train_tensor_data = create_data_from_graph("TRAIN")
test_tensor_data = create_data_from_graph("TEST")
val_tensor_data = create_data_from_graph("VALID")
删除投影的图形以节省内存。
gds.graph.drop(ttv_G)
TransE 模型的训练过程遵循相应的 PyG 示例。
def train_model_with_pyg():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TransE(
num_nodes=train_tensor_data.num_nodes,
num_relations=train_tensor_data.num_edge_types,
hidden_channels=50,
).to(device)
loader = model.loader(
head_index=train_tensor_data.edge_index[0],
rel_type=train_tensor_data.edge_type,
tail_index=train_tensor_data.edge_index[1],
batch_size=1000,
shuffle=True,
)
optimizer = optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
total_loss = total_examples = 0
for head_index, rel_type, tail_index in loader:
optimizer.zero_grad()
loss = model.loss(head_index, rel_type, tail_index)
loss.backward()
optimizer.step()
total_loss += float(loss) * head_index.numel()
total_examples += head_index.numel()
return total_loss / total_examples
@torch.no_grad()
def test(data):
model.eval()
return model.test(
head_index=data.edge_index[0],
rel_type=data.edge_type,
tail_index=data.edge_index[1],
batch_size=1000,
k=10,
)
# Consider increasing the number of epochs
epoch_count = 5
for epoch in range(1, epoch_count):
loss = train()
print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")
if epoch % 75 == 0:
rank, hits = test(val_tensor_data)
print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}")
torch.save(model, f"./model_{epoch_count}.pt")
mean_rank, mrr, hits_at_k = test(test_tensor_data)
print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}")
return model
model = train_model_with_pyg()
# The model can be loaded if it was trained before
# model = torch.load("./model_501.pt")
从训练好的模型中提取节点嵌入,并将它们放入数据库中。
for i in tqdm(range(len(nodeId_to_id))):
gds.run_cypher(
"MATCH (n:Entity {id: $i}) SET n.emb=$EMBEDDING",
params={"i": i, "EMBEDDING": model.node_emb.weight[i].tolist()},
)
5. 使用 GDS 知识图谱边缘嵌入功能进行预测
选择要进行预测的关系类型。
relationship_to_predict = "/film/film/genre"
rel_id_to_predict = rel_dict[relationship_to_predict]
rel_label_to_predict = f"REL_{rel_id_to_predict}"
使用所有节点和所选类型的现有关系将图形进行投影。
G_test, result = gds.graph.project(
"graph_to_predict_",
{"Entity": {"properties": ["id", "emb"]}},
rel_label_to_predict,
)
def print_graph_info(G):
print(f"Graph '{G.name()}' node count: {G.node_count()}")
print(f"Graph '{G.name()}' node labels: {G.node_labels()}")
print(f"Graph '{G.name()}' relationship types: {G.relationship_types()}")
print(f"Graph '{G.name()}' relationship count: {G.relationship_count()}")
print_graph_info(G_test)
从 PyG 模型中检索所选关系的嵌入。然后,使用图形、节点嵌入属性和要预测的关系的嵌入来创建一个 GDS TransE 模型。
target_emb = model.node_emb.weight[rel_id_to_predict].tolist()
transe_model = gds.model.transe.create(G_test, "emb", {rel_label_to_predict: target_emb})
source_node_list = ["/m/07l450", "/m/0ds2l81", "/m/0jvt9"]
source_ids_df = gds.run_cypher(
"UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId",
params={"node_text_list": source_node_list},
)
现在,我们可以使用该模型进行预测。
result = transe_model.predict_stream(
source_node_filter=source_ids_df.nodeId,
target_node_filter="Entity",
relationship_type=rel_label_to_predict,
top_k=3,
concurrency=4,
)
print(result)
使用节点标识符及其文本值来扩展预测结果。
ids_in_result = pd.unique(pd.concat([result.sourceNodeId, result.targetNodeId]))
ids_to_text = gds.run_cypher(
"UNWIND $ids AS id MATCH (n:Entity) WHERE id(n)=id RETURN id(n) AS nodeId, n.text AS tag, n.id AS id",
params={"ids": ids_in_result},
)
nodeId_to_text_res = dict(zip(ids_to_text.nodeId, ids_to_text.tag))
nodeId_to_id_res = dict(zip(ids_to_text.nodeId, ids_to_text.id))
result.insert(1, "sourceTag", result.sourceNodeId.map(lambda x: nodeId_to_text_res[x]))
result.insert(2, "sourceId", result.sourceNodeId.map(lambda x: nodeId_to_id_res[x]))
result.insert(4, "targetTag", result.targetNodeId.map(lambda x: nodeId_to_text_res[x]))
result.insert(5, "targetId", result.targetNodeId.map(lambda x: nodeId_to_id_res[x]))
print(result)
6. 使用写入模式
写入模式允许您将结果直接写入数据库作为新的关系类型。这种方法有助于避免从 nodeId
到 id
的映射。
write_relationship_type = "PREDICTED_" + rel_label_to_predict
result_write = transe_model.predict_write(
source_node_filter=source_ids_df.nodeId,
target_node_filter="Entity",
relationship_type=rel_label_to_predict,
write_relationship_type=write_relationship_type,
write_property="transe_score",
top_k=3,
concurrency=4,
)
从数据库中提取结果。
gds.run_cypher(
"MATCH (n)-[r:"
+ write_relationship_type
+ "]->(m) RETURN n.id AS sourceId, n.text AS sourceTag, m.id AS targetId, m.text AS targetTag, r.transe_score AS score"
)
gds.graph.drop(G_test)