python - pickle numpy数组的子类时保留自定义属性

标签 python arrays numpy pickle python-multiprocessing

我在 the numpy documentation 之后创建了一个 numpy ndarray 的子类.特别是,我有 added a custom attribute通过修改提供的代码。

我正在使用 Python multiprocessing 在并行循环中操作此类的实例。据我了解,范围本质上“复制”到多个线程的方式是使用 pickle

我现在遇到的问题与 numpy 数组的 pickle 方式有关。我找不到任何关于此的综合文档,但有一些 discussions between the dill developers建议我应该关注 __reduce__ 方法,该方法在 pickle 时被调用。

任何人都可以对此有所了解吗?最小的工作示例实际上只是我在上面链接到的 numpy 示例代码,为了完整起见,复制到这里:

import numpy as np

class RealisticInfoArray(np.ndarray):

    def __new__(cls, input_array, info=None):
        # Input array is an already formed ndarray instance
        # We first cast to be our class type
        obj = np.asarray(input_array).view(cls)
        # add the new attribute to the created instance
        obj.info = info
        # Finally, we must return the newly created object:
        return obj

    def __array_finalize__(self, obj):
        # see InfoArray.__array_finalize__ for comments
        if obj is None: return
        self.info = getattr(obj, 'info', None)

现在问题来了:

import pickle

obj = RealisticInfoArray([1, 2, 3], info='foo')
print obj.info  # 'foo'

pickle_str = pickle.dumps(obj)
new_obj = pickle.loads(pickle_str)
print new_obj.info  #  raises AttributeError

谢谢。

最佳答案

np.ndarray使用 __reduce__ pickle 自己。我们可以看看当您调用该函数时它实际返回的内容,以了解发生了什么:

>>> obj = RealisticInfoArray([1, 2, 3], info='foo')
>>> obj.__reduce__()
(<built-in function _reconstruct>, (<class 'pick.RealisticInfoArray'>, (0,), 'b'), (1, (3,), dtype('int64'), False, '\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00'))

所以,我们得到了一个 3 元组。 __reduce__ 的文档描述每个元素在做什么:

When a tuple is returned, it must be between two and five elements long. Optional elements can either be omitted, or None can be provided as their value. The contents of this tuple are pickled as normal and used to reconstruct the object at unpickling time. The semantics of each element are:

  • A callable object that will be called to create the initial version of the object. The next element of the tuple will provide arguments for this callable, and later elements provide additional state information that will subsequently be used to fully reconstruct the pickled data.

    In the unpickling environment this object must be either a class, a callable registered as a “safe constructor” (see below), or it must have an attribute __safe_for_unpickling__ with a true value. Otherwise, an UnpicklingError will be raised in the unpickling environment. Note that as usual, the callable itself is pickled by name.

  • A tuple of arguments for the callable object.

  • Optionally, the object’s state, which will be passed to the object’s __setstate__() method as described in section Pickling and unpickling normal class instances. If the object has no __setstate__() method, then, as above, the value must be a dictionary and it will be added to the object’s __dict__.

所以,_reconstruct是调用来重建对象的函数,(<class 'pick.RealisticInfoArray'>, (0,), 'b')是传递给该函数的参数,(1, (3,), dtype('int64'), False, '\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00'))传递给类(class)' __setstate__ .这给了我们一个机会;我们可以覆盖 __reduce__并将我们自己的元组提供给 __setstate__ , 然后另外覆盖 __setstate__ , 在我们解压时设置我们的自定义属性。我们只需要确保我们保留了父类需要的所有数据,然后调用父类的__setstate__ ,也:

class RealisticInfoArray(np.ndarray):
    def __new__(cls, input_array, info=None):
        obj = np.asarray(input_array).view(cls)
        obj.info = info
        return obj

    def __array_finalize__(self, obj):
        if obj is None: return
        self.info = getattr(obj, 'info', None)

    def __reduce__(self):
        # Get the parent's __reduce__ tuple
        pickled_state = super(RealisticInfoArray, self).__reduce__()
        # Create our own tuple to pass to __setstate__
        new_state = pickled_state[2] + (self.info,)
        # Return a tuple that replaces the parent's __setstate__ tuple with our own
        return (pickled_state[0], pickled_state[1], new_state)

    def __setstate__(self, state):
        self.info = state[-1]  # Set the info attribute
        # Call the parent's __setstate__ with the other tuple elements.
        super(RealisticInfoArray, self).__setstate__(state[0:-1])

用法:

>>> obj = pick.RealisticInfoArray([1, 2, 3], info='foo')
>>> pickle_str = pickle.dumps(obj)
>>> pickle_str
"cnumpy.core.multiarray\n_reconstruct\np0\n(cpick\nRealisticInfoArray\np1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I3\ntp6\ncnumpy\ndtype\np7\n(S'i8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'<'\np11\nNNNI-1\nI-1\nI0\ntp12\nbI00\nS'\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x02\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x03\\x00\\x00\\x00\\x00\\x00\\x00\\x00'\np13\nS'foo'\np14\ntp15\nb."
>>> new_obj = pickle.loads(pickle_str)
>>> new_obj.info
'foo'

关于python - pickle numpy数组的子类时保留自定义属性,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/26598109/

相关文章:

python - np.array 到 PIL 图像 --> Typerror : Cannot handle this data type: (1, 1, 12), |u1

Python数组相乘

c++ - C++中的二维数组加法

python - 为仅返回一组特定值的函数键入提示

arrays - 数组的有序笛卡尔积

php - 如何从 PHP 数组中回显一定数量的元素

python - 为什么我的函数的输出是二进制的?

python - 比较python中的opencv lineartoPolar()转换

python - 按 map 中的值进行 firebase 查询

python - 从 python 脚本调用 gcc 给我 'Undefined symbols: "_main"