GraphSAGE 节点分类训练

GraphSAGE 是一种图神经网络 (GNN) 架构,可用作监督式算法来预测图中节点的类标签。本节提供了如何使用 GraphSAGE 端点在 Snowflake 的 Neo4j 图分析中进行节点分类模型训练的说明。

语法

本节介绍了执行 GraphSAGE 节点分类训练算法所使用的语法。

运行 GraphSAGE 节点分类训练。
CALL graph.gs_nc_train(
  'CPU_X64_XS',                    (1)
  {
    ['defaultTablePrefix': '...',] (2)
    'project': {...},              (3)
    'compute': {...},              (4)
  }
);
1 计算池选择器。
2 表引用的可选前缀。
3 项目配置。
4 计算配置。
表 1. 参数
名称 类型 默认 可选 描述

computePoolSelector

字符串

不适用

用于运行 GraphSAGE 节点分类训练作业的计算池选择器。

配置

Map

{}

用于图项目、算法计算和结果回写的配置。

对于此算法,我们强烈建议使用 GPU 计算池,除非数据集非常小且模型层数较少。

配置映射由以下三个条目组成。

有关下方项目(Project)配置的更多详细信息,请参阅 项目文档
表 2. 项目配置
名称 类型

nodeTables

节点表列表。

relationshipTables

关系类型到关系表的映射。

请注意,为了使 GraphSAGE 能够正确传播节点嵌入(node embeddings)的更新,每种类型的节点必须至少是某种关系类型的目标。`orientation` 参数有助于为仅作为关系源的节点类型添加反向关系(使用 "REVERSE" 或 "UNDIRECTED" 方向)。

表 3. 计算配置
名称 类型 默认 可选 描述

targetLabel

字符串

不适用

用于训练预测的节点标签(即类型)。

targetProperty

字符串

不适用

用于训练预测的节点属性,由指定 'target_label' 的输入节点表中的一列表示。允许使用空值,表示未标记的节点(请参阅下方的 半监督学习)。

modelname

字符串

不适用

要训练的模型的名称(必须唯一)。

numEpochs

整数

不适用

训练模型的轮数 (epochs)。

numSamples

整数列表

不适用

每层要采样的邻居数量。请注意,这也决定了层数。

hiddenChannels

整数

256

模型层输出的节点嵌入维度。

activation

字符串

"relu"

要使用的激活函数。有效值为 "relu" 和 "sigmoid"。

aggregator

字符串

"mean"

要使用的邻域嵌入聚合器。有效值为 "mean" 和 "max"。

learningRate

浮点数

0.001

优化器的学习率。

dropout

浮点数

0.1

每一层的 dropout 概率。必须是 >= 0.0 且 < 1.0 的值。

layerNormalization

布尔值

true

是否在模型层之间应用层归一化。

epochsPerCheckpoint

整数

max(numEpochs / 10, 1)

保存模型检查点之间的轮数。

randomSeed

整数

随机整数

用于初始化计算中所有随机性的数字

splitRatios

Map

{'TRAIN': 0.6, 'TEST': 0.2, 'VALID': 0.2}

将输入图的目标节点拆分为训练集、测试集和验证集的比例映射。键必须是 "TRAIN"、"TEST" 和 "VALID"。值之和必须为 1.0。

epochsPerVal

整数

0

在验证集上评估模型之间的轮数。如果设置为 0,则模型不会在验证集上进行评估。

trainBatchSize

整数

自动推断

每个批次中用于训练的目标节点数量。如果未提供,算法将自动在可用内存约束内推断允许的最大批次大小。

evalBatchSize

整数

训练批次大小

用于评估的批次大小。

classWeights

布尔值或映射 (Map)

false

是否使用类权重来平衡训练数据。如果设置为 true,将根据训练集中目标标签的分布计算类权重。如果设置为映射,则该映射必须包含每个目标类标签的类权重。

半监督学习

目标属性中的空值

targetProperty 列可能包含空值,代表未标记的节点。这支持半监督学习,即模型在部分标记的图上进行训练。

具有空目标值的节点将:

  • 从训练和评估中排除:它们不参与损失计算或指标(准确率、F1 分数)计算。

  • 包含在消息传递中:它们的特征和图连通性仍被 GNN 用于学习节点表示。

  • 在推理过程中被预测:训练好的模型可以为所有节点预测类标签,包括那些在训练期间未标记的节点。

当只有部分节点具有已知的类标签,但整个图结构和节点特征均可用时,此功能非常有用。

示例

在我们的示例中,我们将使用包含演员、导演、电影和流派的 IMDB 数据集。这些节点都关联有关键词,我们将把这些关键词用作节点的特征。它们通过演员参演电影和导演执导电影的关系连接。目标是预测电影的流派。

我们有一个名为 imdb 的数据库,其中包含以下表:

  • actor,具有 nodeidplot_keywords

  • movie,包含 nodeidplot_keywordsgenre

  • director,具有 nodeidplot_keywords

  • acted_in,具有 sourcenodeidtargetnodeid 列,代表 actormovie 节点的 ID

  • directed_in,具有 sourcenodeidtargetnodeid 列,代表 directormovie 节点的 ID

plot_keywords 列包含与节点关联的关键词,编码为浮点数向量。genre 列包含我们要预测的电影节点的目标类标签。

您可以按照 GitHub 上的说明将此数据集上传到您的 Snowflake 账户:neo4j-product-examples/snowflake-graph-analytics

训练查询

在下面的查询中,我们在数据集上训练了一个用于节点分类的 GraphSAGE 模型。我们训练了 10 个轮次,包含两个隐藏层,并使用类权重来平衡类分布。

要运行查询,需要为应用程序、消费者角色和环境进行必要的权限授予设置。请参阅 入门 页面以了解更多信息。

我们还假设应用程序名称为默认的 Neo4j_Graph_Analytics。如果您在安装过程中选择了不同的应用程序名称,请将其替换为该名称。

CALL Neo4j_Graph_Analytics.graph.gs_nc_train('GPU_NV_S', {
    'defaultTablePrefix': 'imdb.gml',
    'project': {
        'nodeTables': ['actor', 'director', 'movie'],
        'relationshipTables': {
            'acted_in': {
                'sourceTable': 'actor',
                'targetTable': 'movie',
                'orientation': 'UNDIRECTED'
            },
            'directed_in': {
                'sourceTable': 'director',
                'targetTable': 'movie',
                'orientation': 'UNDIRECTED'
            }
        }
    },
    'compute': {
        'modelname': 'nc-imdb',
        'numEpochs': 10,
        'numSamples': [20, 20],
        'targetLabel': 'movie',
        'targetProperty': 'genre',
        'classWeights': true
    }
});

上述查询应产生类似于下方的结果。数值结果可能会有所不同。

JOB_ID JOB_STATUS JOB_START JOB_END JOB_RESULT

job_a039eb4b52d0465ba7d22c99e5bc222a

SUCCESS

2025-11-28T14:13:20.503496

2025-11-28T14:14:45.278090

{
  "gs_nc_train": {
    "metrics": {
      "test_acc": 0.7329843044281006,
      "test_f1_macro": 0.7029824256896973,
      "test_f1_micro": 0.7329843640327454,
      "train_acc": 0.9876957535743713,
      "train_f1_macro": 0.9864031672477722,
      "train_f1_micro": 0.9876957535743713
    },
    "train_ms": 76998
  },
  "project": {
    "graphName": "snowgraph",
    "nodeCount": 12772,
    "nodeLabels": {
      "ACTOR": {
        "count": 5841,
        "nodeId": {
          "dataType": "int16"
        },
        "properties": {
          "plot_keywords": {
            "dataType": "ndarray[float32]",
            "dimension": 1256
          }
        },
        "table": "IMDB.GML.ACTOR"
      },
      "DIRECTOR": {
        "count": 2270,
        "nodeId": {
          "dataType": "int16"
        },
        "properties": {
          "plot_keywords": {
            "dataType": "ndarray[float32]",
            "dimension": 1256
          }
        },
        "table": "IMDB.GML.DIRECTOR"
      },
      "MOVIE": {
        "count": 4661,
        "nodeId": {
          "dataType": "int16"
        },
        "properties": {
          "genre": {
            "dataType": "float64",
            "dimension": 1
          },
          "plot_keywords": {
            "dataType": "ndarray[float32]",
            "dimension": 1256
          }
        },
        "table": "IMDB.GML.MOVIE"
      }
    },
    "nodeMillis": 1221,
    "relationshipCount": 18644,
    "relationshipMillis": 178,
    "relationshipTypes": {
      "ACTED_IN": {
        "count": 13983,
        "direction": "UNDIRECTED",
        "sourceTable": "IMDB.GML.ACTOR",
        "table": "IMDB.GML.ACTED_IN",
        "targetTable": "IMDB.GML.MOVIE"
      },
      "DIRECTED_IN": {
        "count": 4661,
        "direction": "UNDIRECTED",
        "sourceTable": "IMDB.GML.DIRECTOR",
        "table": "IMDB.GML.DIRECTED_IN",
        "targetTable": "IMDB.GML.MOVIE"
      }
    },
    "totalMillis": 1399
  }
}