PyTorch自定义数据增强

兼容TorchVision预处理管道

PyTorch自定义数据增强

TorchVision的 transform 模块预置了大量的图像数据增强功能,例如缩放、随机裁切、随机翻转等。对于某些特殊的数据集,可以使用尽可能少的代码实现数据增强。

policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies]
imgs = [
    [augmenter(orig_img) for _ in range(4)]
    for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)

生成的图像如下:

TorchVision的数据增强

在某些特殊的情况下,我们需要按照项目需要来自定义数据增强过程,以下以随机降采样为例,介绍了自定义数据增强的实现方法。

实现自定义数据增强,可以通过继承 torch.nn.Module 类来实现。

import torch
import torchvision.transforms.functional as TF

class RandomDownSample(torch.nn.Module):
    """Randomly down-sample the images.

    Args:
        min_height: 随机下采样时的最小尺寸。
        max_height: 随机下采样时的最大尺寸。
    """

然后,为该类实现两个方法 __init____call__

__init__ 方法可以用来初始化我们的降采样参数,例如指定降采样时允许的最小与最大尺寸。

    def __init__(self, min_height, max_height):
        self.min_height = min_height
        self.max_height = max_height

__call__ 方法则负责具体的计算过程。这里分两步:第一步使用随机数生成器随机生成图像高度,然后通过调用TorchVision的 resize 函数将目标函数缩放为指定大小。

    def __call__(self, x):
        height = int(torch.randint(self.min_height, self.max_height, (1, )))
        return TF.resize(x, height)

最终的实现看起来是这个样子:

class RandomDownSample(torch.nn.Module):
    """Randomly down-sample the images.

    Args:
        min_height: 随机下采样时的最小尺寸。
        max_height: 随机下采样时的最大尺寸。
    """

    def __init__(self, min_height, max_height):
        self.min_height = min_height
        self.max_height = max_height

    def __call__(self, x):
        height = int(torch.randint(self.min_height, self.max_height, (1, )))
        return TF.resize(x, height)

实现完成后,即可以通过 compose 来将预处理串联使用了。

from torchvision import transforms
transform_train = transforms.Compose([
        RandomDownSample(16, 64),
        transforms.ConvertImageDtype(torch.float),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])