PyTorch自定义数据增强

兼容TorchVision预处理管道

PyTorch自定义数据增强

TorchVision的 transform 模块预置了大量的图像数据增强功能,例如缩放、随机裁切、随机翻转等。对于某些特殊的数据集,可以使用尽可能少的代码实现数据增强。

policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies]
imgs = [
    [augmenter(orig_img) for _ in range(4)]
    for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)

生成的图像如下:

TorchVision的数据增强

在某些特殊的情况下,我们需要按照项目需要来自定义数据增强过程,以下以随机降采样为例,介绍了自定义数据增强的实现方法。