检查TensorFlow SavedModel文件中的模型
如何从SavedModel格式文件中读取运算图,并在TensorBoard中显示。
当我们训练模型的时候,使用TensorBoard可以直接查阅训练中的模型graph。当模型训练结束,导出为SavedModel格式文件时,该如何检查模型,获取节点呢?
TensorFlow官方其实提供了一份python工具来实现对冻结网络pb文件的读取。该文件的位置在:
按照说明,只要提供pb文件路径,以及一个log存储路径即可使用TensorBoard来查阅网络结构。仔细分析该文件,核心的代码其实就几行:
这段代码首先创建session,然后将pb文件中的graph读取到内存,再使用summary
将graph以TensorBoard可以识别的格式写入到log文件夹。之后再使用TensorBoard查阅即可。
同理,我们可以借鉴这个思路来实现SavedModel格式的模型读取:
with session.Session(graph=ops.Graph()) as sess:
saved_model.loader.load(sess, ["serve"], model_dir)
pb_visual_writer = summary.FileWriter(log_dir)
pb_visual_writer.add_graph(sess.graph)
这样就可以在TensorBoard中查阅导出的模型了。
完整的代码在这里:https://gist.github.com/yinguobing/8a283724cf892f1e6d0937dc0938b99c
Comments ()