1. 使用预训练的 Mask RCNN 模型进行预测

本文展示了如何使用预训练的 Mask RCNN 模型。

Mask RCNN 网络是 Faster RCNN 网络的扩展。gluoncv.model_zoo.MaskRCNN 继承自 gluoncv.model_zoo.FasterRCNN。强烈建议先阅读 02. 使用预训练的 Faster RCNN 模型进行预测

首先导入一些必要的库

from matplotlib import pyplot as plt
from gluoncv import model_zoo, data, utils

加载预训练模型

让我们获取一个使用 ResNet-50 主干并在 COCO 数据集上训练的 Mask RCNN 模型。通过指定 pretrained=True,如有必要,它将自动从模型库下载模型。有关更多预训练模型,请参阅 模型库

返回的模型是一个 HybridBlock gluoncv.model_zoo.MaskRCNN,其默认上下文为 cpu(0)

net = model_zoo.get_model('mask_rcnn_resnet50_v1b_coco', pretrained=True)

预处理图像

预处理步骤与 Faster RCNN 相同。

接下来我们下载一张图片,并使用预设的数据转换进行预处理。默认行为是将图像的短边调整大小到 600px。但您可以输入任意大小的图像。

如果您想一起加载多张图片,可以向 gluoncv.data.transforms.presets.rcnn.load_test() 提供一个图像文件名列表,例如 [im_fname1, im_fname2, ...]

此函数返回两个结果。第一个是形状为 (batch_size, RGB_channels, height, width) 的 NDArray。它可以直接输入模型。第二个结果包含 numpy 格式的图像,便于绘制。由于我们只加载了一张图像,x 的第一维是 1。

请注意,orig_img 的短边已调整大小到 600px。

im_fname = utils.download('https://github.com/dmlc/web-data/blob/master/' +
                          'gluoncv/detection/biking.jpg?raw=true',
                          path='biking.jpg')
x, orig_img = data.transforms.presets.rcnn.load_test(im_fname)

输出

Downloading biking.jpg from https://github.com/dmlc/web-data/blob/master/gluoncv/detection/biking.jpg?raw=true...

  0%|          | 0/244 [00:00<?, ?KB/s]
100%|##########| 244/244 [00:00<00:00, 44614.42KB/s]

推理和显示

Mask RCNN 模型返回预测的类别 ID、置信度得分、边界框坐标和分割掩码。它们的形状分别为 (batch_size, num_bboxes, 1)、(batch_size, num_bboxes, 1)、(batch_size, num_bboxes, 4) 和 (batch_size, num_bboxes, mask_size, mask_size)。对于本教程中使用的模型,mask_size 是 14。

目标检测结果

我们可以使用 gluoncv.utils.viz.plot_bbox() 来可视化结果。我们将第一个图像的结果进行切片并输入到 plot_bbox 中。

绘制分割掩码

gluoncv.utils.viz.expand_mask() 将调整分割掩码的大小并填充原始图像中的边界框尺寸。gluoncv.utils.viz.plot_mask() 将修改图像以叠加分割掩码。

ids, scores, bboxes, masks = [xx[0].asnumpy() for xx in net(x)]

# paint segmentation mask on images directly
width, height = orig_img.shape[1], orig_img.shape[0]
masks, _ = utils.viz.expand_mask(masks, bboxes, (width, height), scores)
orig_img = utils.viz.plot_mask(orig_img, masks)

# identical to Faster RCNN object detection
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
ax = utils.viz.plot_bbox(orig_img, bboxes, scores, ids,
                         class_names=net.classes, ax=ax)
plt.show()
demo mask rcnn

脚本总运行时间: ( 0 分 9.145 秒)

由 Sphinx-Gallery 生成的图库