逻辑回归

此功能处于 beta 级别。有关功能级别的更多信息,请参见 API 级别.

逻辑回归是一种基本的监督机器学习分类方法。它通过最小化取决于权重矩阵和训练数据的损失函数来训练模型。例如,可以使用梯度下降法来最小化损失。在 GDS 中,我们使用 Adam 优化器,这是一种梯度下降类型的算法。

权重采用 [c,d] 大小的矩阵 W 和长度为 c 的偏差向量 b 的形式,其中 d 是特征维数,c 等于类的数量。然后损失函数定义为

CE(softmax(Wx + b))

其中 CE交叉熵损失softmaxsoftmax 函数x 是长度为 d 的特征向量训练样本。

为了避免过拟合,还可以将 正则化 项添加到损失中。Neo4j 图数据科学支持 l2 正则化的选项,可以使用 penalty 参数进行配置。

调整超参数

为了平衡模型的偏差与方差、训练的速度与内存消耗等问题,GDS 公开了几个可以调整的超参数。下面将对每个参数进行说明。

在基于梯度下降法的训练中,我们试图找到模型的最佳权重。在每个 epoch 中,我们处理所有训练样本以计算损失和权重的梯度。然后使用这些梯度来更新权重。对于更新,我们使用 Adam 优化器,如 https://arxiv.org/pdf/1412.6980.pdf 中所述。

有关训练的统计信息会在 neo4j 调试日志中报告。

最大 Epochs

此参数定义训练的最大 Epochs 数。无论模型的质量如何,训练将在这些 Epochs 之后终止。请注意,如果损失收敛,训练也可能提前停止(参见 PatienceTolerance)。

设置此参数对于限制模型的训练时间很有用。限制计算预算可以起到正则化的作用,并减轻过拟合的风险,过拟合在 Epochs 数过多时会成为一个风险。

最小 Epochs

此参数定义训练的最小 Epochs 数。无论模型的质量如何,训练至少会运行这些 Epochs。

设置此参数对于避免过早停止很有用,但也会增加模型的最小训练时间。

耐心

此参数定义无生产力的连续 Epochs 的最大数量。如果 Epoch 没有将训练损失提高当前损失的至少 tolerance 分数,则该 Epoch 就算无生产力。

假设训练运行了 minEpochs 次,此参数定义了训练何时收敛。

设置此参数可以使训练更加稳健,并避免与 minEpochs 相似的过早终止。但是,较高的 patience 值可能会导致运行比必要更多的 epoch。

根据我们的经验,patience 的合理值范围为 13

容忍度

此参数定义了何时将一个 epoch 视为非生产性,并与 patience 一起定义了训练的收敛标准。如果一个 epoch 未能将训练损失提高当前损失的至少 tolerance 倍,则它是非生产性的。

较低的容忍度会导致更敏感的训练,训练时间更长的概率更高。较高的容忍度意味着训练不太敏感,因此会导致更多 epoch 被视为非生产性的。

学习率

更新权重时,我们根据损失函数的梯度,朝着由 Adam 优化器指示的方向移动。您可以通过 learningRate 参数配置每次权重更新的移动量。

批次大小

此参数定义了在一个批次中分组的训练示例数量。

使用 concurrency 个线程同时计算批次上的梯度。在一个 epoch 结束时,梯度在更新权重之前被求和并缩放。batchSize 不会影响模型质量,但可用于调整训练速度。较大的 batchSize 会增加计算的内存消耗。

惩罚

此参数定义了正则化项在损失函数中的影响。虽然正则化可以避免过拟合,但较高的值甚至会导致欠拟合。最小值为零,正则化项没有任何影响。

类别权重

此参数引入了类别权重的概念,在 T. Lin 等人的“Focal Loss for Dense Object Detection”中进行了研究。它通常被称为平衡交叉熵。它在交叉熵损失函数中为每个类别分配一个权重,从而允许模型以不同的重要性对待不同的类别。它为每个示例定义为

balanced cross entropy

其中 at 表示真实类别的类别权重。pt 表示真实类别的概率。

对于类别不平衡问题,类别权重通常设置为类别频率的倒数,以提高模型对少数类别的归纳偏差。

对于链接预测,它必须是一个长度为 2 的列表,第一个权重用于负面示例(缺少关系),第二个权重用于正面示例(实际关系)。

在节点分类中的使用

对于节点分类,ith 权重用于 ith 类别,按类别值排序(必须是整数)。例如,如果您的节点分类数据集有三个类别:0、1、42。那么类别权重必须长度为 3。第三个权重应用于类别 42。

焦点权重

此参数引入了焦点损失的概念,同样是在 T. Lin 等人的“Focal Loss for Dense Object Detection”中进行了研究。当 focusWeight 的值为大于零时,损失函数从标准交叉熵损失变为焦点损失。它为每个示例定义为

focal loss

其中 pt 表示真实类别的概率。focusWeight 参数是指数,表示为 g

增加 focusWeight 将引导模型尝试拟合“困难”的错误分类示例。一个困难的错误分类示例是指模型对真实类别具有低预测概率的示例。在上面的等式中,对于低真实类别概率示例,损失将呈指数级增长,从而调整模型以尝试拟合它们,但代价可能是对“容易”示例的置信度降低。

在类别不平衡的数据集中,少数类别通常难以正确分类。有关链接预测的类别不平衡的更多信息,请参阅 类别不平衡