逻辑回归

此功能处于 Beta 测试阶段。有关功能分级的更多信息,请参阅 API 分级

逻辑回归是一种基础的监督机器学习分类方法。它通过最小化取决于权重矩阵和训练数据的损失函数来训练模型。损失可以通过例如梯度下降法进行最小化。在 GDS 中,我们使用 Adam 优化器,这是一种梯度下降类型的算法。

权重以 [c,d] 大小的矩阵 W 和长度为 c 的偏置向量 b 的形式存在,其中 d 是特征维度,c 等于类别的数量。损失函数定义为:

CE(softmax(Wx + b))

其中 CE交叉熵损失softmaxSoftmax 函数x 是长度为 d 的特征向量训练样本。

为了避免过拟合,可以在损失函数中添加一个正则化项。Neo4j 图数据科学支持 l2 正则化选项,可以通过 penalty 参数进行配置。

调整超参数

为了平衡模型的偏差与方差、训练速度与内存消耗等事项,GDS 提供了若干可调整的超参数。每一项的说明如下。

在基于梯度下降的训练中,我们尝试为模型找到最佳权重。在每个 epoch 中,我们处理所有训练样本以计算损失和权重的梯度。这些梯度随后被用于更新权重。正如 https://arxiv.org/pdf/1412.6980.pdf 中所述,我们在更新时使用 Adam 优化器。

有关训练的统计信息记录在 neo4j 调试日志中。

最大迭代次数 (Max Epochs)

此参数定义训练的最大迭代次数(epoch)。无论模型质量如何,训练都会在达到此次数后终止。请注意,如果损失函数收敛,训练也可能提前停止(参见耐心值 (Patience)容差 (Tolerance))。

设置此参数有助于限制模型的训练时间。限制计算预算可以起到正则化作用并减轻过拟合,当轮次过多时,过拟合会成为一个风险。

最小迭代次数 (Min Epochs)

此参数定义训练的最小迭代次数。无论模型质量如何,训练至少会运行这么多轮。

设置此参数有助于避免提前停止,但也会增加模型的最小训练时间。

耐心值 (Patience)

此参数定义了非生产性连续 epoch 的最大次数。如果一个 epoch 未能将训练损失改进至少 tolerance(容差)比例,则该 epoch 被视为非生产性的。

假设训练运行了 minEpochs 次,此参数定义了训练何时收敛。

设置此参数可以使训练更加稳健,并类似于 minEpochs 避免过早终止。然而,过高的耐心值可能导致运行多于必要的 epoch。

根据我们的经验,patience 的合理取值范围在 13 之间。

容差 (Tolerance)

此参数定义了何时认为一个 epoch 是非生产性的,它与 patience 一起定义了训练的收敛准则。如果一个 epoch 未能将训练损失改进至少 tolerance(容差)比例,则该 epoch 被视为非生产性的。

较低的容差会带来更敏感的训练,更有可能延长训练时间。较高的容差意味着训练不那么敏感,从而导致更多的 epoch 被计为非生产性。

学习率

在更新权重时,我们根据损失函数的梯度,沿着 Adam 优化器指定的方向移动。你可以通过 learningRate(学习率)参数配置每次权重更新的步长。

批大小 (Batch size)

此参数定义了单个批次(batch)中包含多少个训练样本。

梯度是在批次上使用 concurrency 个线程并发计算的。在 epoch 结束时,梯度会被累加并缩放,然后用于更新权重。batchSize 不会影响模型质量,但可用于调整训练速度。较大的 batchSize 会增加计算的内存消耗。

惩罚项 (Penalty)

此参数定义了损失函数中正则化项的影响力。虽然正则化可以避免过拟合,但过大的值甚至会导致欠拟合。最小值为零,此时正则化项完全不起作用。

类别权重 (Class weights)

此参数引入了 类别权重 的概念,该概念由 T. Lin 等人在“Focal Loss for Dense Object Detection”一文中研究。它通常被称为 平衡交叉熵 (balanced cross entropy)。它为交叉熵损失函数中的每个类别分配一个权重,从而允许模型以不同的重要性对待不同的类别。对于每个样本,它定义为

balanced cross entropy

其中 at 表示真实类别的权重。pt 表示真实类别的预测概率。

对于类别不平衡的问题,类别权重通常设置为类别频率的倒数,以改善模型在少数类上的归纳偏置。

对于链路预测,它必须是一个长度为 2 的列表,其中第一个权重用于负样本(缺失的关系),第二个权重用于正样本(实际存在的关系)。

在节点分类中的使用

对于节点分类,ith 权重对应于按类别值(必须是整数)排序的 ith 个类别。例如,如果你的节点分类数据集有三个类别:0, 1, 42。那么类别权重必须为长度 3 的列表。第三个权重将应用于类别 42。

聚焦权重 (Focus weight)

此参数引入了 焦点损失 (focal loss) 的概念,同样由 T. Lin 等人在“Focal Loss for Dense Object Detection”中研究。当 focusWeight 大于零时,损失函数会从标准的交叉熵损失变为焦点损失。对于每个样本,它定义为

focal loss

其中 pt 表示真实类别的概率。focusWeight 参数是标注为 g 的指数。

增加 focusWeight 将引导模型尝试拟合那些“难以区分”的误分类样本。难以区分的样本是指模型对真实类别的预测概率较低的样本。在上述方程中,对于真实类别概率较低的样本,损失会呈指数级增大,从而促使模型尝试拟合这些样本,代价是模型对“容易”样本的置信度可能会降低。

在类别不平衡的数据集中,少数类通常更难被正确分类。阅读更多关于链路预测中类别不平衡的信息,请查看 类别不平衡 (Class Imbalance)