逻辑回归
此功能处于 Beta 阶段。有关功能层级的更多信息,请参阅 API 层级。
逻辑回归是一种基本的监督机器学习分类方法。它通过最小化取决于权重矩阵和训练数据的损失函数来训练模型。例如,可以使用梯度下降来最小化损失。在 GDS 中,我们使用 Adam 优化器,这是一种梯度下降类型的算法。
权重以 [c,d]
大小的矩阵 W
和长度为 c
的偏置向量 b
的形式存在,其中 d
是特征维度,c
等于类别数量。损失函数定义为
CE(softmax(Wx + b))
其中 CE
是 交叉熵损失,softmax
是 softmax 函数,x
是长度为 d
的特征向量训练样本。
为了避免过拟合,还可以在损失中添加一个 正则化 项。Neo4j 图数据科学支持 l2
正则化选项,该选项可以使用 penalty
参数进行配置。
调整超参数
为了平衡模型的偏差与方差以及训练速度与内存消耗等因素,GDS 提供了几个可以调整的超参数。每个参数的说明如下。
在基于梯度下降的训练中,我们尝试为模型找到最佳权重。在每个迭代周期中,我们处理所有训练样本以计算损失和权重的梯度。然后使用这些梯度来更新权重。对于更新,我们使用 Adam 优化器,如 https://arxiv.org/pdf/1412.6980.pdf 中所述。
训练统计信息将在 Neo4j 调试日志中报告。
耐心
此参数定义了连续无效迭代的最大数量。如果一个迭代周期没有使训练损失至少提高当前损失的 tolerance
比例,则认为该迭代周期是无效的。
假设训练已运行了 minEpochs
设定的迭代次数,此参数定义了训练何时收敛。
设置此参数可以使训练更稳健,并类似于 minEpochs
避免提前终止。然而,过高的耐心值可能导致运行不必要的过多迭代次数。
根据我们的经验,patience
的合理值在 1
到 3
之间。
容忍度
此参数定义了何时将一个迭代周期视为无效,并与 patience
一起定义了训练的收敛标准。如果一个迭代周期没有使训练损失至少提高当前损失的 tolerance
比例,则认为该迭代周期是无效的。
较低的容忍度会导致更敏感的训练,并有更高的概率进行更长时间的训练。较高的容忍度意味着训练的敏感度较低,因此会有更多的迭代周期被视为无效。
批次大小
此参数定义了单个批次中包含多少个训练样本。
梯度在批次上使用 concurrency
个线程并行计算。在每个迭代周期结束时,梯度在更新权重之前被求和并缩放。batchSize
不影响模型质量,但可用于调整训练速度。较大的 batchSize
会增加计算的内存消耗。
类别权重
此参数引入了类别权重的概念,该概念由 T. Lin 等人在《用于密集对象检测的 Focal Loss》中进行了研究。它通常被称为平衡交叉熵。它为交叉熵损失函数中的每个类别分配一个权重,从而允许模型以不同的重要性处理不同的类别。它为每个样本定义为
其中 at
表示真实类别的类别权重。pt
表示真实类别的概率。
对于类别不平衡问题,类别权重通常设置为类别频率的倒数,以改善模型对少数类别的归纳偏置。
焦点权重
此参数引入了焦点损失(focal loss)的概念,同样由 T. Lin 等人在《用于密集对象检测的 Focal Loss》中进行了研究。当 focusWeight
的值大于零时,损失函数从标准交叉熵损失变为焦点损失。它为每个样本定义为
其中 pt
表示真实类别的概率。focusWeight
参数是表示为 g
的指数。
增加 focusWeight
将引导模型尝试拟合“困难的”错误分类样本。一个困难的错误分类样本是指模型对真实类别的预测概率较低的样本。在上述方程中,对于真实类别概率低的样本,损失将呈指数级增加,从而调整模型以尝试拟合这些样本,但可能会牺牲对“简单”样本的置信度。
在类别不平衡数据集中,少数类别通常更难正确分类。了解更多关于链接预测中的类别不平衡信息,请参阅类别不平衡。