模型目录中的模型对象
GDS 模型目录中的模型在 Python 客户端中表示为 Model 对象,类似于图对象的表示方式。Model 对象通常通过训练流水线 (pipeline) 或 GraphSAGE 模型来构建,在这种情况下,系统会返回一个指向已训练模型的 Model 对象引用。
创建完成后,Model 对象可以作为参数传递给 Python 客户端中的方法,例如模型目录操作。此外,Model 对象还提供了一些便捷方法,允许在不显式调用模型目录的情况下检查所代表的模型。
在下面的示例中,我们假设已经实例化了一个名为 gds 的 GraphDataScience 对象。请在 入门指南 中阅读更多相关信息。
1. 构建模型对象
构建模型对象的主要方式是训练模型。模型有两种类型:流水线模型和 GraphSAGE 模型。为了训练流水线模型,必须先创建并配置一个流水线。有关如何操作流水线的更多信息,请参阅机器学习流水线,其中包含使用流水线模型的示例。在本节中,我们将演示如何创建和使用 GraphSAGE 模型对象。
首先,我们引入一个小型道路网络图
gds.run_cypher(
"""
CREATE
(a:City {name: "New York City", settled: 1624}),
(b:City {name: "Philadelphia", settled: 1682}),
(c:City:Capital {name: "Washington D.C.", settled: 1790}),
(d:City {name: "Baltimore", settled: 1729}),
(e:City {name: "Atlantic City", settled: 1854}),
(f:City {name: "Boston", settled: 1822}),
(a)-[:ROAD {cost: 50}]->(b),
(a)-[:ROAD {cost: 50}]->(c),
(a)-[:ROAD {cost: 100}]->(d),
(b)-[:ROAD {cost: 40}]->(d),
(c)-[:ROAD {cost: 40}]->(d),
(c)-[:ROAD {cost: 80}]->(e),
(d)-[:ROAD {cost: 30}]->(e),
(d)-[:ROAD {cost: 80}]->(f),
(e)-[:ROAD {cost: 40}]->(f);
"""
)
G, project_result = gds.graph.project(
"road_graph",
{"City": {"properties": ["settled"]}},
{"ROAD": {"properties": ["cost"]}}
)
assert G.relationship_count() == 9
现在我们可以使用图 G 来训练 GraphSAGE 模型。
model, train_result = gds.beta.graphSage.train(G, modelName="city-representation", featureProperties=["settled"], randomSeed=42)
assert train_result["modelInfo"]["metrics"]["ranEpochs"] == 1
其中 model 是模型对象,res 是一个包含底层过程调用元数据的 pandas Series。
同样,我们也可以通过训练机器学习流水线来获得模型对象。
若要获取一个代表已训练且存在于模型目录中的模型对象,可以调用仅存在于客户端的 get 方法并传入名称
model = gds.model.get("city-representation")
assert model.name() == "city-representation"
|
|
2. 检查模型对象
所有模型对象上都有一些便捷方法,让我们可以提取有关所代表模型的信息。
| 名称 | 参数 | 返回类型 | 描述 |
|---|---|---|---|
|
|
|
模型在模型目录中显示的名称。 |
|
|
|
模型类型,例如 "graphSage"。 |
|
|
|
用于训练模型的配置。 |
|
|
|
模型训练所基于的图模式 (schema)。 |
|
|
|
如果模型已加载到内存模型目录中,则为 |
|
|
|
如果模型已存储在磁盘上,则为 |
|
|
|
模型创建的时间。 |
|
|
|
如果模型在用户间共享,则为 |
|
|
|
如果模型存在于 GDS 模型目录中,则为 |
|
|
|
例如,要获取上面创建的模型对象 model 的训练配置,我们可以执行以下操作
train_config = model.train_config()
assert train_config["concurrency"] == 4
3. 使用模型对象
此外,模型对象还可以用作 GDS 模型目录操作的输入。例如,假设我们有上面创建的模型对象 model,我们可以执行
# Store the model on disk (GDS Enterprise Edition)
_ = gds.model.store(model)
gds.model.drop(model) # same as model.drop()
# Load the model again for further use
gds.model.load(model.name())
3.1. GraphSAGE
如上面在构建模型对象中所述,使用 Python 客户端训练 GraphSAGE 模型类似于其 Cypher 对应过程。
训练完成后,除了上述方法外,GraphSAGE 模型对象还将具有以下方法。
| 名称 | 参数 | 返回类型 | 描述 |
|---|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
返回训练时计算的指标值。 |
因此,对于我们上面训练的 GraphSAGE 模型 model,我们可以执行以下操作
# Make sure our training actually converged
metrics = model.metrics()
assert metrics["didConverge"]
# Predict on `G` and write embedding node properties back to the database
predict_result = model.predict_write(G, writeProperty="embedding")
assert predict_result["nodePropertiesWritten"] == G.node_count()