TensorFlow的模型准确率计量方法
从零开始编写训练逻辑,该如何计算模型的准确率?
TensorFlow中的 metrics
模块提供了一系列模型指标评估方法。当使用 Keras 的 Model.fit
函数训练时,可以直接在编译模型时传入该类别的一个实例即可实现自动计算。但是当手动实现训练循环时,需要自行实现评估逻辑,手动更新计量指标。这里以 CategoricalAccuracy
(类别准确率)为例,说明具体的使用方法与注意事项。
CategoricalAccuracy
可以用来计算分类模型的准确率。它需要至少两个输入变量:独热标签与预测值。例如训练时某个分类任务的标签为 [0, 0, 1]
,模型的输出为 [0.1, 0,1, 0.9]
。使用以下代码可以获得当前的准确率计量结果。
即当前模型的准确率为1.0。
需要注意的是准确率的计量结果是累计的。假设第二次评估的标签与预测值分比为[0, 0, 1]
与 [0.1, 0.9, 0.1]
时,即模型做出了错误的预测,继续使用如下代码获取最新的评估结果。
综合两次评估可以得出当前模型的准确率为0.5。
在实际训练中,我们希望计量一段训练过程中的模型准确率,并以此为依据决定是否保存该模型。因此计量过程存在开始与结束节点。当通过结束节点之后,需要重置该计量对象。
重置的具体节点则可以灵活设定。如果你的数据集不大,可以将epoch结尾作为重置节点。如果你的数据集太过庞大,则可以在保存模型后重置。只要牢记你设定该计量对象的目的便不难做出抉择。
Comments ()