将TensorFlow的变量格式从NCHW转换为NHWC

将一堆二维张量拼接成三维张量的时候,默认的Chanel维度在首位;然而在TensorFlow中张量的默认Channel维度在末尾。因此有时需要将变量模式从NCHW转换为NHWC以匹配格式。

将TensorFlow的变量格式从NCHW转换为NHWC

将一堆二维张量拼接成三维张量的时候,默认的Chanel维度在首位;然而在TensorFlow中张量的默认Channel维度在末尾。因此有时需要将变量模式从NCHW转换为NHWC以匹配格式。

根据TensorFlow官方文档,N、C、H、W代表意义如下[1]

  • N:一个batch内图片的数量。
  • H:垂直高度方向的像素个数。
  • W:水平宽度方向的像素个数。
  • C:通道数。例如灰度图像为1, 彩色RGB图像为3。

假设张量x如下:

x = [[[  1,   2],
      [  3,   4]],
     [[ 11,  22],
      [ 33,  44]],
     [[111, 222],
      [333, 444]]]

此时x的shape为(3, 2, 2),通道数为3。取第一个维度首元素如下:

>>> x[0, :, :]
<tf.Tensor: id=7, shape=(2, 2), dtype=int32, numpy=
array([[1, 2],
       [3, 4]], dtype=int32)>

取最后一个维度首元素如下:

>>> x[:,:,0]
<tf.Tensor: id=12, shape=(3, 2), dtype=int32, numpy=
array([[  1,   3],
       [ 11,  33],
       [111, 333]], dtype=int32)>

按照StackOverflow的解答[2],要将通道置于维度的尾部,可以使用TensorFlow中的tf.transpose函数实现[3]

y = tf.transpose(x, [1, 2, 0])

其中第二个参数是转换后的张量中,原始张量的维度编号。编号0原本在首位,现在处于末位。

转换后的张量y如下:

<tf.Tensor: id=15, shape=(2, 2, 3), dtype=int32, numpy=
array([[[  1,  11, 111],
        [  2,  22, 222]],

       [[  3,  33, 333],
        [  4,  44, 444]]], dtype=int32)>

取y第一个维度首元素如下:

>>> y[0]
<tf.Tensor: id=20, shape=(2, 3), dtype=int32, numpy=
array([[  1,  11, 111],
       [  2,  22, 222]], dtype=int32)>

取y最后一个维度首元素如下:

>>> y[:,:,0]
<tf.Tensor: id=25, shape=(2, 2), dtype=int32, numpy=
array([[1, 2],
       [3, 4]], dtype=int32)>

可见x[0, :, :] = y[:, :, 0]。张量已经由NCHW转换为NHWC格式。


  1. https://tensorflow.google.cn/performance/performance_guide#use_nchw_imag ↩︎

  2. https://stackoverflow.com/questions/37689423/convert-between-nhwc-and-nchw-in-tensorflow ↩︎

  3. https://tensorflow.google.cn/versions/master/api_docs/python/tf/transpose ↩︎