TensorFlow的模型准确率计量方法

从零开始编写训练逻辑,该如何计算模型的准确率?

TensorFlow的模型准确率计量方法
封面图片:Piret Ilver

TensorFlow中的 metrics 模块提供了一系列模型指标评估方法。当使用 Keras 的 Model.fit 函数训练时,可以直接在编译模型时传入该类别的一个实例即可实现自动计算。但是当手动实现训练循环时,需要自行实现评估逻辑,手动更新计量指标。这里以 CategoricalAccuracy(类别准确率)为例,说明具体的使用方法与注意事项。

CategoricalAccuracy 可以用来计算分类模型的准确率。它需要至少两个输入变量:独热标签与预测值。例如训练时某个分类任务的标签为 [0, 0, 1] ,模型的输出为 [0.1, 0,1, 0.9] 。使用以下代码可以获得当前的准确率计量结果。

m = tf.keras.metrics.CategoricalAccuracy()
m.update_state([0, 0, 1], [0.1, 0.1, 0.9])
m.result()

# 输出结果为 
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
计算分类模型当前准确率的代码

即当前模型的准确率为1.0。

需要注意的是准确率的计量结果是累计的。假设第二次评估的标签与预测值分比为[0, 0, 1][0.1, 0.9, 0.1] 时,即模型做出了错误的预测,继续使用如下代码获取最新的评估结果。

m.update_state([0, 0, 1], [0.1, 0.9, 0.1])
m.result()

# 输出结果为
<tf.Tensor: shape=(), dtype=float32, numpy=0.5>
更新模型的累计准确率

综合两次评估可以得出当前模型的准确率为0.5。

在实际训练中,我们希望计量一段训练过程中的模型准确率,并以此为依据决定是否保存该模型。因此计量过程存在开始与结束节点。当通过结束节点之后,需要重置该计量对象。

m.reset_state()
m.result()

# 输出结果为
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
模型准确率重置为0

重置的具体节点则可以灵活设定。如果你的数据集不大,可以将epoch结尾作为重置节点。如果你的数据集太过庞大,则可以在保存模型后重置。只要牢记你设定该计量对象的目的便不难做出抉择。