构建TensorFlow模型的三种方式
都能建模型,该用哪一个?
封面图片:Photo by Martin Reisch on Unsplash
TensorFlow 2提供了多种构建神经网络模型的方式,包括:
- Sequential API
- Keras functional API
- Keras model subclassing
这三种方式有什么区别,我该用哪一种?
Sequential API
如果模型结构简单,适合除了一层一层堆叠Layer就没有其他特殊结构的话,这是最简单的做法。不过如果你的模型包含非线性结构结构例如跳过中间层的连接、多个输入输出或者需要在多个模型中共享某一层的话,这种方式就无能为力了。
Functional API与Model Subclassing
大部分人可能会在这两种方式的选择上纠结。在官方博客中,这两种方式也被称为Symbolic API与Impreative API。他们都可以用来构建复杂拓扑结构的模型,也存在一些区别。
功能区别
Functional API特别适合构建DAG,但是无法构建动态模型,例如tree-RNN。而Subclassing可以。
易用性
Functional API本质在构建graph,因此很容易可以在构建模型时就开展兼容性检查,而非在模型执行的时候。这极大的减轻了开发者工作量。而Subclassing下,兼容性检查这些工作需要由开发者自己来负责。例如模型的summary()
方法,Functional API的输出包含非常详细的内容:
Model: "functional"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
img (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 26, 26, 16) 160
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 8, 8, 16) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 6, 6, 16) 2320
_________________________________________________________________
global_max_pooling2d (Global (None, 16) 0
=================================================================
Total params: 2,480
Trainable params: 2,480
Non-trainable params: 0
_________________________________________________________________
但是对于Subclassing来说,模型构建完成后直接summary()
会报错。
ValueError: This model has not yet been built. Build the model first by calling `build()` or calling `fit()` with some data, or specify an `input_shape` argument in the first layer(s) for automatic build.
模型需要处理至少一次数据,完成运算图的构建,才能够输出summary信息,例如同一款模型的输出:
Model: "subclass"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_2 (Conv2D) multiple 160
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 multiple 0
_________________________________________________________________
conv2d_3 (Conv2D) multiple 2320
_________________________________________________________________
global_max_pooling2d_1 (Glob multiple 0
=================================================================
Total params: 2,480
Trainable params: 2,480
Non-trainable params: 0
_________________________________________________________________
再利用便捷性
Functional API下定义模型的数据结构很容实现复制、克隆。这种情况下你可以很容易地将模型导出保存。即便没有构建模型的原始代码,你依然可以从模型配置中恢复模型结构。而Subclassing的方式下,构建模型的是类方法中的一段代码,而非对外透明的数据结构。因此,在获取极大灵活性的同时,牺牲了易用性与重复利用性。
总结
对于大部分人来说,依据自己实际情况选择一种就好了。学术研究可能偏爱Functional API,它的使用范式与我们脑中的神经网络模型非常相似,而且代码简单,再利用性好。工业界的程序员可能更加偏爱Subclassing的方式。一方面是OOP代码习惯,而且面对多变的需求,需要给可能的hack操作留下操作空间。
另外多说一句,如果你愿意,这里两种方案其实是可以混在一起使用的。
扩展阅读
这里有一篇TensorFlow官方解析文章,推荐阅读。
一个彩蛋
Subclassing方式构建的模型也可以在Summary中显示output shape:
Model: "subclass_with_shape"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_2 (Conv2D) (None, 26, 26, 16) 160
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 8, 8, 16) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 6, 6, 16) 2320
_________________________________________________________________
global_max_pooling2d_1 (Glob (None, 16) 0
=================================================================
Total params: 2,480
Trainable params: 2,480
向微信公众号发送 model summary
获取具体方法。
Comments ()