TensorFlow官方ResNet模型实现分析
TensoFlow官方是如何构建DNN模型的。
上一篇文章中我们分析了TensorFlow官方如何解决Cifar10模型问题。我们从具体入口函数入手,分析了cifar10_main.py
如何定义输入函数与模型函数。发现resnet_run_loop
与resnet_model
这两个非常重要的模块在发挥着重要的作用。本次我们将对这两个模块进行分析。
模型实现
ResNet模型的实现体现在 resnet_model.py
文件中的共计546行代码中。按照注释可以将整体分割成两大部分:辅助构建ResNet模型的函数以及ResNet block定义函数。
# Convenience functions for building the ResNet model.
def batch_norm(inputs, training, data_format):
def fixed_padding(inputs, kernel_size, data_format):
def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format):
# ResNet block definitions.
def _building_block_v1(inputs, filters, training, projection_shortcut, strides, ...):
def _building_block_v2(inputs, filters, training, projection_shortcut, strides, ...)
def _bottleneck_block_v1(inputs, filters, training, projection_shortcut, ...)
def _bottleneck_block_v2(inputs, filters, training, projection_shortcut, ...)
def block_layer(inputs, filters, bottleneck, block_fn, blocks, strides, ...)
class Model(object):
如果对ResNet不熟悉的话可以先参考这篇post。ResNet是由一系列block堆叠而成的。官方根据文献实现了两种不同的block,对应代码中的_building_block
与_bottleneck_block
。block_layer
则将blocks组合在一起,它的角色类似TensorFlow原生的 tf.layers
,都是为了简化最终的网络构建。最终模型的成型体现在class Model()
中。
对于Model的定义基本上占据了文件代码的1/3。不过最核心的部分在__call__
函数中。这个函数清晰的展示了Model是如何加工输入input的,具体包括:
- 为了加速GPU运算,将输入由NHWC转换成NCHW。
- 首次卷积运算。
- 根据ResNet版本判断是否要做batch norm。
- 首次pooling。
- 堆叠block。
- 最终的pooling(代码中用表现更好的reduce_mean替代)。
- 最终的全连接层。
回忆cifar10的模型是继承自这个模型,除非要对ResNet模型本身做出大的修改,这里的实现是一个非常成熟、可开箱即用的结果。
模型运行
模型实现的重点在于“描述模型长什么样子”,当知道模型的具体体貌后,我们还需要将模型放置在具体的环境中,实现数据与误差在模型中的流动,进而利用梯度下降法更新模型参数。这部分工作是有模块 resnet_run_loop.py
实现的。
resnet_run_loop
模块的代码行数达到596行,是目前我们遇到的最庞大的一个模块。同样依赖文件注释可以将所有函数分为两大块:输入处理与运行循环(训练、验证、测试)。
# Functions for input processing.
def process_record_dataset(dataset, ...)
def get_synth_input_fn(height, width, num_channels, num_classes, ...)
def image_bytes_serving_input_fn(image_shape, dtype=tf.float32)
def override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)
# Functions for running training/eval/validation loops for the model.
def learning_rate_with_decay(...)
def resnet_model_fn(features, labels, mode, model_class, ...)
def resnet_main(...)
def define_resnet_flags(resnet_size_choices=None)
输入处理
先看输入部分。如果对上一篇文章还有印象的话,当时的输入函数调用了这里的process_record_dataset()
函数,并将自定义的parse_record_fn
传递了进来。仔细看这里的函数实现:
dataset = dataset.apply(
tf.contrib.data.map_and_batch(
lambda value: parse_record_fn(value, is_training, dtype),
batch_size=batch_size,
num_parallel_batches=num_parallel_batches,
drop_remainder=False))
我们在cifar10中自定义的解析record的函数在这里通过lamda的方式作用在了每一份data上。除此之外,函数中还根据是否是训练过程对数据的随机化、批次、预读取等细节做出了设定。通过预留解析数据的接口函数parse_record_fn
,实现了不同数据集使用相同函数构建input的方式,例如自带的Cifar10与ImageNet例子。
运行循环
数据输入已经确定,模型也已经定义完成,但是我们离模型训练还差关键的一步:loss的定义。文件中的函数resnet_model_fn
则刚好包含了这一实现。具体的实现为:
# 获取模型输出
logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)
# 计算交叉熵
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=labels)
# 计算l2 loss
l2_loss = weight_decay * tf.add_n(
[tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()
if loss_filter_fn(v.name)])
# 获得最终loss
loss = cross_entropy + l2_loss
函数resnet_model_fn
的作用其实是构建EstimatorSpec,loss其实是EstimatorSpec的一部分。而estimator的实例化则在函数resnet_main
内部,这也是真正的训练与验证过程的所在。在Cifar10的例子中,也是直接调用了resnet_main
。
# Cifar10
result = resnet_run_loop.resnet_main(
flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
shape=[HEIGHT, WIDTH, NUM_CHANNELS])
在resnet_main
函数内可以找到训练、验证与导出的调用代码:
# 训练
classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
hooks=train_hooks, max_steps=flags_obj.max_train_steps)
# 验证
eval_results = classifier.evaluate(input_fn=input_fn_eval,
steps=flags_obj.max_train_steps)
# 导出
classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
strip_default_attrs=True)
总结
TensorFlow官方ResNet模型通过resnet_model
与resnet_run_loop
来实现。resnet_model
专注于ResNet的网络构建;resnet_run_loop
则承担了输入数据处理、Estimator相关组件初始化以及训练、验证与导出的流程控制。在此基础上,可以通过添加模块的方式来实现自定义的模型训练,例如Cifar10。
如果你需要使用官方的ResNet模型来实现特定的功能,可以分以下几步来做:
- 模仿Cifar10,构建定制化的模型结构与输入函数,形成定制化模块。
- 如果需要,修改
resnet_run_loop
模块中的相关部分,例如loss函数。 - 如果需要,修改
resnet_model
模块。
Comments ()