构建TensorFlow模型的三种方式

都能建模型,该用哪一个?

构建TensorFlow模型的三种方式

封面图片:Photo by Martin Reisch on Unsplash

TensorFlow 2提供了多种构建神经网络模型的方式,包括:

  • Sequential API
  • Keras functional API
  • Keras model subclassing

这三种方式有什么区别,我该用哪一种?

Sequential API

如果模型结构简单,适合除了一层一层堆叠Layer就没有其他特殊结构的话,这是最简单的做法。不过如果你的模型包含非线性结构结构例如跳过中间层的连接、多个输入输出或者需要在多个模型中共享某一层的话,这种方式就无能为力了。

Functional API与Model Subclassing

大部分人可能会在这两种方式的选择上纠结。在官方博客中,这两种方式也被称为Symbolic API与Impreative API。他们都可以用来构建复杂拓扑结构的模型,也存在一些区别。

功能区别

Functional API特别适合构建DAG,但是无法构建动态模型,例如tree-RNN。而Subclassing可以。

易用性

Functional API本质在构建graph,因此很容易可以在构建模型时就开展兼容性检查,而非在模型执行的时候。这极大的减轻了开发者工作量。而Subclassing下,兼容性检查这些工作需要由开发者自己来负责。例如模型的summary() 方法,Functional API的输出包含非常详细的内容:

Model: "functional"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
img (InputLayer)             [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 26, 26, 16)        160       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 8, 8, 16)          0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 6, 6, 16)          2320      
_________________________________________________________________
global_max_pooling2d (Global (None, 16)                0         
=================================================================
Total params: 2,480
Trainable params: 2,480
Non-trainable params: 0
_________________________________________________________________

但是对于Subclassing来说,模型构建完成后直接summary()会报错。

ValueError: This model has not yet been built. Build the model first by calling `build()` or calling `fit()` with some data, or specify an `input_shape` argument in the first layer(s) for automatic build.

模型需要处理至少一次数据,完成运算图的构建,才能够输出summary信息,例如同一款模型的输出:

Model: "subclass"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_2 (Conv2D)            multiple                  160       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 multiple                  0         
_________________________________________________________________
conv2d_3 (Conv2D)            multiple                  2320      
_________________________________________________________________
global_max_pooling2d_1 (Glob multiple                  0         
=================================================================
Total params: 2,480
Trainable params: 2,480
Non-trainable params: 0
_________________________________________________________________

再利用便捷性

Functional API下定义模型的数据结构很容实现复制、克隆。这种情况下你可以很容易地将模型导出保存。即便没有构建模型的原始代码,你依然可以从模型配置中恢复模型结构。而Subclassing的方式下,构建模型的是类方法中的一段代码,而非对外透明的数据结构。因此,在获取极大灵活性的同时,牺牲了易用性与重复利用性。

总结

对于大部分人来说,依据自己实际情况选择一种就好了。学术研究可能偏爱Functional API,它的使用范式与我们脑中的神经网络模型非常相似,而且代码简单,再利用性好。工业界的程序员可能更加偏爱Subclassing的方式。一方面是OOP代码习惯,而且面对多变的需求,需要给可能的hack操作留下操作空间。

另外多说一句,如果你愿意,这里两种方案其实是可以混在一起使用的。

扩展阅读

这里有一篇TensorFlow官方解析文章,推荐阅读。

What are Symbolic and Imperative APIs in TensorFlow 2.0?
The TensorFlow blog contains regular news from the TensorFlow team and the community, with articles on Python, TensorFlow.js, TF Lite, TFX, and more.

一个彩蛋

Subclassing方式构建的模型也可以在Summary中显示output shape:

Model: "subclass_with_shape"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_2 (Conv2D)            (None, 26, 26, 16)        160       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 8, 8, 16)          0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 6, 6, 16)          2320      
_________________________________________________________________
global_max_pooling2d_1 (Glob (None, 16)                0         
=================================================================
Total params: 2,480
Trainable params: 2,480

向微信公众号发送 model summary 获取具体方法。