注意
点击此处下载完整示例代码
10. 通过重用预训练模型的部分跳过微调¶
公共数据集上预训练的检测模型通常需要微调才能应用于我们感兴趣的领域,这是一个难题。在本教程中,我们将展示一种非常有趣的方式来重用预训练模型。
基本上,您可以获取一个GluonCV预训练的检测模型,将其类别重置为COCO类别的一个子集,然后它就可以立即使用,无需任何微调。
首先导入一些必要的库
from matplotlib import pyplot as plt
import gluoncv
from gluoncv import model_zoo, data, utils
加载预训练模型¶
让我们获取一个在COCO数据集上使用ResNet-50骨干网络训练的Faster RCNN模型。
net = model_zoo.get_model('faster_rcnn_resnet50_v1b_coco', pretrained=True)
输出
Downloading /root/.mxnet/models/faster_rcnn_resnet50_v1b_coco-5b4690fb.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/faster_rcnn_resnet50_v1b_coco-5b4690fb.zip...
0%| | 0/123471 [00:00<?, ?KB/s]
0%| | 101/123471 [00:00<02:35, 793.24KB/s]
0%| | 520/123471 [00:00<00:54, 2268.98KB/s]
2%|1 | 2185/123471 [00:00<00:16, 7211.21KB/s]
6%|6 | 7729/123471 [00:00<00:04, 23565.05KB/s]
12%|#1 | 14625/123471 [00:00<00:02, 38378.16KB/s]
18%|#8 | 22617/123471 [00:00<00:01, 51584.18KB/s]
24%|##4 | 29945/123471 [00:00<00:01, 58012.94KB/s]
31%|###1 | 38467/123471 [00:00<00:01, 66370.94KB/s]
37%|###7 | 45718/123471 [00:00<00:01, 68238.40KB/s]
44%|####3 | 54313/123471 [00:01<00:00, 72781.48KB/s]
50%|##### | 62120/123471 [00:01<00:00, 74367.01KB/s]
57%|#####6 | 69976/123471 [00:01<00:00, 75421.65KB/s]
63%|######3 | 78162/123471 [00:01<00:00, 77344.08KB/s]
70%|######9 | 85931/123471 [00:01<00:00, 76281.15KB/s]
76%|#######5 | 93586/123471 [00:01<00:00, 72638.40KB/s]
82%|########1 | 100901/123471 [00:01<00:00, 64100.74KB/s]
88%|########7 | 108330/123471 [00:01<00:00, 66800.28KB/s]
94%|#########4| 116294/123471 [00:01<00:00, 70328.60KB/s]
123472KB [00:02, 60063.64KB/s]
预处理图像¶
类似于Faster RCNN推理教程,我们获取并预处理一个示例图像
im_fname = utils.download('https://github.com/dmlc/web-data/blob/master/' +
'gluoncv/detection/biking.jpg?raw=true',
path='biking.jpg')
x, orig_img = data.transforms.presets.rcnn.load_test(im_fname)
将类别重置为我们想要的¶
原始COCO模型有80个类别
print('coco classes: ', net.classes)
net.reset_class(classes=['bicycle', 'backpack'], reuse_weights=['bicycle', 'backpack'])
# now net has 2 classes as desired
print('new classes: ', net.classes)
输出
coco classes: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
new classes: ['bicycle', 'backpack']
更灵活的旧权重重用映射策略¶
我们也支持使用字典进行1对1类别权重重映射。因此我们可以利用这一点来重新映射一些类别。
net = model_zoo.get_model('faster_rcnn_resnet50_v1b_coco', pretrained=True)
net.reset_class(classes=['spaceship'], reuse_weights={'spaceship':'bicycle'})
box_ids, scores, bboxes = net(x)
ax = utils.viz.plot_bbox(orig_img, bboxes[0], scores[0], box_ids[0], class_names=net.classes)
plt.show()

同样适用于不同模型¶
我们可以将此策略应用于SSD、YOLO和Mask-RCNN模型。现在我们可以使用Mask RCNN并重置类别以仅检测人物。
net = model_zoo.get_model('mask_rcnn_resnet50_v1b_coco', pretrained=True)
net.reset_class(classes=['person'], reuse_weights=['person'])
ids, scores, bboxes, masks = [xx[0].asnumpy() for xx in net(x)]
# paint segmentation mask on images directly
width, height = orig_img.shape[1], orig_img.shape[0]
masks, _ = utils.viz.expand_mask(masks, bboxes, (width, height), scores)
orig_img = utils.viz.plot_mask(orig_img, masks)
# identical to Faster RCNN object detection
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
ax = utils.viz.plot_bbox(orig_img, bboxes, scores, ids,
class_names=net.classes, ax=ax)
plt.show()

输出
Downloading /root/.mxnet/models/mask_rcnn_resnet50_v1b_coco-a3527fdc.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/mask_rcnn_resnet50_v1b_coco-a3527fdc.zip...
0%| | 0/131497 [00:00<?, ?KB/s]
0%| | 89/131497 [00:00<03:03, 714.80KB/s]
0%| | 507/131497 [00:00<00:58, 2256.39KB/s]
2%|1 | 2187/131497 [00:00<00:17, 7301.73KB/s]
6%|5 | 7794/131497 [00:00<00:05, 23929.23KB/s]
11%|#1 | 14844/131497 [00:00<00:02, 39181.80KB/s]
18%|#7 | 23483/131497 [00:00<00:01, 54226.90KB/s]
24%|##3 | 31019/131497 [00:00<00:01, 60835.05KB/s]
30%|##9 | 39159/131497 [00:00<00:01, 66515.10KB/s]
36%|###6 | 47358/131497 [00:00<00:01, 71218.45KB/s]
42%|####1 | 54993/131497 [00:01<00:01, 72769.45KB/s]
48%|####8 | 63551/131497 [00:01<00:00, 76633.28KB/s]
54%|#####4 | 71284/131497 [00:01<00:00, 76420.80KB/s]
60%|###### | 79512/131497 [00:01<00:00, 78175.94KB/s]
66%|######6 | 87365/131497 [00:01<00:00, 78257.67KB/s]
72%|#######2 | 95216/131497 [00:01<00:00, 78165.56KB/s]
78%|#######8 | 103050/131497 [00:01<00:00, 78046.31KB/s]
84%|########4 | 110867/131497 [00:01<00:00, 75708.71KB/s]
90%|######### | 118836/131497 [00:01<00:00, 76485.69KB/s]
97%|#########6| 127172/131497 [00:01<00:00, 78505.04KB/s]
100%|##########| 131497/131497 [00:02<00:00, 64054.72KB/s]