在介绍如何贡献 Transform 数据增强之前,建议先阅读 MMCV 的 ,本文档只是对教程的简化和总结。
实现一个简单的自定义数据变换
要实现自定义的数据变换,通常需要以下步骤:
定义一个数据增强类,并将其注册到 mmcv
的 TRANSFORMS
注册器中
实现 transform 函数,并让其接受字典类型的输入,返回字典类型的输出
以实现翻数据增强为例:
import random
import mmcv
from mmcv.transforms import BaseTransform, TRANSFORMS
@TRANSFORMS.register_module() # 1. 注册
class MyFlip(BaseTransform): # 2. 继承 BaseTransform
def __init__(self, direction: str):
super().__init__()
self.direction = direction
def transform(self, results: dict) -> dict: # 接受字典类型输入
img = results['img']
results['img'] = mmcv.imflip(img, direction=self.direction) # 接受字典类型输出
return results
至此,一个简单的数据增强就已经完成了
常用的数据增强工具
OpenMMLab 系列算法库实现了非常多的 Transforms,在实现过程中我们发现很多数据变化都存在共性,因此针对这些共性 MMCV 提供了一系列的数据增强工具,方便大家复用代码,快速实现自定义的数据增强。这些工具不仅可以提高数据数据增强的开发效率,也能够通过组合的方式,基于现有的数据增强组合出一些新的数据增强。因此在实现新的数据增强时,我们需要考虑是否有必要,新的需求能否通过组合已有的数据增强来实现。
Compose
pipeline = [
dict(type='LoadImageFromFile', key='gt_img'),
dict(type='Compose', transforms=[
dict(type='RandomCrop', crop_size=(384, 384)),
dict(type='Normalize'),
])
]
虽说看上去 Compose
和在列表中扩展数据增强的效果一样,但是当 Compose
和 TTA 组合时,会有意想不到的效果
KeyMapper
# config
pipeline = [
dict(type='LoadImageFromFile', key='gt_img'),
# 使用 KeyMapper 将外部(原始)字段 'gt_img' 映射到内部字段 'img'
dict(type='KeyMapper',
mapping=dict(img='gt_img'), # 定义输入时的字段映射
auto_remap=True, # 等价于 remapping=dict(img='gt_img')
allow_nonexist_keys=True, # 即允许 results 中不包含 `gt_img` 字段
transforms=[
# Transform 实现中使用标准字段 'img' 即可
dict(type='Crop', crop_size=(384, 384), random_crop=True),
dict(type='Normalize'),
])
]
上面这段配置就能实现将 gt_img
字段映射到 img
,并且在 Crop
和 Normalize
时使用 img
字段,并最后将 img
字段映射回 gt_img
输出到 results
中。
pipeline = [
dict(type='LoadImageFromFile', key='lq'), # low-quality image
dict(type='LoadImageFromFile', key='gt'), # ground-truth image
# 使用 TransformBroadcaster,将多个外部字段 ('lq' 和 'gt')依次映射到内部字段
# 'img',并用 wrapped transforms 依次处理
dict(type='TransformBroadcaster',
mapping=dict(img=['lq', 'gt']), # 情况 1: 来自多个字段
# input_mapping=dict(img='images'), #情况 2: 来自一个包含多个数据的字段
auto_remap=True,
share_random_param=True, # 在处理多个数据字段时,使用一组相同的随机参数
transforms=[
# Transform 实现中使用标准字段 'img' 即可
dict(type='RandomCrop', crop_size=(384, 384)),
dict(type='Normalize'),
])
]
import numpy as np
from mmcv.transforms.utils import cache_randomness
@TRANSFORMS.register_module() # 1. 注册
class MyFlip(BaseTransform): # 2. 继承 BaseTransform
def __init__(self, direction: str):
super().__init__()
self.direction = direction
@cache_randomness
def _should_flip(self):
return np.random.random() > 0.5
def transform(self, results: dict) -> dict: # 接受字典类型输入
if self._should_flip():
img = results['img']
results['img'] = mmcv.imflip(img, direction=self.direction) # 接受字典类型输出
return results
再强调一遍,要保证自定义数据增强在 TransformBroadcaster
中,能够使用相同的随机参数处理不同字段的数据,需要满足两个条件:
使用 cache_randomness
装饰器装饰一个或多个随机方法
RandomApply 和 RandomApply
为了为数据增强引入更强的随机性,MMCV 实现了:
# 使用 RandomApply 在 2 个 sub-pipeline 中随机选择
pipeline = [
...
dict(type='RandomChoice',
transforms=[
[dict(type='RandomHorizontalFlip')], # sub-pipeline 1
[dict(type='RandomRotate')], # sub-pipeline 2
]
),
...
]
# 使用 RandomApply 随机执行一个 sub-pipeline
pipeline = [
...
dict(type='RandomApply',
transforms=dict(type='RandomHorizontalFlip'),
prob=0.3),
...
]
测试时数据增强
dict(type='TestTimeAug',
transforms=[
[dict(type='Resize', scale=(1333, 400), keep_ratio=True),
dict(type='Resize', scale=(1333, 800), keep_ratio=True)],
[dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)],
[[dict(type='CenterCrop', crop_size=100), dict(type='RandomRotate', crop_size=100)],
[dict(type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor', 'flip',
'flip_direction'))]])
可以对同一张图片进行 8 次增强:
flowchart TD;
ori_image --> Resize1["Resize1(scale=(1333, 400))"];
ori_image --> Resize2["Resize1(scale=(1333, 800))"];
Resize1 --> RandomFlip11["RandomFlip1(prob=0)"];
Resize1 --> RandomFlip12["RandomFlip1(prob=1)"];
Resize2 --> RandomFlip21["RandomFlip1(prob=0)"];
Resize2 --> RandomFlip22["RandomFlip1(prob=1)"];
RandomFlip11 --> CenterCrop111["CenterCrop111(crop_size=100)"];
RandomFlip11 --> CenterCrop112["CenterCrop112(crop_size=100)"];
RandomFlip12 --> CenterCrop121["CenterCrop121(crop_size=100)"];
RandomFlip12 --> CenterCrop122["CenterCrop122(crop_size=100)"];
RandomFlip21 --> CenterCrop211["CenterCrop211(crop_size=100)"];
RandomFlip21 --> CenterCrop212["CenterCrop212(crop_size=100)"];
RandomFlip22 --> CenterCrop221["CenterCrop221(crop_size=100)"];
RandomFlip22 --> CenterCrop222["CenterCrop222(crop_size=100)"];
用户可通过堆叠不同的增强策略以实现指数级别的数据增强
PR 参考