TensorFlow如何冻结网络模型

如何使用TensorFlow冻结网络模型为单一文件,以供C++、移动端与嵌入式设备在推理阶段使用。

TensorFlow如何冻结网络模型
Photo by Annie Spratt 

对于以推理为目的的网络模型,可以将其“冻结”起来供C++等语言调用。本文介绍了TensorFlow官方冻结网络模型的实现。

什么是冻结网络

神经网络模型中存在大量的权重参数,这些参数在模型训练的过程中通过方向传播的方式来更新,不断变化,是变量。冻结网络是将这些变量转换为常量的过程。

为什么要冻结网络

在特定场景下例如嵌入式设备上的网络推理,我们希望模型文件的使用尽可能的简单。TensorFlow的Checkpoint或者SavedModels格式均使用多文件的方式存储,且包含了对推理无用的冗余信息。冻结网络提供了尽可能精简的存储方式。

如何冻结网络

TensorFlow官方提供了实现:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py

该脚本可以将Checkpoint或者SavedModels格式的模型转换为单文件存储的供推理用模型。使用类似命令即可:

# From TensorFlow souce: tensorflow/tensorflow/python/tools
python3 freeze_graph.py --input_graph=some_graph_def.pb \
	--input_checkpoint=model.ckpt-8361242 \
	--output_graph=/tmp/frozen_graph.pb \
    	--output_node_names=softmax
冻结网络示例命令

具体是如何实现的

尽管这个脚本有近500行,但是它的核心在于一个函数:tf.graph_util.convert_variables_to_constants 。该函数将网络中的变量转换为常量,并移除与变量存储加载相关的操作。你可以在官网查阅该函数的API说明