模型目录中的模型对象

GDS 模型目录中的模型在 Python 客户端中表示为 Model 对象,类似于图对象的表示方式。Model 对象通常通过训练流水线 (pipeline)GraphSAGE 模型来构建,在这种情况下,系统会返回一个指向已训练模型的 Model 对象引用。

创建完成后,Model 对象可以作为参数传递给 Python 客户端中的方法,例如模型目录操作。此外,Model 对象还提供了一些便捷方法,允许在不显式调用模型目录的情况下检查所代表的模型。

在下面的示例中,我们假设已经实例化了一个名为 gdsGraphDataScience 对象。请在 入门指南 中阅读更多相关信息。

1. 构建模型对象

构建模型对象的主要方式是训练模型。模型有两种类型:流水线模型和 GraphSAGE 模型。为了训练流水线模型,必须先创建并配置一个流水线。有关如何操作流水线的更多信息,请参阅机器学习流水线,其中包含使用流水线模型的示例。在本节中,我们将演示如何创建和使用 GraphSAGE 模型对象。

首先,我们引入一个小型道路网络图

gds.run_cypher(
  """
  CREATE
    (a:City {name: "New York City", settled: 1624}),
    (b:City {name: "Philadelphia", settled: 1682}),
    (c:City:Capital {name: "Washington D.C.", settled: 1790}),
    (d:City {name: "Baltimore", settled: 1729}),
    (e:City {name: "Atlantic City", settled: 1854}),
    (f:City {name: "Boston", settled: 1822}),

    (a)-[:ROAD {cost: 50}]->(b),
    (a)-[:ROAD {cost: 50}]->(c),
    (a)-[:ROAD {cost: 100}]->(d),
    (b)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 80}]->(e),
    (d)-[:ROAD {cost: 30}]->(e),
    (d)-[:ROAD {cost: 80}]->(f),
    (e)-[:ROAD {cost: 40}]->(f);
  """
)
G, project_result = gds.graph.project(
    "road_graph",
    {"City": {"properties": ["settled"]}},
    {"ROAD": {"properties": ["cost"]}}
)

assert G.relationship_count() == 9

现在我们可以使用图 G 来训练 GraphSAGE 模型。

model, train_result = gds.beta.graphSage.train(G, modelName="city-representation", featureProperties=["settled"], randomSeed=42)

assert train_result["modelInfo"]["metrics"]["ranEpochs"] == 1

其中 model 是模型对象,res 是一个包含底层过程调用元数据的 pandas Series

同样,我们也可以通过训练机器学习流水线来获得模型对象。

若要获取一个代表已训练且存在于模型目录中的模型对象,可以调用仅存在于客户端的 get 方法并传入名称

model = gds.model.get("city-representation")

assert model.name() == "city-representation"

get 方法不使用任何层级 (tier) 前缀,因为它不与任何层级相关联。它仅存在于客户端,没有对应的 Cypher 过程。

2. 检查模型对象

所有模型对象上都有一些便捷方法,让我们可以提取有关所代表模型的信息。

表 1. 模型对象方法
名称 参数 返回类型 描述

名称 (name)

-

str

模型在模型目录中显示的名称。

type

-

str

模型类型,例如 "graphSage"。

train_config

-

Series

用于训练模型的配置。

graph_schema

-

Series

模型训练所基于的图模式 (schema)。

loaded

-

bool

如果模型已加载到内存模型目录中,则为 True,否则为 False

stored

-

bool

如果模型已存储在磁盘上,则为 True,否则为 False

creation_time

-

neo4j.time.Datetime

模型创建的时间。

shared

-

bool

如果模型在用户间共享,则为 True,否则为 False

exists

-

bool

如果模型存在于 GDS 模型目录中,则为 True,否则为 False

drop

failIfMissing: Optional[bool]

Series

将模型从 GDS 模型目录中移除

例如,要获取上面创建的模型对象 model 的训练配置,我们可以执行以下操作

train_config = model.train_config()

assert train_config["concurrency"] == 4

3. 使用模型对象

使用模型对象的主要方式是进行预测。GraphSAGE 的相关操作如下所述,流水线的相关操作请参阅机器学习流水线页面。

此外,模型对象还可以用作 GDS 模型目录操作的输入。例如,假设我们有上面创建的模型对象 model,我们可以执行

# Store the model on disk (GDS Enterprise Edition)
_ = gds.model.store(model)

gds.model.drop(model)  # same as model.drop()

# Load the model again for further use
gds.model.load(model.name())

3.1. GraphSAGE

如上面在构建模型对象中所述,使用 Python 客户端训练 GraphSAGE 模型类似于其 Cypher 对应过程

训练完成后,除了上述方法外,GraphSAGE 模型对象还将具有以下方法。

表 2. GraphSAGE 模型方法
名称 参数 返回类型 描述

predict_mutate

G: 图对象,
config: **kwargs

Series

预测输入图中节点的嵌入,并将预测结果写入图。.

predict_stream

G: 图对象,
config: **kwargs

DataFrame

预测输入图中节点的嵌入并流式传输结果。.

predict_write

G: 图对象,
config: **kwargs

Series

预测输入图中节点的嵌入并将结果写回数据库。.

metrics

-

Series

返回训练时计算的指标值。

因此,对于我们上面训练的 GraphSAGE 模型 model,我们可以执行以下操作

# Make sure our training actually converged
metrics = model.metrics()
assert metrics["didConverge"]

# Predict on `G` and write embedding node properties back to the database
predict_result = model.predict_write(G, writeProperty="embedding")
assert predict_result["nodePropertiesWritten"] == G.node_count()