使用Keras导入Estimator生成的SavedModel模型
在TensorFlow 2中导入TensorFlow 1生成的模型文件。
背景
head-pose-estimation是使用深度学习以及其它CV算法来估算人脸面部朝向的一个小项目,已经断断续续更新了三年。目前在Github上有500+星,150个Fork。考虑到TensorFlow已经更新到了2.2版,现在是时候将代码迁移到TensorFlow 2了。
工作拆解
好久没碰代码,稍微花了些时间复习下。得益于项目模块化的构建方式,涉及到TensorFlow的模块仅限于mark_detector.py
。在TensorFlow 1中,执行模型推演之前需要定义graph
,然后再通过Session
加载并调用。在TensorFlow 2中,即时执行(eager execution)已经是默认选项,无需再定义Session
。同时,对于推演模型来说,使用keras
高阶API无疑是更加快捷可靠的做法。所以理论上需要做的工作可以归结为两点:
- 删除所有涉及到
Session
的旧代码。 - 使用
keras
加载并执行模型推演。
动手操作
根据TensorFlow官方文档,需要修改的内容为:
在模块中增加导入 keras
模块:
from tensorflow import keras
在类 MarkDetector
的初始化函数中,删除构建 Session
的部分:
# Get a TensorFlow session ready to do landmark detection # Load a Tensorflow saved model into memory. self.graph = tf.Graph() config = tf.ConfigProto() config.gpu_options.allow_growth = True self.sess = tf.Session(graph=self.graph, config=config)
使用keras
API来替换原有的载入模型部分。原有的是:
tf.saved_model.loader.load(self.sess, ["serve"], saved_model)
新的是:
self.model = keras.models.load_model(saved_model)
其中saved_model
是模型的存储路径。
模型加载完成后,需要在具体的检测函数中实现调用。具体为detect_marks
函数中,将原有的代码:
logits_tensor = self.graph.get_tensor_by_name( 'layer6/final_dense:0') predictions = self.sess.run( logits_tensor, feed_dict={'image_tensor:0': image_np})
替换为:
predictions = self.model.predict(image_np)
没有了Session
,代码整体上看上去更加的简洁。果断保存,执行——哎,出错了?
错误分析
错误信息为:
Traceback (most recent call last): File "estimate_head_pose.py", line 162, in <module> main() File "estimate_head_pose.py", line 114, in main marks = mark_detector.detect_marks([face_img]) File "/Users/Robin/Developer/head-pose-estimation/mark_detector.py", line 157, in detect_marks predictions = self.model.predict(image_np) AttributeError: 'AutoTrackable' object has no attribute 'predict'
不应该呀!我之前已经把训练模型的代码更新到keras
版了,而且仔细查阅keras
的文档,这个调用方式应当是没有问题的。所以问题很可能出现在模型上。
打开生成模型的项目仔细检查一遍,发现一个问题:我之前的确将训练代码更新到了keras
版,但是,我没有实现模型导出代码。所以,这里的模型文件应当是基于旧代码中的estimator
模式,而不是keras
模式导出的。所以使用 keras
加载模型会报错。
解决问题
TensorFlow官方文档提供了estimator
模式下模型的载入方法,在示例代码中可以清晰的看到导入后的模型在调用时要指定一个 Signature key
:
于是进一步查询tf.saved_model.load
API,发现以下说明:
Signatures associated with the SavedModel are available as functions:
并有示例代码:
imported = tf.saved_model.load(path)
f = imported.signatures["serving_default"]
print(f(x=tf.constant([[1.]])))
于是使用命令打印出当前模型包含的 signatures
:
self.model = keras.models.load_model(saved_model)
print(list(self.model.signatures))
得到结果:
['predict', 'serving_default']
然后照猫画虎,更改模型的调用语句为:
predictions = self.model.signatures['predict'](image_np)
保存,执行——哎,又报错了!
ValueError: All inputs to `ConcreteFunction`s must be Tensors; on invocation of pruned, the 0-th input ([array([[[63, 53, 41],
[66, 55, 40],
[68, 55, 41],
...,
[76, 62, 46],
[73, 60, 44],
[67, 54, 39]]], dtype=uint8)]) was not a Tensor.
原来是输入类型错误,小问题。将Numpy的 array
转换为TensorFlow的Tensor
即可:
predictions = self.model.signatures["predict"](
tf.constant(image_np, dtype=tf.uint8))
保存,执行——哎,叒报错了!
Traceback (most recent call last):
File "estimate_head_pose.py", line 162, in <module>
main()
File "estimate_head_pose.py", line 114, in main
marks = mark_detector.detect_marks([face_img])
File "/Users/Robin/Developer/head-pose-estimation/mark_detector.py", line 162, in detect_marks
marks = np.reshape(marks, (-1, 2))
File "/Users/Robin/Library/Python/3.6/lib/python/site-packages/numpy/core/fromnumeric.py", line 292, in reshape
return _wrapfunc(a, 'reshape', newshape, order=order)
File "/Users/Robin/Library/Python/3.6/lib/python/site-packages/numpy/core/fromnumeric.py", line 56, in _wrapfunc
return getattr(obj, method)(*args, **kwds)
ValueError: cannot reshape array of size 1 into shape (2)
矩阵变形错误。模型本身的输出尺寸应当是固定的,于是将模型输出结果打印出来看下:
{'output': <tf.Tensor: shape=(1, 136), dtype=float32, numpy=
array([[0.12975551, 0.32224602, 0.14154439, 0.42376196, 0.15958622,
...
0.55187416, 0.6883621 , 0.5191848 , 0.69217545, 0.48493028,
0.68999064]], dtype=float32)>}
原来模型的输出结果是一个Python的dict
。果断通过 output
提取出 Tensor
并转换为Numpy矩阵:
marks = np.array(predictions['output']).flatten()[:136]
保存,执行——终于OK了!😂
总结
推荐在训练与推演时使用一致的方式,如Keras或者Estimator。不过在特殊情况下,也可以使用Keras加载、并通过指定 signature
的方式执行Estimator生成的模型。
Comments ()