逻辑回归
此功能处于 beta 级别。有关功能级别的更多信息,请参见 API 级别.
逻辑回归是一种基本的监督机器学习分类方法。它通过最小化取决于权重矩阵和训练数据的损失函数来训练模型。例如,可以使用梯度下降法来最小化损失。在 GDS 中,我们使用 Adam 优化器,这是一种梯度下降类型的算法。
权重采用 [c,d]
大小的矩阵 W
和长度为 c
的偏差向量 b
的形式,其中 d
是特征维数,c
等于类的数量。然后损失函数定义为
CE(softmax(Wx + b))
其中 CE
是 交叉熵损失,softmax
是 softmax 函数,x
是长度为 d
的特征向量训练样本。
为了避免过拟合,还可以将 正则化 项添加到损失中。Neo4j 图数据科学支持 l2
正则化的选项,可以使用 penalty
参数进行配置。
调整超参数
为了平衡模型的偏差与方差、训练的速度与内存消耗等问题,GDS 公开了几个可以调整的超参数。下面将对每个参数进行说明。
在基于梯度下降法的训练中,我们试图找到模型的最佳权重。在每个 epoch 中,我们处理所有训练样本以计算损失和权重的梯度。然后使用这些梯度来更新权重。对于更新,我们使用 Adam 优化器,如 https://arxiv.org/pdf/1412.6980.pdf 中所述。
有关训练的统计信息会在 neo4j 调试日志中报告。
耐心
此参数定义无生产力的连续 Epochs 的最大数量。如果 Epoch 没有将训练损失提高当前损失的至少 tolerance
分数,则该 Epoch 就算无生产力。
假设训练运行了 minEpochs
次,此参数定义了训练何时收敛。
设置此参数可以使训练更加稳健,并避免与 minEpochs
相似的过早终止。但是,较高的 patience 值可能会导致运行比必要更多的 epoch。
根据我们的经验,patience
的合理值范围为 1
到 3
。
容忍度
此参数定义了何时将一个 epoch 视为非生产性,并与 patience
一起定义了训练的收敛标准。如果一个 epoch 未能将训练损失提高当前损失的至少 tolerance
倍,则它是非生产性的。
较低的容忍度会导致更敏感的训练,训练时间更长的概率更高。较高的容忍度意味着训练不太敏感,因此会导致更多 epoch 被视为非生产性的。
批次大小
此参数定义了在一个批次中分组的训练示例数量。
使用 concurrency
个线程同时计算批次上的梯度。在一个 epoch 结束时,梯度在更新权重之前被求和并缩放。batchSize
不会影响模型质量,但可用于调整训练速度。较大的 batchSize
会增加计算的内存消耗。
类别权重
此参数引入了类别权重的概念,在 T. Lin 等人的“Focal Loss for Dense Object Detection”中进行了研究。它通常被称为平衡交叉熵。它在交叉熵损失函数中为每个类别分配一个权重,从而允许模型以不同的重要性对待不同的类别。它为每个示例定义为
其中 at
表示真实类别的类别权重。pt
表示真实类别的概率。
对于类别不平衡问题,类别权重通常设置为类别频率的倒数,以提高模型对少数类别的归纳偏差。
焦点权重
此参数引入了焦点损失的概念,同样是在 T. Lin 等人的“Focal Loss for Dense Object Detection”中进行了研究。当 focusWeight
的值为大于零时,损失函数从标准交叉熵损失变为焦点损失。它为每个示例定义为
其中 pt
表示真实类别的概率。focusWeight
参数是指数,表示为 g
。
增加 focusWeight
将引导模型尝试拟合“困难”的错误分类示例。一个困难的错误分类示例是指模型对真实类别具有低预测概率的示例。在上面的等式中,对于低真实类别概率示例,损失将呈指数级增长,从而调整模型以尝试拟合它们,但代价可能是对“容易”示例的置信度降低。
在类别不平衡的数据集中,少数类别通常难以正确分类。有关链接预测的类别不平衡的更多信息,请参阅 类别不平衡。