注意
点击 here 下载完整的示例代码
1. 使用预训练的 Mask RCNN 模型进行预测¶
本文展示了如何使用预训练的 Mask RCNN 模型。
Mask RCNN 网络是 Faster RCNN 网络的扩展。gluoncv.model_zoo.MaskRCNN
继承自 gluoncv.model_zoo.FasterRCNN
。强烈建议先阅读 02. 使用预训练的 Faster RCNN 模型进行预测。
首先导入一些必要的库
from matplotlib import pyplot as plt
from gluoncv import model_zoo, data, utils
加载预训练模型¶
让我们获取一个使用 ResNet-50 主干并在 COCO 数据集上训练的 Mask RCNN 模型。通过指定 pretrained=True
,如有必要,它将自动从模型库下载模型。有关更多预训练模型,请参阅 模型库。
返回的模型是一个 HybridBlock gluoncv.model_zoo.MaskRCNN
,其默认上下文为 cpu(0)。
net = model_zoo.get_model('mask_rcnn_resnet50_v1b_coco', pretrained=True)
预处理图像¶
预处理步骤与 Faster RCNN 相同。
接下来我们下载一张图片,并使用预设的数据转换进行预处理。默认行为是将图像的短边调整大小到 600px。但您可以输入任意大小的图像。
如果您想一起加载多张图片,可以向 gluoncv.data.transforms.presets.rcnn.load_test()
提供一个图像文件名列表,例如 [im_fname1, im_fname2, ...]
。
此函数返回两个结果。第一个是形状为 (batch_size, RGB_channels, height, width) 的 NDArray。它可以直接输入模型。第二个结果包含 numpy 格式的图像,便于绘制。由于我们只加载了一张图像,x 的第一维是 1。
请注意,orig_img 的短边已调整大小到 600px。
im_fname = utils.download('https://github.com/dmlc/web-data/blob/master/' +
'gluoncv/detection/biking.jpg?raw=true',
path='biking.jpg')
x, orig_img = data.transforms.presets.rcnn.load_test(im_fname)
输出
Downloading biking.jpg from https://github.com/dmlc/web-data/blob/master/gluoncv/detection/biking.jpg?raw=true...
0%| | 0/244 [00:00<?, ?KB/s]
100%|##########| 244/244 [00:00<00:00, 44614.42KB/s]
推理和显示¶
Mask RCNN 模型返回预测的类别 ID、置信度得分、边界框坐标和分割掩码。它们的形状分别为 (batch_size, num_bboxes, 1)、(batch_size, num_bboxes, 1)、(batch_size, num_bboxes, 4) 和 (batch_size, num_bboxes, mask_size, mask_size)。对于本教程中使用的模型,mask_size 是 14。
目标检测结果
我们可以使用 gluoncv.utils.viz.plot_bbox()
来可视化结果。我们将第一个图像的结果进行切片并输入到 plot_bbox 中。
绘制分割掩码
gluoncv.utils.viz.expand_mask()
将调整分割掩码的大小并填充原始图像中的边界框尺寸。gluoncv.utils.viz.plot_mask()
将修改图像以叠加分割掩码。
ids, scores, bboxes, masks = [xx[0].asnumpy() for xx in net(x)]
# paint segmentation mask on images directly
width, height = orig_img.shape[1], orig_img.shape[0]
masks, _ = utils.viz.expand_mask(masks, bboxes, (width, height), scores)
orig_img = utils.viz.plot_mask(orig_img, masks)
# identical to Faster RCNN object detection
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
ax = utils.viz.plot_bbox(orig_img, bboxes, scores, ids,
class_names=net.classes, ax=ax)
plt.show()

脚本总运行时间: ( 0 分 9.145 秒)