TensorFlow如何冻结网络模型
如何使用TensorFlow冻结网络模型为单一文件,以供C++、移动端与嵌入式设备在推理阶段使用。
对于以推理为目的的网络模型,可以将其“冻结”起来供C++等语言调用。本文介绍了TensorFlow官方冻结网络模型的实现。
什么是冻结网络
神经网络模型中存在大量的权重参数,这些参数在模型训练的过程中通过方向传播的方式来更新,不断变化,是变量。冻结网络是将这些变量转换为常量的过程。
为什么要冻结网络
在特定场景下例如嵌入式设备上的网络推理,我们希望模型文件的使用尽可能的简单。TensorFlow的Checkpoint或者SavedModels格式均使用多文件的方式存储,且包含了对推理无用的冗余信息。冻结网络提供了尽可能精简的存储方式。
如何冻结网络
TensorFlow官方提供了实现:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py
该脚本可以将Checkpoint或者SavedModels格式的模型转换为单文件存储的供推理用模型。使用类似命令即可:
具体是如何实现的
尽管这个脚本有近500行,但是它的核心在于一个函数:tf.graph_util.convert_variables_to_constants
。该函数将网络中的变量转换为常量,并移除与变量存储加载相关的操作。你可以在官网查阅该函数的API说明。
Comments ()