多层感知机

此功能处于 Alpha 阶段。有关功能阶段的更多信息,请参阅 API 阶段

多层感知机 (MLP) 是一种前馈神经网络。它由多层相互连接的神经元组成。神经元的值是通过对来自前一层的加权输入进行聚合,并应用激活函数计算得出的。对于分类任务,输出层的大小取决于类别的数量。为了优化网络权重,GDS 使用带有交叉熵损失 (Cross Entropy Loss) 的梯度下降法。

调整超参数

为了平衡模型的偏差与方差、训练速度与内存消耗等事项,GDS 提供了若干可调整的超参数。每一项的说明如下。

在基于梯度下降的训练中,我们尝试为模型找到最佳权重。在每个 epoch 中,我们处理所有训练样本以计算损失和权重的梯度。这些梯度随后被用于更新权重。正如 https://arxiv.org/pdf/1412.6980.pdf 中所述,我们在更新时使用 Adam 优化器。

有关训练的统计信息记录在 neo4j 调试日志中。

最大迭代次数 (Max Epochs)

此参数定义训练的最大迭代次数(epoch)。无论模型质量如何,训练都会在达到此次数后终止。请注意,如果损失函数收敛,训练也可能提前停止(参见耐心值 (Patience)容差 (Tolerance))。

设置此参数有助于限制模型的训练时间。限制计算预算可以起到正则化作用并减轻过拟合,当轮次过多时,过拟合会成为一个风险。

最小迭代次数 (Min Epochs)

此参数定义训练的最小迭代次数。无论模型质量如何,训练至少会运行这么多轮。

设置此参数有助于避免提前停止,但也会增加模型的最小训练时间。

耐心值 (Patience)

此参数定义了非生产性连续 epoch 的最大次数。如果一个 epoch 未能将训练损失改进至少 tolerance(容差)比例,则该 epoch 被视为非生产性的。

假设训练运行了 minEpochs 次,此参数定义了训练何时收敛。

设置此参数可以使训练更加稳健,并类似于 minEpochs 避免过早终止。然而,过高的耐心值可能导致运行多于必要的 epoch。

根据我们的经验,patience 的合理取值范围在 13 之间。

容差 (Tolerance)

此参数定义了何时认为一个 epoch 是非生产性的,它与 patience 一起定义了训练的收敛准则。如果一个 epoch 未能将训练损失改进至少 tolerance(容差)比例,则该 epoch 被视为非生产性的。

较低的容差会带来更敏感的训练,更有可能延长训练时间。较高的容差意味着训练不那么敏感,从而导致更多的 epoch 被计为非生产性。

学习率

在更新权重时,我们根据损失函数的梯度,沿着 Adam 优化器指定的方向移动。你可以通过 learningRate(学习率)参数配置每次权重更新的步长。

批大小 (Batch size)

此参数定义了单个批次(batch)中包含多少个训练样本。

梯度是在批次上使用 concurrency 个线程并发计算的。在 epoch 结束时,梯度会被累加并缩放,然后用于更新权重。batchSize 不会影响模型质量,但可用于调整训练速度。较大的 batchSize 会增加计算的内存消耗。

惩罚项 (Penalty)

该参数定义了损失函数中正则化项的影响。虽然正则化可以避免过拟合,但过大的值甚至可能导致欠拟合。最小值为零,此时正则化项不起任何作用。

隐藏层大小 (HiddenLayerSizes)

该参数定义了神经网络的形状。列表中的每个条目代表该层中神经元的数量。列表的长度定义了隐藏层的层数。从理论上讲,更深、更大的网络可以更好地逼近高阶曲面,但代价是需要训练更多的权重(和偏置)。

类别权重 (Class weights)

该参数引入了 T. Lin 等人在《Focal Loss for Dense Object Detection》一文中研究的类别权重概念。它通常被称为平衡交叉熵 (balanced cross entropy)。它为交叉熵损失函数中的每个类别分配一个权重,从而允许模型以不同的重要性对待不同的类别。每个样本的定义如下:

balanced cross entropy

其中 at 表示真实类别的权重。pt 表示真实类别的概率。

对于类别不平衡问题,通常将类别权重设置为类别频率的倒数,以提高模型在少数类别上的归纳偏差 (inductive bias)。

对于链路预测,它必须是一个长度为 2 的列表,其中第一个权重用于负样本(缺失的关系),第二个权重用于正样本(实际存在的关系)。

在节点分类中的用法

对于节点分类,ith 权重对应于按类别值(必须是整数)排序的 ith 类别。例如,如果您的节点分类数据集有三个类别:0, 1, 42。那么类别权重必须为长度 3 的列表。第三个权重将应用于类别 42。

聚焦权重 (Focus weight)

该参数引入了焦点损失 (focal loss) 的概念,同样源自 T. Lin 等人的《Focal Loss for Dense Object Detection》。当 focusWeight 大于零时,损失函数会从标准的交叉熵损失变为焦点损失。每个样本的定义如下:

focal loss

其中 pt 表示真实类别的概率。focusWeight 参数是标记为 g 的指数。

增加 focusWeight 将引导模型尝试拟合那些“难以”分类的错误示例。所谓的难分类示例,是指模型对真实类别的预测概率较低的示例。在上述公式中,对于真实类别概率较低的示例,损失会呈指数级增长,从而促使模型专注于拟合这些示例,代价是可能对“简单”示例的预测信心降低。

在类别不平衡的数据集中,少数类别通常更难正确分类。有关链路预测中类别不平衡的更多信息,请参阅 类别不平衡 (Class Imbalance)