准备您的数据集为 ImageRecord 格式

原始图像是计算机视觉任务的自然数据格式。然而,当从图像文件加载数据进行训练时,磁盘 IO 可能会成为瓶颈。

例如,在 AWS p3.16xlarge 实例上使用 ImageNet 训练 ResNet50 模型时,8 块 GPU 的并行训练速度非常快,即使从 ramdisk 读取图像也无法跟上。

为了在顶级配置平台上提高性能,我们建议用户使用 MXNet 的 ImageRecord 格式进行训练。

准备工作

只需几行代码即可为您的图像创建 ImageRecord 文件。

假设我们有一个文件夹 ./example,其中图像按类别放置在不同的子文件夹中

./example/class_A/1.jpg
./example/class_A/2.jpg
./example/class_A/3.jpg
./example/class_B/4.jpg
./example/class_B/5.jpg
./example/class_B/6.jpg
./example/class_C/100.jpg
./example/class_C/1024.jpg
./example/class_D/65535.jpg
./example/class_D/0.jpg
...

首先,我们需要生成一个 .lst 文件,即包含这些图像的标签和文件名信息的列表。

python im2rec.py ./example_rec ./example/ --recursive --list --num-thread 8

执行后,您可能会发现生成了一个 ./example_rec.lst 文件。有了这个文件,下一步是

python im2rec.py ./example_rec ./example/ --recursive --pass-through --pack-label --num-thread 8

这会生成另外两个文件:example_rec.idxexample_rec.rec。现在,您可以使用它们进行训练了!

对于验证集,我们通常不会打乱图像的顺序,因此相应的命令是

python im2rec.py ./example_rec_val ./example_val --recursive --list --num-thread 8
python im2rec.py ./example_rec_val ./example_val --recursive --pass-through --pack-label --no-shuffle --num-thread 8

ImageNet 的 ImageRecord 文件

如前所述,ImageNet 训练可以从 ImageRecord 格式改进的 IO 速度中受益。

我们在教程 “准备 ImageNet 数据集” 中使用了相同的脚本,但参数不同。请仔细阅读并提前下载 imagenet 文件。

首先,请下载辅助脚本 imagenet.py 和验证图像信息 imagenet_val_maps.pklz。确保将它们放在同一目录中。

假设 tar 文件保存在文件夹 ~/ILSVRC2012 中。我们可以使用以下命令自动准备数据集。

python imagenet.py --download-dir ~/ILSVRC2012 --with-rec

注意

提取图像可能需要一段时间。例如,在带有 EBS 的 AWS EC2 实例上大约需要 30 分钟。

默认情况下,imagenet.py 会将图像提取到 ~/.mxnet/datasets/imagenet。您可以通过设置 --target-dir 指定不同的目标文件夹。

使用 ImageRecordIter 读取

准备好的数据集可以直接使用工具类 mxnet.io.ImageRecordIter 加载。这是一个示例,每次随机读取 128 张图像并执行随机缩放和裁剪。

import os
from mxnet import nd
from mxnet.io import ImageRecordIter

rec_path = os.path.expanduser('~/.mxnet/datasets/imagenet/rec/')

# You need to specify ``root`` for ImageNet if you extracted the images into
# a different folder
train_data = ImageRecordIter(
    path_imgrec = os.path.join(rec_path, 'train.rec'),
    path_imgidx = os.path.join(rec_path, 'train.idx'),
    data_shape  = (3, 224, 224),
    batch_size  = 32,
    shuffle     = True
)
for batch in train_data:
    print(batch.data[0].shape, batch.label[0].shape)
    break

输出

(32, 3, 224, 224) (32,)

绘制一些验证图像

from gluoncv.utils import viz
val_data = ImageRecordIter(
    path_imgrec = os.path.join(rec_path, 'val.rec'),
    path_imgidx = os.path.join(rec_path, 'val.idx'),
    data_shape  = (3, 224, 224),
    batch_size  = 32,
    shuffle     = False
)
for batch in val_data:
    viz.plot_image(nd.transpose(batch.data[0][12], (1, 2, 0)))
    viz.plot_image(nd.transpose(batch.data[0][21], (1, 2, 0)))
    break
  • recordio
  • recordio

脚本总运行时间: ( 0 分 4.546 秒)

由 Sphinx-Gallery 生成的画廊