关系嵌入模型
一个并不少见的场景是,用户已经在 Graph Data Science (GDS) 库之外训练好了知识图谱嵌入 (KGE) 模型,并将模型训练的输出存储在 Neo4j 数据库中。针对此类情况,GDS 支持使用这些 KGE 模型输出和 KGE 评分函数来推断 GDS 图投影中的新关系。目前支持的评分函数包括 TransE 和 DistMult。
下面我们将介绍如何使用这些功能。首先了解相关方法及其签名,然后通过一个小型的玩具图示例进行端到端的演示。
在下面的示例中,我们假设已经实例化了一个名为 gds 的 GraphDataScience 对象。请在 入门指南 中阅读更多相关信息。
1. 创建关系嵌入模型
在 GDS 中使用预训练的 KGE 模型预测新关系的工作流程的第一步,是创建一个关系嵌入模型。
有两种方法可以实现这一点,每种方法对应一种支持的 KGE 评分函数:
-
gds.model.transe.create用于创建使用 TransE 评分函数的模型,以及 -
gds.model.distmult.create用于创建使用 DistMult 评分函数的模型。
这两个方法都会返回一个 SimpleRelEmbeddingModel,我们很快就会介绍其用法。它们也采用相同的参数:
| 名称 | 类型 | |
|---|---|---|
|
|
表示模型所基于的图的对象 |
|
|
存储 KGE 模型嵌入的节点属性名称 |
|
|
关系类型名称到 KGE 模型关系类型嵌入的映射 |
2. 使用关系嵌入模型进行预测
SimpleRelEmbeddingModel 代表一个基于 KGE 模型的关系嵌入模型。它有三种预测新关系的方法。推断新嵌入的计算过程是相同的,但之后处理新关系的方式有所不同。
该类具有三种方法:
-
predict_stream用于流式返回预测出的关系, -
predict_mutate用于将关系添加到投影图中, -
predict_write用于将关系写回 Neo4j 数据库。
由于这些方法中预测部分的计算是相同的,因此它们共享一组参数:
| 名称 | 类型 | |
|---|---|---|
|
|
需要考虑的源节点的定义。可以是节点标签、节点 ID 或节点 ID 列表 |
|
|
需要考虑的源节点的定义。可以是节点标签、节点 ID 或节点 ID 列表 |
|
|
其嵌入将用于计算的关系类型名称 |
|
|
为每个源节点生成多少个关系。对于每个源节点,得分最高的 |
|
|
作为可选关键字参数的通用 GDS 算法配置 |
特别是,支持作为此算法关键字参数的通用算法配置参数包括 concurrency、jobId 和 logProgress。您可以在 GDS 手册中的此处阅读更多相关信息。
现在让我们概述这些预测方法之间的差异。
2.1. 流式传输预测关系
predict_stream 方法返回一个包含三列的 pandas.DataFrame:sourceNodeId、targetNodeId 和 score。它们分别指代源节点 ID、目标节点 ID,以及在节点对和关系类型上运行 KGE 模型评分函数得到的得分。
除了上述概述的参数外,该方法没有额外的参数。
2.2. 使用预测关系更改图投影
predict_mutate 方法通过 mutate_relationship_type 参数指定的新类型,将预测的关系添加到图投影中。每个此类关系都将拥有一个通过 mutateProperty 参数指定的属性,代表在节点对和关系类型上运行 KGE 模型评分函数的输出。该方法返回一个包含计算元数据的 pandas.Series。
除了上述概述的共享参数外,该方法在 top_k 参数之后,按顺序还有两个位置参数:
| 名称 | 类型 | |
|---|---|---|
|
|
预测关系的新关系类型名称 |
|
|
新关系上将存储模型预测得分的属性名称 |
| 名称 | 类型 | |
|---|---|---|
|
|
创建的关系数量 |
|
|
向投影图添加属性的毫秒数 |
|
|
计算百分位数的毫秒数 |
|
|
预处理数据的毫秒数 |
|
|
运行预测算法的毫秒数 |
|
|
运行算法时使用的配置 |
2.3. 将预测关系写回数据库
predict_write 方法通过 write_relationship_type 参数指定的新类型,将预测的关系写回 Neo4j 数据库。每个此类关系都将拥有一个通过 writeProperty 参数指定的属性,代表在节点对和关系类型上运行 KGE 模型评分函数的输出。
除了上述概述的共享参数外,该方法在 top_k 参数之后,按顺序还有两个位置参数:
| 名称 | 类型 | |
|---|---|---|
|
|
预测关系的新关系类型名称 |
|
|
新关系上将存储模型预测得分的属性名称 |
该方法返回一个包含计算元数据的 pandas.Series。
| 名称 | 类型 | |
|---|---|---|
|
|
创建的关系数量 |
|
|
将结果数据写回 Neo4j 数据库的毫秒数 |
|
|
预处理数据的毫秒数 |
|
|
运行预测算法的毫秒数 |
|
|
运行算法时使用的配置 |
3. 检查关系嵌入模型
SimpleRelEmbeddingModel 类中有一些方法可以让我们检查模型。它们不需要任何输入,只需返回有关模型的信息。如下所示:
| 名称 | 返回类型 | 描述 |
|---|---|---|
|
|
返回模型正在使用的评分函数名称 |
|
|
返回模型所基于的图名称 |
|
|
返回图中存储嵌入的节点属性名称 |
|
|
返回模型的关系类型嵌入 |
4. 示例
在本节中,我们将举例说明如何创建和使用基于使用 TransE 评分函数训练的 KGE 模型的关系嵌入模型。这其中一部分是拥有一个包含 KGE 模型嵌入的 Graph 投影。
因此,我们首先引入一个小型的道路网络图以及一些居民。
gds.run_cypher(
"""
CREATE
(a:City {name: "New York City", settled: 1624, emb: [0.52173235, 0.85803989, 0.31678055]}),
(b:City {name: "Philadelphia", settled: 1682, emb: [0.61455845, 0.79957553, 0.83513986]}),
(c:City:Capital {name: "Washington D.C.", settled: 1790, emb: [0.54354943, 0.64039515, 0.23094848]}),
(d:City {name: "Baltimore", settled: 1729, emb: [0.67689553, 0.28851121, 0.43250516]}),
(e:City {name: "Atlantic City", settled: 1854, emb: [0.79804478, 0.81980933, 0.9322812]}),
(f:City {name: "Boston", settled: 1822, emb: [0.15583946, 0.16060805, 0.52078528]}),
(g:Person {name: "Brian", emb: [0.4142066 , 0.18411476, 0.68245374]}),
(h:Person {name: "Olga", emb: [0.61230904, 0.7735076 , 0.09668418]}),
(i:Person {name: "Jacob", emb: [0.87470625, 0.63589938, 0.33536311]}),
(a)-[:ROAD {cost: 50}]->(b),
(a)-[:ROAD {cost: 50}]->(c),
(a)-[:ROAD {cost: 100}]->(d),
(b)-[:ROAD {cost: 40}]->(d),
(c)-[:ROAD {cost: 40}]->(d),
(c)-[:ROAD {cost: 80}]->(e),
(d)-[:ROAD {cost: 30}]->(e),
(d)-[:ROAD {cost: 80}]->(f),
(e)-[:ROAD {cost: 40}]->(f),
(g)-[:LIVES_IN]->(a),
(h)-[:LIVES_IN]->(f),
(i)-[:LIVES_IN]->(e);
"""
)
G, project_result = gds.graph.project(
graph_name="road_graph",
node_spec={"City": {"properties": ["emb"]}, "Person": {"properties": ["emb"]}},
relationship_spec=["ROAD", "LIVES_IN"]
)
# Sanity check
assert G.relationship_count() == 12
此处的 "emb" 节点属性包含了我们将在计算中用于推断新关系的 TransE 节点嵌入。
4.1. 创建并检查我们的模型
使用我们的图 G 和预计算的关系类型嵌入,我们现在可以构建一个 TransE 关系嵌入模型。
transe_model = gds.model.transe.create(
G,
node_embedding_property="emb",
relationship_type_embeddings={
"ROAD": [0.88355126, 0.15116676, 0.24225456],
"LIVES_IN": [0.94185368, 0.60460752, 0.92028837]
}
)
# Sanity check
assert transe_model.scoring_function() == "transe"
模型创建完成后,我们就可以开始预测我们图中的新关系了。
4.2. 进行预测
让我们让模型预测我们感兴趣的三位居民未来可能移动到的位置,并用这些新关系更改以 G 表示的 GDS 投影。
result = transe_model.predict_mutate(
source_node_filter="Person",
target_node_filter="City",
relationship_type="LIVES_IN",
top_k=1,
mutate_relationship_type="MIGHT_MOVE",
mutate_property="likeliness_score"
)
# Let us make sure the number of new relationships makes sense
assert result["relationshipsWritten"] == 3
assert G.relationship_count() == 12 + 3
利用 TransE 嵌入和 GDS 的关系嵌入模型功能,我们能够推断出我们感兴趣的居民未来可能会移动到的位置。我们创建的新的 "MIGHT_MOVE" 关系现在已成为以 G 表示的 GDS 图投影的一部分。