封面图片:Photo by Pontus Wellgraf on Unsplash

TensorFlow提供了多种模型优化方案,包括:量化、剪枝与权重聚类。其中训练后量化最为简单,使用TensorFlow Lite就可以完成。而训练感知量化、剪枝与权重聚类则需要专门的工具——TensorFlow模型优化工具集。

TensorFlow Model Optimization
A suite of tools for optimizing ML models for deployment and execution. Improve performance and efficiency, reduce latency for inference at the edge.

本文以剪枝为例,详细说明了自定义TensorFlow模型实现剪枝的具体方法。


前提条件

剪枝功能对模型以及TensorFlow的版本都有要求,具体为:

  • 仅支持Sequential与Functional方式创建的tf.keras模型
  • TensorFlow版本: TF 1.14+ 与 2.x
  • TensorFlow执行模式: graph模式或者eager模式
  • 分布式训练:仅限graph模式
目前模型优化工具集的版本为0.5.0。本文基于TensorFlow 2.3实现。

模型自定义的复杂度

如果你的模型完全是采用keras内置layer来实现的,那么遵循官方文档,只需要很少的代码就可以实现剪枝。但是实际上,对于复杂的模型架构,我通常采用自定义keras layer的方式来实现。此时在构建layer的时候就需要考虑剪枝的实现。

如何自定义Keras Layer

TensorFlow针对自定义Keras layer提供了详尽的指南文档。但是,这些文档并没有提供与剪枝相关的内容。代码中也没有显式的提供剪枝支持。实际上,针对Keras内置layer的剪枝是由TensorFlow模型优化工具箱来实现的。对于非Keras内置layer,在满足以下条件的情况下也可以支持:

  1. 该layer必须继承自 keras.layers.Layer ,同时是 PrunableLayer
  2. 该layer需要提供 get_prunable_weights 方法。

HRNet剪枝

以HRNet为例,对于HRNetBody这个layer,在声明类时需要同时继承自两个类:

# 省略了部分代码
import tensorflow_model_optimization as tfmot

class HRNetBody(keras.layers.Layer, tfmot.sparsity.keras.PrunableLayer):
HRNet继承,作者:国冰

之后,为该类实现 get_prunable_weights 方法:

def get_prunable_weights(self):
    prunable_weights = list(chain(*[
        self.bottleneck_1.get_prunable_weights(),
        self.bottleneck_2.get_prunable_weights(),
        self.bottleneck_3.get_prunable_weights(),
        self.bottleneck_4.get_prunable_weights(),
        [getattr(self.conv3x3, 'kernel')]
    ]))

    return prunable_weights
HRNet返回剪枝权重,作者:国冰

该方法以list的形式直接返回可供修剪的权重。该示例中,HRNetBody这个layer嵌入了多个子layer:4个bottleneck 与一个卷积layer self.conv3x3 。其中bottleneck 同样为自定义Keras layer,并且已经实现了方法 get_prunable_weights ,可以直接用来获得可修剪权重;卷积layer self.conv3x3 为Keras内置layer,因此需要使用 getattr 来获得可修剪权重。由于嵌入的子layer返回的权重是list,因此使用了 itertools.chain 方法来将它们展开并收纳在一个list中。在实现该方法时要留意有些权重参数可能不适合修剪,否则会造成模型准确度的大幅下降。