11. 使用预训练的 CenterNet 模型进行预测

本文展示了如何仅使用几行代码来使用预训练的 CenterNet 模型。

首先,让我们导入一些必要的库

from gluoncv import model_zoo, data, utils
from matplotlib import pyplot as plt

加载预训练模型

让我们获取一个在 Pascal VOC 数据集上训练的 CenterNet 模型,使用 resnet18_v1b 作为基础模型。通过指定 pretrained=True,如有必要,它会自动从模型库下载模型。有关更多预训练模型,请参阅模型库

net = model_zoo.get_model('center_net_resnet18_v1b_voc', pretrained=True)

输出

Downloading /root/.mxnet/models/center_net_resnet18_v1b_voc-38c509d4.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/center_net_resnet18_v1b_voc-38c509d4.zip...

  0%|          | 0/51562 [00:00<?, ?KB/s]
  0%|          | 102/51562 [00:00<01:02, 821.12KB/s]
  1%|          | 510/51562 [00:00<00:22, 2287.24KB/s]
  4%|4         | 2187/51562 [00:00<00:06, 7431.59KB/s]
 15%|#4        | 7732/51562 [00:00<00:01, 24013.65KB/s]
 26%|##6       | 13564/51562 [00:00<00:01, 35346.01KB/s]
 41%|####1     | 21266/51562 [00:00<00:00, 47145.40KB/s]
 56%|#####5    | 28661/51562 [00:00<00:00, 44141.51KB/s]
 69%|######9   | 35709/51562 [00:00<00:00, 50966.07KB/s]
 80%|#######9  | 41091/51562 [00:01<00:00, 48957.85KB/s]
 94%|#########3| 48361/51562 [00:01<00:00, 55308.03KB/s]
51563KB [00:01, 41304.09KB/s]

预处理图像

接下来,我们下载一张图像,并使用预设的数据转换进行预处理。这里我们指定将图像的短边调整大小为 512 像素。您可以输入任意大小的图像,但是,由于模型是使用 512x512 图像训练的,因此在特定输入分辨率下可能会表现更好。

如果您想一起加载多张图像,可以向 gluoncv.data.transforms.presets.yolo.load_test() 提供一个图像文件名列表,例如 [im_fname1, im_fname2, ...]

此函数返回两个结果。第一个是形状为 (batch_size, RGB_channels, height, width) 的 NDArray。它可以直接馈送到模型中。第二个是 numpy 格式的图像,便于绘制。由于我们只加载了一张图像,因此 x 的第一个维度为 1。

im_fname = utils.download('https://raw.githubusercontent.com/zhreshold/' +
                          'mxnet-ssd/master/data/demo/dog.jpg',
                          path='dog.jpg')
x, img = data.transforms.presets.center_net.load_test(im_fname, short=512)
print('Shape of pre-processed image:', x.shape)

输出

Shape of pre-processed image: (1, 3, 512, 683)

推理并显示

forward 函数将返回所有检测到的边界框、对应的预测类别 ID 和置信度得分。它们的形状分别为 (batch_size, num_bboxes, 1)(batch_size, num_bboxes, 1)(batch_size, num_bboxes, 4)

我们可以使用 gluoncv.utils.viz.plot_bbox() 来可视化结果。我们对第一个图像的结果进行切片并将其馈送到 plot_bbox

demo center net

脚本总运行时间: ( 0 分钟 2.948 秒)

由 Sphinx-Gallery 生成的图库