逻辑回归

此功能处于 Beta 阶段。有关功能层级的更多信息,请参阅 API 层级

逻辑回归是一种基本的监督机器学习分类方法。它通过最小化取决于权重矩阵和训练数据的损失函数来训练模型。例如,可以使用梯度下降来最小化损失。在 GDS 中,我们使用 Adam 优化器,这是一种梯度下降类型的算法。

权重以 [c,d] 大小的矩阵 W 和长度为 c 的偏置向量 b 的形式存在,其中 d 是特征维度,c 等于类别数量。损失函数定义为

CE(softmax(Wx + b))

其中 CE交叉熵损失softmaxsoftmax 函数x 是长度为 d 的特征向量训练样本。

为了避免过拟合,还可以在损失中添加一个 正则化 项。Neo4j 图数据科学支持 l2 正则化选项,该选项可以使用 penalty 参数进行配置。

调整超参数

为了平衡模型的偏差与方差以及训练速度与内存消耗等因素,GDS 提供了几个可以调整的超参数。每个参数的说明如下。

在基于梯度下降的训练中,我们尝试为模型找到最佳权重。在每个迭代周期中,我们处理所有训练样本以计算损失和权重的梯度。然后使用这些梯度来更新权重。对于更新,我们使用 Adam 优化器,如 https://arxiv.org/pdf/1412.6980.pdf 中所述。

训练统计信息将在 Neo4j 调试日志中报告。

最大迭代次数(Epochs)

此参数定义了训练的最大迭代次数。无论模型质量如何,训练都将在达到这些迭代次数后终止。请注意,如果损失收敛,训练也可以提前停止(参见耐心容忍度)。

设置此参数有助于限制模型的训练时间。限制计算预算可以起到正则化作用,并减轻过拟合,因为当迭代次数过多时,过拟合会成为一个风险。

最小迭代次数(Epochs)

此参数定义了训练的最小迭代次数。无论模型质量如何,训练都将至少运行这么多迭代次数。

设置此参数有助于避免提前停止,但也会增加模型的最小训练时间。

耐心

此参数定义了连续无效迭代的最大数量。如果一个迭代周期没有使训练损失至少提高当前损失的 tolerance 比例,则认为该迭代周期是无效的。

假设训练已运行了 minEpochs 设定的迭代次数,此参数定义了训练何时收敛。

设置此参数可以使训练更稳健,并类似于 minEpochs 避免提前终止。然而,过高的耐心值可能导致运行不必要的过多迭代次数。

根据我们的经验,patience 的合理值在 13 之间。

容忍度

此参数定义了何时将一个迭代周期视为无效,并与 patience 一起定义了训练的收敛标准。如果一个迭代周期没有使训练损失至少提高当前损失的 tolerance 比例,则认为该迭代周期是无效的。

较低的容忍度会导致更敏感的训练,并有更高的概率进行更长时间的训练。较高的容忍度意味着训练的敏感度较低,因此会有更多的迭代周期被视为无效。

学习率

更新权重时,我们根据损失函数的梯度,按照 Adam 优化器指定的方向移动。每次权重更新的移动量可以通过 learningRate 参数进行配置。

批次大小

此参数定义了单个批次中包含多少个训练样本。

梯度在批次上使用 concurrency 个线程并行计算。在每个迭代周期结束时,梯度在更新权重之前被求和并缩放。batchSize 不影响模型质量,但可用于调整训练速度。较大的 batchSize 会增加计算的内存消耗。

惩罚项

此参数定义了损失函数中正则化项的影响。虽然正则化可以避免过拟合,但过高的值甚至可能导致欠拟合。最小值为零时,正则化项完全没有效果。

类别权重

此参数引入了类别权重的概念,该概念由 T. Lin 等人在《用于密集对象检测的 Focal Loss》中进行了研究。它通常被称为平衡交叉熵。它为交叉熵损失函数中的每个类别分配一个权重,从而允许模型以不同的重要性处理不同的类别。它为每个样本定义为

balanced cross entropy

其中 at 表示真实类别的类别权重。pt 表示真实类别的概率。

对于类别不平衡问题,类别权重通常设置为类别频率的倒数,以改善模型对少数类别的归纳偏置。

对于链接预测,它必须是一个长度为 2 的列表,其中第一个权重用于负样本(缺失的关系),第二个权重用于正样本(实际关系)。

在节点分类中的用法

对于节点分类,第 ith 个权重对应第 ith 个类别,按类别值排序(必须是整数)。例如,如果您的节点分类数据集有三个类别:0、1、42。那么类别权重长度必须为 3。第三个权重应用于类别 42。

焦点权重

此参数引入了焦点损失(focal loss)的概念,同样由 T. Lin 等人在《用于密集对象检测的 Focal Loss》中进行了研究。当 focusWeight 的值大于零时,损失函数从标准交叉熵损失变为焦点损失。它为每个样本定义为

focal loss

其中 pt 表示真实类别的概率。focusWeight 参数是表示为 g 的指数。

增加 focusWeight 将引导模型尝试拟合“困难的”错误分类样本。一个困难的错误分类样本是指模型对真实类别的预测概率较低的样本。在上述方程中,对于真实类别概率低的样本,损失将呈指数级增加,从而调整模型以尝试拟合这些样本,但可能会牺牲对“简单”样本的置信度。

在类别不平衡数据集中,少数类别通常更难正确分类。了解更多关于链接预测中的类别不平衡信息,请参阅类别不平衡

© . All rights reserved.