将TensorFlow的变量格式从NCHW转换为NHWC
将一堆二维张量拼接成三维张量的时候,默认的Chanel维度在首位;然而在TensorFlow中张量的默认Channel维度在末尾。因此有时需要将变量模式从NCHW转换为NHWC以匹配格式。
将一堆二维张量拼接成三维张量的时候,默认的Chanel维度在首位;然而在TensorFlow中张量的默认Channel维度在末尾。因此有时需要将变量模式从NCHW转换为NHWC以匹配格式。
根据TensorFlow官方文档,N、C、H、W代表意义如下[1]:
- N:一个batch内图片的数量。
- H:垂直高度方向的像素个数。
- W:水平宽度方向的像素个数。
- C:通道数。例如灰度图像为1, 彩色RGB图像为3。
假设张量x如下:
x = [[[ 1, 2],
[ 3, 4]],
[[ 11, 22],
[ 33, 44]],
[[111, 222],
[333, 444]]]
此时x的shape为(3, 2, 2),通道数为3。取第一个维度首元素如下:
>>> x[0, :, :]
<tf.Tensor: id=7, shape=(2, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4]], dtype=int32)>
取最后一个维度首元素如下:
>>> x[:,:,0]
<tf.Tensor: id=12, shape=(3, 2), dtype=int32, numpy=
array([[ 1, 3],
[ 11, 33],
[111, 333]], dtype=int32)>
按照StackOverflow的解答[2],要将通道置于维度的尾部,可以使用TensorFlow中的tf.transpose
函数实现[3]。
y = tf.transpose(x, [1, 2, 0])
其中第二个参数是转换后的张量中,原始张量的维度编号。编号0原本在首位,现在处于末位。
转换后的张量y如下:
<tf.Tensor: id=15, shape=(2, 2, 3), dtype=int32, numpy=
array([[[ 1, 11, 111],
[ 2, 22, 222]],
[[ 3, 33, 333],
[ 4, 44, 444]]], dtype=int32)>
取y第一个维度首元素如下:
>>> y[0]
<tf.Tensor: id=20, shape=(2, 3), dtype=int32, numpy=
array([[ 1, 11, 111],
[ 2, 22, 222]], dtype=int32)>
取y最后一个维度首元素如下:
>>> y[:,:,0]
<tf.Tensor: id=25, shape=(2, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4]], dtype=int32)>
可见x[0, :, :] = y[:, :, 0]。张量已经由NCHW转换为NHWC格式。
Comments ()