python - MxNet:label_shapes 与label_names 指定的名称不匹配

标签 python machine-learning computer-vision deep-learning mxnet

我编写了一个脚本,使用我用 MxNet 训练的模型对单个输入图像进行分类。为了对传入图像进行分类,我通过网络将它们前馈。

简而言之,这就是我正在做的事情:

symbol, arg_params, aux_params = mx.model.load_checkpoint('model-prefix', 42)
model = mx.mod.Module(symbol=symbol, context=mx.cpu())
model.bind(data_shapes=[('data', (1, 3, 224, 244))], for_training=False)
model.set_params(arg_params, aux_params)

# ... loading the image & resizing ...
# img is the image to classify as numpy array of shape (3, 244, 244)

Batch = namedtuple('Batch', ['data'])
self._model.forward(Batch(data=[mx.nd.array(img)]))
probabilities = self._model.get_outputs()[0].asnumpy()

print(str(probabilities))

这工作正常,除了我收到以下警告

UserWarning: Data provided by label_shapes don't match names specified by label_names ([] vs. ['softmax_label'])

我应该更改什么以避免收到此警告?我不清楚 label_shapeslabel_names 参数的含义,以及我期望用什么来填充它们。

注意:我找到了一些关于它们的线程,但没有一个能让我解决问题。同样,MxNet 文档没有提供有关这些参数是什么以及如何填充它们的详细信息。

最佳答案

设置label_names=Noneallow_missing=True。这应该消除警告。

model = mx.mod.Module(symbol=symbol, context=mx.cpu(), label_names=None)
...
model.set_params(arg_params, aux_params, allow_missing=True)

如果您好奇为什么首先打印警告,

每个模块都有关联的标签。训练此模型时,使用 softmax_label 作为标签(很可能是因为输出层是名为“softmax”的 softmax 层)。从文件加载模型时,创建的模块将 softmax_label 作为模块的标签。

>>>print(model.label_names)
['softmax_label']
然后调用

model.bind 而不提供 label_shapes。

model.bind(data_shapes=[('data', (1, 3, 224, 244))], for_training=False)

MXNet 发现模块中包含一个在绑定(bind)期间未提供的标签,并对此进行了提示 - 这是您看到的警告消息。

我认为如果使用for_training=False调用bind,MXNet不应该提示缺少标签。我创建了这个问题:https://github.com/dmlc/mxnet/issues/6958

但是,对于我们从磁盘加载模型的特殊情况,我们可以使用 None 作为标签来加载它,这样 MXNet 以后就不会在绑定(bind)不提供标签时提示 - 这这就是建议的修复的作用。

关于python - MxNet:label_shapes 与label_names 指定的名称不匹配,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44947104/

相关文章:

c++ - 快速访问 jpeg 图像的像素值

python - 对象提取与构建

python - Pandas 按年份汇总值,但保留原始 TimeSeries 索引

python - Lightfm : handling user and item cold-start

python - 如何修改 .htaccess 中的 sys.path 以允许 mod_python 看到 Django?

python - 为什么 pygame 在退出之前需要 time.sleep 来播放声音?

computer-vision - Caffe:具有不同数量标签的多标签图像

machine-learning - 在 Keras 2.0 上使用合并层(lambda/函数)?

machine-learning - 设置 .eval() 时,我的模型表现比设置 .train() 时差

python - 虚线图像中的OpenCV跟踪线