探物 AI > 核心导读: 在视觉深度学习中,数据决定了模型上限。对于算力有限或样本稀缺的开发者,数据增强(Data Augmentation)就是性价比最高的“炼丹秘籍”,很多模型准确度的提高就是靠增加数据,增加数据的普适性,因此本文统计数据增强算法,进行汇总,并进行效果演示。
一、 数据增强的四个维度
1. 几何增强和像素增强
这是最经典的策略,旨在告诉模型:物体在不同光照、不同角度下还是同一个东西。
- 几何变换:水平/垂直翻转、旋转、随机裁剪(Crop)、缩放。
# ─── 1. 几何变换 ──── def aug_geometric(img: np.ndarray) -> dict: h, w = img.shape[:2] results = {} # 水平翻转 results["HorizontalFlip"] = cv2.flip(img, 1) # 垂直翻转 results["VerticalFlip"] = cv2.flip(img, 0) # 旋转 45° M = cv2.getRotationMatrix2D((w // 2, h // 2), 45, 1.0) results["Rotate45"] = cv2.warpAffine(img, M, (w, h)) # 随机裁剪(取中心 70%) cy, cx = int(h * 0.15), int(w * 0.15) crop = img[cy:h - cy, cx:w - cx] results["CenterCrop70%"] = cv2.resize(crop, (w, h)) # 缩放(先缩小再 resize 回来) small = cv2.resize(img, (int(w * 0.6), int(h * 0.6))) results["Scale0.6"] = cv2.resize(small, (w, h)) return results
几何变换:水平/垂直翻转、旋转、随机裁剪(Crop)、缩放。
# ─── 1. 几何变换 ────
def aug_geometric(img: np.ndarray) -> dict:
h, w = img.shape[:2]
results = {}
# 水平翻转
results["HorizontalFlip"] = cv2.flip(img, 1)
# 垂直翻转
results["VerticalFlip"] = cv2.flip(img, 0)
# 旋转 45°
M = cv2.getRotationMatrix2D((w // 2, h // 2), 45, 1.0)
results["Rotate45"] = cv2.warpAffine(img, M, (w, h))
# 随机裁剪(取中心 70%)
cy, cx = int(h * 0.15), int(w * 0.15)
crop = img[cy:h - cy, cx:w - cx]
results["CenterCrop70%"] = cv2.resize(crop, (w, h))
# 缩放(先缩小再 resize 回来)
small = cv2.resize(img, (int(w * 0.6), int(h * 0.6)))
results["Scale0.6"] = cv2.resize(small, (w, h))
return results
- 色彩抖动:调整亮度、对比度、饱和度及色相(Hue)。
# ─── 2. 色彩抖动 ────── def aug_color_jitter(img: np.ndarray) -> dict: results = {} img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32) # 亮度 +60 bright = img.astype(np.float32) bright = np.clip(bright + 60, 0, 255).astype(np.uint8) results["Brightness+60"] = bright # 对比度 ×1.8 contrast = np.clip(img.astype(np.float32) * 1.8, 0, 255).astype(np.uint8) results["Contrast×1.8"] = contrast # 饱和度 ×2(HSV 的 S 通道) sat = img_hsv.copy() sat[..., 1] = np.clip(sat[..., 1] * 2, 0, 255) results["Saturation×2"] = cv2.cvtColor(sat.astype(np.uint8), cv2.COLOR_HSV2RGB) # 色相偏移 +30°(HSV 的 H 通道,范围 0-179) hue = img_hsv.copy() hue[..., 0] = (hue[..., 0] + 30) % 180 results["HueShift+30"] = cv2.cvtColor(hue.astype(np.uint8), cv2.COLOR_HSV2RGB) return results
# ─── 2. 色彩抖动 ──────
def aug_color_jitter(img: np.ndarray) -> dict:
results = {}
img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32)
# 亮度 +60
bright = img.astype(np.float32)
bright = np.clip(bright + 60, 0, 255).astype(np.uint8)
results["Brightness+60"] = bright
# 对比度 ×1.8
contrast = np.clip(img.astype(np.float32) * 1.8, 0, 255).astype(np.uint8)
results["Contrast×1.8"] = contrast
# 饱和度 ×2(HSV 的 S 通道)
sat = img_hsv.copy()
sat[..., 1] = np.clip(sat[..., 1] * 2, 0, 255)
results["Saturation×2"] = cv2.cvtColor(sat.astype(np.uint8), cv2.COLOR_HSV2RGB)
# 色相偏移 +30°(HSV 的 H 通道,范围 0-179)
hue = img_hsv.copy()
hue[..., 0] = (hue[..., 0] + 30) % 180
results["HueShift+30"] = cv2.cvtColor(hue.astype(np.uint8), cv2.COLOR_HSV2RGB)
return results
- 噪声注入:高斯噪声、椒盐噪声(模拟传感器在极端环境下的干扰)。
# ─── 3. 噪声注入 ─────
def aug_noise(img: np.ndarray) -> dict:
results = {}
h, w = img.shape[:2]
# 高斯噪声
gauss = np.random.normal(0, 25, img.shape).astype(np.float32)
noisy = np.clip(img.astype(np.float32) + gauss, 0, 255).astype(np.uint8)
results["GaussianNoise"] = noisy
# 椒盐噪声(1% 像素)
sp = img.copy()
n_pixels = i
🔗 原文链接: 点击阅读原文
文章评论