05. 深入探究 SSD 训练:提升性能的 3 个技巧

在之前的教程 04. 在 Pascal VOC 数据集上训练 SSD 模型 中,我们简要介绍了有助于构建 SSD 训练管线的基本 API。

在本文中,我们将深入探讨细节并介绍对于复现当前最佳性能至关重要的技巧。这些是论文和技术报告中通常遗漏的隐藏陷阱。

损失归一化:使用批次归一化代替样本归一化

论文中提到的训练目标是定位损失 (loc) 和置信度损失 (conf) 的加权和。

\[L(x, c, l, g) = \frac{1}{N} (L_{conf}(x, c) + \alpha L_{loc}(x, l, g))\]

但问题是,计算 N 的正确方法是什么?我们应该在整个批次中累加 N,还是改用每个样本的 N

为了说明这一点,请生成一些模拟数据

import mxnet as mx
x = mx.random.uniform(shape=(2, 3, 300, 300))  # use batch-size 2
# suppose image 1 has single object
id1 = mx.nd.array([1])
bbox1 = mx.nd.array([[10, 20, 80, 90]])  # xmin, ymin, xmax, ymax
# suppose image 2 has 4 objects
id2 = mx.nd.array([1, 3, 5, 7])
bbox2 = mx.nd.array([[10, 10, 30, 30], [40, 40, 60, 60], [50, 50, 90, 90], [100, 110, 120, 140]])

然后,通过填充 -1 作为标记值将它们合并为一个批次

gt_ids = mx.nd.ones(shape=(2, 4)) * -1
gt_ids[0, :1] = id1
gt_ids[1, :4] = id2
print('class_ids:', gt_ids)

输出

class_ids:
[[ 1. -1. -1. -1.]
 [ 1.  3.  5.  7.]]
<NDArray 2x4 @cpu(0)>
gt_boxes = mx.nd.ones(shape=(2, 4, 4)) * -1
gt_boxes[0, :1, :] = bbox1
gt_boxes[1, :, :] = bbox2
print('bounding boxes:', gt_boxes)

输出

bounding boxes:
[[[ 10.  20.  80.  90.]
  [ -1.  -1.  -1.  -1.]
  [ -1.  -1.  -1.  -1.]
  [ -1.  -1.  -1.  -1.]]

 [[ 10.  10.  30.  30.]
  [ 40.  40.  60.  60.]
  [ 50.  50.  90.  90.]
  [100. 110. 120. 140.]]]
<NDArray 2x4x4 @cpu(0)>

在本例中,我们使用 vgg16 atrous 300x300 SSD 模型。出于演示目的,我们此处不使用任何预训练权重

from gluoncv import model_zoo
net = model_zoo.get_model('ssd_300_vgg16_atrous_voc', pretrained_base=False, pretrained=False)

训练前的一些准备工作

from mxnet import gluon
net.initialize()
conf_loss = gluon.loss.SoftmaxCrossEntropyLoss()
loc_loss = gluon.loss.HuberLoss()

通过手动计算损失来模拟训练步骤:你可以随时使用 gluoncv.loss.SSDMultiBoxLoss,它实现了此功能。

from mxnet import autograd
from gluoncv.model_zoo.ssd.target import SSDTargetGenerator
target_generator = SSDTargetGenerator()
with autograd.record():
    # 1. forward pass
    cls_preds, box_preds, anchors = net(x)
    # 2. generate training targets
    cls_targets, box_targets, box_masks = target_generator(
        anchors, cls_preds, gt_boxes, gt_ids)
    num_positive = (cls_targets > 0).sum().asscalar()
    cls_mask = (cls_targets >= 0).expand_dims(axis=-1)  # negative targets should be ignored in loss
    # 3 losses, here we have two options, batch-wise or sample wise norm
    # 3.1 batch wise normalization: divide loss by the summation of num positive targets in batch
    batch_conf_loss = conf_loss(cls_preds, cls_targets, cls_mask) / num_positive
    batch_loc_loss = loc_loss(box_preds, box_targets, box_masks) / num_positive
    # 3.2 sample wise normalization: divide by num positive targets in this sample(image)
    sample_num_positive = (cls_targets > 0).sum(axis=0, exclude=True)
    sample_conf_loss = conf_loss(cls_preds, cls_targets, cls_mask) / sample_num_positive
    sample_loc_loss = loc_loss(box_preds, box_targets, box_masks) / sample_num_positive
    # Since ``conf_loss`` and ``loc_loss`` calculate the mean of such loss, we want
    # to rescale it back to loss per image.
    rescale_conf = cls_preds.size / cls_preds.shape[0]
    rescale_loc = box_preds.size / box_preds.shape[0]
    # then call backward and step, to update the weights, etc..
    # L = conf_loss + loc_loss * alpha
    # L.backward()

范数不同,但样本级范数加起来与批次级范数相同

print('batch-wise num_positive:', num_positive)
print('sample-wise num_positive:', sample_num_positive)

输出

batch-wise num_positive: 36.0
sample-wise num_positive:
[13. 23.]
<NDArray 2 @cpu(0)>

注意

每张图像的 num_positive 不再是 1 和 4,因为多个锚框可以匹配到单个对象

比较损失

print('batch-wise norm conf loss:', batch_conf_loss * rescale_conf)
print('sample-wise norm conf loss:', sample_conf_loss * rescale_conf)

输出

batch-wise norm conf loss:
[442.7147 675.863 ]
<NDArray 2 @cpu(0)>
sample-wise norm conf loss:
[1225.9791 1057.8724]
<NDArray 2 @cpu(0)>
print('batch-wise norm loc loss:', batch_loc_loss * rescale_loc)
print('sample-wise norm loc loss:', sample_loc_loss * rescale_loc)

输出

batch-wise norm loc loss:
[2.656074  2.1453514]
<NDArray 2 @cpu(0)>
sample-wise norm loc loss:
[7.3552823 3.3579414]
<NDArray 2 @cpu(0)>

哪种更好?乍一看,很难说哪种在理论上更好,因为批次归一化确保损失通过全局统计量得到良好归一化,而样本归一化则确保在某些极端情况下(一张图像中有数百个对象时)梯度不会爆炸。在这种情况下,同一批次中的其他样本可能会被这种异常大的范数抑制。

在我们的实验中,批次归一化在 Pascal VOC 数据集上始终表现更好,贡献了 1~2% 的 mAP 增益。但是,当你使用新的数据集或新的模型时,一定要尝试这两种方法。

初始化器很重要:不要只使用一种初始化器

虽然 SSD 网络基于预训练的特征提取器(称为 base_network),但我们也会在 base_network 后追加未初始化的卷积层,以扩展特征图的级联。

每个输出特征图后也追加了卷积预测器,用作类别预测器和边界框偏移预测器。

对于这些附加的层,我们必须在训练前初始化它们。

from gluoncv import model_zoo
import mxnet as mx
# don't load pretrained for this demo
net = model_zoo.get_model('ssd_300_vgg16_atrous_voc', pretrained=False, pretrained_base=False)
# random init
net.initialize()
# gluon only infer shape when real input data is used
net(mx.nd.zeros(shape=(1, 3, 300, 300)))
# now we have real shape for each parameter
predictors = [(k, v) for k, v in net.collect_params().items() if 'predictor' in k]
name, pred = predictors[0]
print(name, pred)

输出

ssd3_convpredictor0_conv0_weight Parameter ssd3_convpredictor0_conv0_weight (shape=(84, 512, 3, 3), dtype=<class 'numpy.float32'>)

我们可以使用不同的初始化器来初始化它,例如 NormalXavier

pred.initialize(mx.init.Uniform(), force_reinit=True)
print('param shape:', pred.data().shape, 'peek first 20 elem:', pred.data().reshape((-1))[:20])

输出

param shape: (84, 512, 3, 3) peek first 20 elem:
[-0.04006358  0.04752301 -0.04936712  0.02708755 -0.06145268 -0.0103094
  0.04445995  0.02895925 -0.01508887 -0.04410328 -0.05917829  0.00261795
  0.02758304  0.02611597  0.06757144  0.03305504  0.01971556 -0.05105315
 -0.03926021  0.04332945]
<NDArray 20 @cpu(0)>

仅仅从 Uniform 切换到 Xavier 就可以带来约 1% 的 mAP 增益。

pred.initialize(mx.init.Xavier(rnd_type='gaussian', magnitude=2, factor_type='out'), force_reinit=True)
print('param shape:', pred.data().shape, 'peek first 20 elem:', pred.data().reshape((-1))[:20])

输出

param shape: (84, 512, 3, 3) peek first 20 elem:
[ 0.05409709 -0.02777563 -0.05862886  0.0120097  -0.05354748  0.03673649
 -0.01118423 -0.00505917 -0.07389503 -0.05523501 -0.05710729  0.05084738
 -0.04024388 -0.06320304  0.00896897  0.09223884 -0.05637952 -0.00855709
 -0.11271537 -0.01174088]
<NDArray 20 @cpu(0)>

解释置信度分数:单独处理每个类别

如果我们回顾每个类别的置信度预测,其形状为 (B, A, N+1),其中 B 是批次大小,A 是锚框数量,N 是前景类别数量。

print('class prediction shape:', cls_preds.shape)

输出

class prediction shape: (2, 8732, 21)

有两种处理预测的方法

1. 沿着类别轴取预测值的 argmax。这样,只考虑最可能的类别。

2. 单独处理 N 个前景类别。这样,例如,次可能的类别仍然有机会作为最终预测保留下来。

考虑这个例子

cls_pred = mx.nd.array([-1, -2, 3, 4, 6.5, 6.4])
cls_prob = mx.nd.softmax(cls_pred, axis=-1)
for k, v in zip(['bg', 'apple', 'orange', 'person', 'dog', 'cat'], cls_prob.asnumpy().tolist()):
    print(k, v)

输出

bg 0.00027409225003793836
apple 0.00010083290544571355
orange 0.014964930713176727
person 0.040678903460502625
dog 0.49557045102119446
cat 0.4484107196331024

狗和猫的概率非常接近,如果我们使用方法 1,当猫是正确决定时,我们很可能会失败。

结果表明,通过从方法 1 切换到方法 2,我们在评估中获得了 0.5~0.8 的 mAP 增益。

方法 2 的一个明显缺点是它比方法 1 慢得多。对于 N 个类别,方法 2 的复杂度是 O(N),而方法 1 始终是 O(1)。这可能是一个问题,也可能不是问题,具体取决于用例,但如果你想的话,可以随时在它们之间切换。

提示

分别查看 gluoncv.nn.coder.MultiClassDecoder()gluoncv.nn.coder.MultiPerClassDecoder() 以了解方法 1 和方法 2 的实现。

脚本总运行时间: ( 0 分钟 1.410 秒)

由 Sphinx-Gallery 生成的图库