"TensorFlow官方本着精心维护、测试且及时更新的原则使用高阶API构建了一系列模型。这些模型性能优异,同时易于阅读。端到端测试则用来保证模型在每个新发布版本上都能保持相同的性能。"

官方地址:

https://github.com/tensorflow/models/tree/master/official

如果我们从零开始自行创建模型表达,最有可能的做法是:按照需求抽象若干Python模块。这些模块可能包含数据增强、图像处理、模型定义、训练、测试与推演等。TensorFlow的做法类似,我们可以借鉴下官方的具体实现。

目前官方提供的模型包括:

  • boosted_trees: A Gradient Boosted Trees model to classify higgs boson process from HIGGS Data Set.
  • mnist: A basic model to classify digits from the MNIST dataset.
  • resnet: A deep residual network that can be used to classify both CIFAR-10 and ImageNet's dataset of 1000 classes.
  • transformer: A transformer model to translate the WMT English to German dataset.
  • wide_deep: A model that combines a wide model and deep network to classify census income data.

这些模型遵循共同的原则:

  • 使用通用的工具函数。
  • 保存最后训练结果为SavedModel。
  • 包含flags与flag-parsing库 (read more here)
  • 提供benchmarks与logs (read more here)

ResNet的实现下有众多文件,分工不同。总体上看包含两大部分:模型实现与模型应用。

模型实现主要依赖两个模块:resnet_model.pyresnet_run_loop.py 。模型应用下有两个例子:cifar10与imagenet。接下来我们从应用入手,抽丝剥茧倒推整体实现。

先以cifar-10为例。按照官方教程,首先执行cifar10_download_and_extract.py,这段脚本的作用是下载训练所需文件。接下来执行是cifar10_main.py,这个文件即为训练与测试的入口文件。

入口文件分析

文件代码共278行,按照注释整体上分为两块:数据处理与模型执行。里边关键的函数加起来不到10个。

# Data processing
def get_filenames(is_training, data_dir):
def parse_record(raw_record, is_training, dtype):
def preprocess_image(image, is_training):
def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...):
def get_synth_input_fn(dtype):

# Running the model
def class Cifar10Model(resnet_model.Model):
def cifar10_model_fn(features, labels, mode, params):
def define_cifar_flags():
def run_cifar(flags_obj):

其中的run_cifar为入口函数。而在该函数内部其实只做了两件事:

  1. 实例化input函数。
  2. 调用resnet_run_loop模块下的resnet_main函数。

resnet_main函数接收多个参数,包括input函数与cifar10_model_fn函数。而input函数的具体内容则取决于flags_obj.use_synthetic_data这个flag。这里的三个函数都在本文件中定义。从代码来看,cifar10采用的应该是TensorFlow estimator的方式。这种方式对外隐藏了Graph与Session,实现的关键在于两部分:输入函数模型函数。接下来我们分别分析。

输入函数

这里的输入函数根据输入的flag对象中的设定来决定返回的输入函数。但是它的语法很有意思:

input_function = (flags_obj.use_synthetic_data and 
                  get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or 
                  input_fn)

使用bool值与函数对象做and运算,True返回该函数,False返回False。使用bool值与函数对象做or运算,True返回True,False返回该函数。这里相当于一种if运算的替代方案。这里的两个输入函数一个输出真实数据,一个输出替代数据。我们关注真实数据。

input_fn内部的实现也非常简单,在获得dataset对象后,直接使用resnet_run_loop模块的process_record_dataset构建具体的输入函数。不过输入的参数之一是parse_record_fn,对应本文件中的parse_record函数。这个函数具体内容如下:

def parse_record(raw_record, is_training, dtype):
  """Parse CIFAR-10 image and label from a raw record."""
  record_vector = tf.decode_raw(raw_record, tf.uint8)
  label = tf.cast(record_vector[0], tf.int32)
  depth_major = tf.reshape(record_vector[1:_RECORD_BYTES],
                           [NUM_CHANNELS, HEIGHT, WIDTH])
  image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
  image = preprocess_image(image, is_training)
  image = tf.cast(image, dtype)

  return image, label

输入函数接收raw_record,内部进行数据的解码与格式转换,输出image与label。这正是Estimator对输入函数的要求。

模型函数

cifar10_model_fn

该函数其实只做了一件事,根据cifar10的需要构建resnet_model模型函数。具体的内容包含:

  1. reshape feature,有可能因为cifar数据的存储格式造成。
  2. 使用resnet_run_loop模块下的learning_rate_with_decay定义了一个learning_rate相关的函数。
  3. 定义了weight_decay。
  4. 定义了是否在regularized loss上使用normalization。

最后使用resnet_run_loop模块下的resnet_model_fn构建函数对象并返回。该函数的参数列表里还要一个重要的参数为model_class。在本文件中为Cifar10Model。

Cifar10Model

Cifar10Model继承自resnet_model模块下的Model。做出的修改几乎为零。关键在于自定义参数。

super(Cifar10Model, self).__init__(
    resnet_size=resnet_size,
    bottleneck=False,
    num_classes=num_classes,
    num_filters=16,
    kernel_size=3,
    conv_stride=1,
    first_pool_size=None,
    first_pool_stride=None,
    block_sizes=[num_blocks] * 3,
    block_strides=[1, 2, 2],
    resnet_version=resnet_version,
    data_format=data_format,
    dtype=dtype
)

上方的代码指定了适用于cifar10的ResNet模型的网络规模、版本、分类数、卷积参数和池化参数等。如果要根据自己的需求去定义一个ResNet派生模型,这个函数可供参考。

小结

至此cifar10_main.py分析完成。Cifar10示例应用遵循了TensorFlow estimator的使用方式,通过构建输入函数与模型函数的方式实现了自定义任务的模型训练、评估与导出。其中需要自行实现的代码非常少,大部分工作依赖resnet_run_loop与resnet_model这两个模块。下一篇文章中我们具体对这两个模块进行分析。

如果你感觉自己已经理解了官方模型应用的实现,可以尝试分析下另一个Imagenet的实现:imagenet_main.py


  1. https://www.tensorflow.org/guide/estimators
  2. https://github.com/tensorflow/models/tree/master/official