常用的数据增广方法
1. 对图片进行按比例缩放
2. 对图片进行随机位置的截取
3. 对图片进行随机的水平和竖直翻转
4. 对图片进行随机角度的旋转
5. 对图片进行亮度、对比度和颜色的随机变化
下面使用torchvision演示一下这些数据增强方法。
1 2 3
| import sys from PIL import Image from torchvision import transforms
|
1 2 3
| img = Image.open('img.jpg') img
|

随机比例缩放
随机比例缩放使用的是torchvision.transforms.Resize()
函数,函数有两个参数,第一个参数为缩放大小,如果为一个值则会按比例缩放,否则按传入的值缩放;第二个参数表示缩放图片使用的方法,默认的是双线性差值。
1 2 3 4 5 6 7 8
| print('缩放前尺寸为:{}'.format(img.size)) new_img = transforms.Resize(224)(img) print('缩放后尺寸为:{}'.format(new_img.size)) new_img
缩放前尺寸为:(134, 43) 缩放后尺寸为:(698, 224)
|

1 2
| new_img = transforms.RandomCrop(224, padding=8)(new_img) new_img
|

1 2 3 4 5
| new_img = transforms.Resize((224, 224))(img) print('缩放后尺寸为:{}'.format(new_img.size)) new_img
缩放后尺寸为:(224, 224)
|

随机位置截取
随机位置截取能够提取图片中的局部信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果,在torchvision中主要有以下两种方式,一个是torchvision.transforms.RandomCrop()
,传入的参数是截取出图片的长和宽,在图片的随机位置进行截取;第二个是torchvision.transforms.CenterCrop()
,同样传入图片的长和宽,会在图片的中心进行截取。
1 2 3
| random_img = transforms.RandomCrop(100)(new_img) random_img
|

1 2 3
| center_img = transforms.CenterCrop(100)(new_img) center_img
|

1 2 3
| random_img2 = transforms.RandomCrop(224, padding=8)(new_img) random_img2
|

随机水平翻转和竖直翻转
torchvision.transforms.RandomHorizontalFlip()
和torchvision.transforms.RandomVerticalFlip()
1 2 3
| h_flip = transforms.RandomHorizontalFlip()(new_img) h_flip
|

1 2 3
| v_flip = transforms.RandomVerticalFlip()(new_img) v_flip
|

随机角度旋转
torchvision.transforms.RandomRotation()
1 2
| rot_im = transforms.RandomRotation(30)(new_img) rot_im
|

亮度、对比度和颜色变化
torchvision.transforms.ColorJitter()
函数有四个参数,第一个参数为亮度,第二个参数为对比度,第三个参数为饱和度,第四个参数为颜色
1 2 3
| bright_img = transforms.ColorJitter(brightness=1)(new_img) bright_img
|

1 2 3
| contrast_img = transforms.ColorJitter(contrast=1)(new_img) contrast_img
|

1 2 3
| saturation_img = transforms.ColorJitter(saturation=1)(new_img) saturation_img
|

1 2 3
| color_img = transforms.ColorJitter(hue=0.5)(new_img) color_img
|

1 2 3
| compose_img = transforms.ColorJitter(0.5, 0.5, 0.5)(new_img) compose_img new_img
|

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| import matplotlib.pyplot as plt %matplotlib inline
img_transform = transforms.Compose([ transforms.Resize(232), transforms.RandomCrop(224), transforms.ColorJitter(0.15, 0.15, 0.15) ])
nrows = 5 ncols = 5 figsize = (10, 10)
_, figs = plt.subplots(nrows, ncols, figsize = figsize) for i in range(nrows): for j in range(ncols): figs[i][j].imshow(img_transform(new_img)) figs[i][j].axes.get_xaxis().set_visible(False) figs[i][j].axes.get_yaxis().set_visible(False) plt.show()
|
