TensorFlow官方ResNet模型实现分析

TensoFlow官方是如何构建DNN模型的。

TensorFlow官方ResNet模型实现分析
Cover photo by Jimmy Chang

上一篇文章中我们分析了TensorFlow官方如何解决Cifar10模型问题。我们从具体入口函数入手,分析了cifar10_main.py如何定义输入函数与模型函数。发现resnet_run_loopresnet_model这两个非常重要的模块在发挥着重要的作用。本次我们将对这两个模块进行分析。

模型实现

ResNet模型的实现体现在 resnet_model.py 文件中的共计546行代码中。按照注释可以将整体分割成两大部分:辅助构建ResNet模型的函数以及ResNet block定义函数。

# Convenience functions for building the ResNet model.
def batch_norm(inputs, training, data_format):
def fixed_padding(inputs, kernel_size, data_format):
def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format):

# ResNet block definitions.
def _building_block_v1(inputs, filters, training, projection_shortcut, strides, ...):
def _building_block_v2(inputs, filters, training, projection_shortcut, strides, ...)
def _bottleneck_block_v1(inputs, filters, training, projection_shortcut, ...)
def _bottleneck_block_v2(inputs, filters, training, projection_shortcut, ...)
def block_layer(inputs, filters, bottleneck, block_fn, blocks, strides, ...)
class Model(object):

如果对ResNet不熟悉的话可以先参考这篇post。ResNet是由一系列block堆叠而成的。官方根据文献实现了两种不同的block,对应代码中的_building_block_bottleneck_blockblock_layer则将blocks组合在一起,它的角色类似TensorFlow原生的 tf.layers,都是为了简化最终的网络构建。最终模型的成型体现在class Model()中。

对于Model的定义基本上占据了文件代码的1/3。不过最核心的部分在__call__函数中。这个函数清晰的展示了Model是如何加工输入input的,具体包括:

  1. 为了加速GPU运算,将输入由NHWC转换成NCHW。
  2. 首次卷积运算。
  3. 根据ResNet版本判断是否要做batch norm。
  4. 首次pooling。
  5. 堆叠block。
  6. 最终的pooling(代码中用表现更好的reduce_mean替代)。
  7. 最终的全连接层。

回忆cifar10的模型是继承自这个模型,除非要对ResNet模型本身做出大的修改,这里的实现是一个非常成熟、可开箱即用的结果。

模型运行

模型实现的重点在于“描述模型长什么样子”,当知道模型的具体体貌后,我们还需要将模型放置在具体的环境中,实现数据与误差在模型中的流动,进而利用梯度下降法更新模型参数。这部分工作是有模块 resnet_run_loop.py 实现的。

resnet_run_loop模块的代码行数达到596行,是目前我们遇到的最庞大的一个模块。同样依赖文件注释可以将所有函数分为两大块:输入处理与运行循环(训练、验证、测试)。

# Functions for input processing.
def process_record_dataset(dataset, ...)
def get_synth_input_fn(height, width, num_channels, num_classes, ...)
def image_bytes_serving_input_fn(image_shape, dtype=tf.float32)
def override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

# Functions for running training/eval/validation loops for the model.
def learning_rate_with_decay(...)
def resnet_model_fn(features, labels, mode, model_class, ...)
def resnet_main(...)
def define_resnet_flags(resnet_size_choices=None)

输入处理

先看输入部分。如果对上一篇文章还有印象的话,当时的输入函数调用了这里的process_record_dataset()函数,并将自定义的parse_record_fn传递了进来。仔细看这里的函数实现:

dataset = dataset.apply(
    tf.contrib.data.map_and_batch(
        lambda value: parse_record_fn(value, is_training, dtype),
        batch_size=batch_size,
        num_parallel_batches=num_parallel_batches,
        drop_remainder=False))

我们在cifar10中自定义的解析record的函数在这里通过lamda的方式作用在了每一份data上。除此之外,函数中还根据是否是训练过程对数据的随机化、批次、预读取等细节做出了设定。通过预留解析数据的接口函数parse_record_fn,实现了不同数据集使用相同函数构建input的方式,例如自带的Cifar10与ImageNet例子。

运行循环

数据输入已经确定,模型也已经定义完成,但是我们离模型训练还差关键的一步:loss的定义。文件中的函数resnet_model_fn则刚好包含了这一实现。具体的实现为:

# 获取模型输出
logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)

# 计算交叉熵
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
      logits=logits, labels=labels)
      
# 计算l2 loss
l2_loss = weight_decay * tf.add_n(
      [tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()
       if loss_filter_fn(v.name)])
       
# 获得最终loss
loss = cross_entropy + l2_loss

函数resnet_model_fn的作用其实是构建EstimatorSpec,loss其实是EstimatorSpec的一部分。而estimator的实例化则在函数resnet_main内部,这也是真正的训练与验证过程的所在。在Cifar10的例子中,也是直接调用了resnet_main

# Cifar10 
result = resnet_run_loop.resnet_main(
    flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
    shape=[HEIGHT, WIDTH, NUM_CHANNELS])

resnet_main函数内可以找到训练、验证与导出的调用代码:

# 训练
classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                       hooks=train_hooks, max_steps=flags_obj.max_train_steps)

# 验证
eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                       steps=flags_obj.max_train_steps)

# 导出
classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
                                 strip_default_attrs=True)

总结

TensorFlow官方ResNet模型通过resnet_modelresnet_run_loop来实现。resnet_model专注于ResNet的网络构建;resnet_run_loop则承担了输入数据处理、Estimator相关组件初始化以及训练、验证与导出的流程控制。在此基础上,可以通过添加模块的方式来实现自定义的模型训练,例如Cifar10。

如果你需要使用官方的ResNet模型来实现特定的功能,可以分以下几步来做:

  1. 模仿Cifar10,构建定制化的模型结构与输入函数,形成定制化模块。
  2. 如果需要,修改resnet_run_loop模块中的相关部分,例如loss函数。
  3. 如果需要,修改resnet_model模块。

  1. https://towardsdatascience.com/an-overview-of-resnet-and-its-variants-5281e2f56035