PyTorch自定义数据增强
兼容TorchVision预处理管道
TorchVision的 transform
模块预置了大量的图像数据增强功能,例如缩放、随机裁切、随机翻转等。对于某些特殊的数据集,可以使用尽可能少的代码实现数据增强。
policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies]
imgs = [
[augmenter(orig_img) for _ in range(4)]
for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)
生成的图像如下:
在某些特殊的情况下,我们需要按照项目需要来自定义数据增强过程,以下以随机降采样为例,介绍了自定义数据增强的实现方法。
实现自定义数据增强,可以通过继承 torch.nn.Module
类来实现。
import torch
import torchvision.transforms.functional as TF
class RandomDownSample(torch.nn.Module):
"""Randomly down-sample the images.
Args:
min_height: 随机下采样时的最小尺寸。
max_height: 随机下采样时的最大尺寸。
"""
然后,为该类实现两个方法 __init__
与 __call__
。
__init__
方法可以用来初始化我们的降采样参数,例如指定降采样时允许的最小与最大尺寸。
def __init__(self, min_height, max_height):
self.min_height = min_height
self.max_height = max_height
__call__
方法则负责具体的计算过程。这里分两步:第一步使用随机数生成器随机生成图像高度,然后通过调用TorchVision的 resize
函数将目标函数缩放为指定大小。
def __call__(self, x):
height = int(torch.randint(self.min_height, self.max_height, (1, )))
return TF.resize(x, height)
最终的实现看起来是这个样子:
class RandomDownSample(torch.nn.Module):
"""Randomly down-sample the images.
Args:
min_height: 随机下采样时的最小尺寸。
max_height: 随机下采样时的最大尺寸。
"""
def __init__(self, min_height, max_height):
self.min_height = min_height
self.max_height = max_height
def __call__(self, x):
height = int(torch.randint(self.min_height, self.max_height, (1, )))
return TF.resize(x, height)
实现完成后,即可以通过 compose
来将预处理串联使用了。
from torchvision import transforms
transform_train = transforms.Compose([
RandomDownSample(16, 64),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=[0.5], std=[0.5])
])
Comments ()