python - 在运行时获取 HybridBlock 层形状

标签 python python-3.x mxnet

我正在尝试构建一个自定义池层(适用于 ndarray 和 Symbol)并且我需要知道运行时的输入形状。根据文档,HybridBlock具有“infer_shape”功能,但我无法使其工作。有什么指示可以指出我做错了什么吗?

mxnet版本

1.0.0 ,从 conda、python3 构建。

最小可重现示例

例如:

import mxnet as mx
import mxnet.ndarray as nd
from mxnet.gluon import HybridBlock

class runtime_shape(HybridBlock):


    def __init__(self,  **kwards):
        HybridBlock.__init__(self,**kwards)


    def hybrid_forward(self,F,_input):

        print (self.infer_shape(_input))

        return _input

xx = nd.random_uniform(shape=[5,5,16,16])

mynet = runtime_shape()
mynet.hybrid_forward(nd,xx)

错误消息:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-41-3f539a940958> in <module>()
----> 1 mynet.hybrid_forward(nd,xx)

<ipython-input-38-afc9785b716d> in hybrid_forward(self, F, _input)
     17     def hybrid_forward(self,F,_input):
     18 
---> 19         print (self.infer_shape(_input))
     20 
     21         return _input

 /home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in infer_shape(self, *args)
    460     def infer_shape(self, *args):
    461         """Infers shape of Parameters from inputs."""
--> 462         self._infer_attrs('infer_shape', 'shape', *args)
    463 
    464     def infer_type(self, *args):

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in _infer_attrs(self, infer_fn, attr, *args)
    448     def _infer_attrs(self, infer_fn, attr, *args):
    449         """Generic infer attributes."""
--> 450         inputs, out = self._get_graph(*args)
    451         args, _ = _flatten(args)
    452         arg_attrs, _, aux_attrs = getattr(out, infer_fn)(

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in _get_graph(self, *args)
    369             params = {i: j.var() for i, j in self._reg_params.items()}
    370             with self.name_scope():
--> 371                 out = self.hybrid_forward(symbol, *grouped_inputs, **params)  # pylint: disable=no-value-for-parameter
    372             out, self._out_format = _flatten(out)
    373 

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in __exit__(self, ptype, value, trace)
     78         if self._block._empty_prefix:
     79             return
---> 80         self._name_scope.__exit__(ptype, value, trace)
     81         self._name_scope = None
     82         _BlockScope._current = self._old_scope

AttributeError: 'NoneType' object has no attribute '__exit__'

最佳答案

HybridBlock 的想法是让命令式世界中的调试变得容易,您可以简单地放置一个断点或 print 语句,并查看哪些数据正在流经您的网络。当您确信网络正在执行您想要的操作时,您可以调用 .hybridize() 并享受速度的提高。

在开发网络并使用命令式模式时,您可以简单地打印: 打印('形状',_input.shape)

并在使用网络的混合版本时删除此行,因为这仅适用于 NDArray。

如果这不能回答您的问题,您能否通过获取输入数据的形状来精确说明您想要实现的目标是什么?

关于python - 在运行时获取 HybridBlock 层形状,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48070604/

相关文章:

python - 仅使用 ‘Sequential()’ 或 ‘HybridSequential()’ 作为容器是否有任何副作用?

python - Graphviz 不绘制图表

python - 检查 command 在 Python 中是否有问题?

python - 如何验证 pip 的 --extra-index-url?

python - QCombobox finData 方法始终返回 -1 与 numpy 数组

python - 在 Python3 中使用 for 循环为 vigenere 密码创建 2D 列表

machine-learning - mxnet 训练没有进展

python - Scikit 学习错误地输入值

python - 使用 one-hot 编码结果按日期、类别和客户对客户订单进行分组

python - 如何在代表相同标签的每个标记上方设置标题