我想在 s390x 架构上使用预训练的 MXNet 模型,但它似乎不起作用。这是因为预训练的模型是小端的,而 s390x 是大端的。所以,我正在尝试使用 https://numpy.org/devdocs/reference/generated/numpy.lib.format.html它适用于小端和大端。
解决这个问题的一种方法是我发现在 x86 机器上加载模型参数,调用 asnumpy,通过 numpy 保存然后使用 numpy 在 s390x 机器上加载参数并将它们转换为 MXNet。但我不确定如何编码。任何人都可以帮助我吗?
更新
这个问题似乎不清楚。因此,我添加了一个示例,它可以通过 3 个步骤更好地解释我想要做什么 -
net = mx.gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=mx.cpu())
gluon.contrib.utils.export(net, path="./my_model")
net = gluon.contrib.utils.import(symbol_file="my_model-symbol.json",
param_file="my_model-0000.params",
ctx = 'cpu')
我想使用 numpy 加载我们在第 2 步中创建的 .npy 文件,而不是使用 MXNet API 加载。加载 .npy 文件后,我们需要将其转换为 MXNet。所以,我终于可以在 MXNet 中使用模型了。
最佳答案
从另一个问题中发布的代码片段开始,Save/Load MXNet model parameters using NumPy :
似乎 mxnet 可以选择将数据在内部存储为 numpy 数组:
mx.npx.set_np(True, True)
不幸的是,这个选项没有达到我希望的效果(我的 IPython session 崩溃了)。参数是
dict
的 mxnet.gluon.parameter.Parameter
实例,每个实例都包含其他特殊数据类型的属性。解决这个问题以便您可以将其存储为大量纯 numpy 数组(或它们的集合在 .npz
文件中)是一项无望的任务。幸运的是,python 有
pickle
将复杂的数据结构转换成或多或少可移植的东西:# (mxnet/resnet setup skipped)
parameters = resnet.collect_params()
import pickle
with open('foo.pkl', 'wb') as f:
pickle.dump(parameters, f)
恢复参数:with open('foo.pkl', 'rb') as f:
parameters_loaded = pickle.load(f)
本质上,它看起来像 resnet.save_parameters()
如 mxnet/gluon/block.py
中所定义获取参数(使用 _collect_parameters_with_prefix()
)并使用自定义写入函数将它们写入文件,该函数似乎是从 C 编译的(我没有检查细节)。您可以使用
pickle
保存参数反而。用于加载,
load_parameters
(也在 util.py
中)包含此代码(删除了健全性检查):for name in loaded:
params[name]._load_init(loaded[name], ctx, cast_dtype=cast_dtype, dtype_source=dtype_source)
在这里,loaded
是从文件加载的字典。通过检查代码,我没有完全掌握正在加载的内容 - params
似乎是函数中不再使用的局部变量。但值得一试,从这里开始,编写 load_parameters
的替代品。功能。您可以通过在类外定义一个函数,将函数“猴子补丁”到现有类中,如下所示:def my_load_parameters(self, ...):
... (put your modified implementation here)
mx.gluon.Block.load_parameters = my_load_parameters
免责声明/警告:pickle
保存/加载要在单个大端系统上工作,不能保证在不同端系统之间工作。 pickle 协议(protocol)本身是 endian-neutral,但是如果浮点值(在 mxnet.gluon.parameter.Parameter
深处)被存储为机器端约定中的原始数据缓冲区,那么 pickle 不会神奇地猜测缓冲区需要反转。我认为 numpy 数组在腌制时是字节序安全的。关于python - 使用 numpy 进行 MXNet 参数序列化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62796927/