检查TensorFlow SavedModel文件中的模型

如何从SavedModel格式文件中读取运算图,并在TensorBoard中显示。

检查TensorFlow SavedModel文件中的模型
Photo by Paweł Czerwiński

当我们训练模型的时候,使用TensorBoard可以直接查阅训练中的模型graph。当模型训练结束,导出为SavedModel格式文件时,该如何检查模型,获取节点呢?

TensorFlow官方其实提供了一份python工具来实现对冻结网络pb文件的读取。该文件的位置在:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/import_pb_to_tensorboard.py

按照说明,只要提供pb文件路径,以及一个log存储路径即可使用TensorBoard来查阅网络结构。仔细分析该文件,核心的代码其实就几行:

with session.Session(graph=ops.Graph()) as sess:
    with gfile.GFile(model_dir, "rb") as f:
      graph_def = graph_pb2.GraphDef()
      graph_def.ParseFromString(f.read())
      importer.import_graph_def(graph_def)

    pb_visual_writer = summary.FileWriter(log_dir)
    pb_visual_writer.add_graph(sess.graph)
用于读取pb文件中网络模型的核心代码

这段代码首先创建session,然后将pb文件中的graph读取到内存,再使用summary 将graph以TensorBoard可以识别的格式写入到log文件夹。之后再使用TensorBoard查阅即可。

同理,我们可以借鉴这个思路来实现SavedModel格式的模型读取:

with session.Session(graph=ops.Graph()) as sess:
    saved_model.loader.load(sess, ["serve"], model_dir)
    pb_visual_writer = summary.FileWriter(log_dir)
    pb_visual_writer.add_graph(sess.graph)

这样就可以在TensorBoard中查阅导出的模型了。

完整的代码在这里:https://gist.github.com/yinguobing/8a283724cf892f1e6d0937dc0938b99c