python - 在 python pickle 文件中保留 numpy strides

标签 python numpy

有没有办法告诉 numpy 在写入 python pickle 文件时保留非标准跨步?

>>> # Create an array with non-standard striding
>>> x = numpy.arange(2*3*4, dtype='uint8').reshape((2,3,4)).transpose(0,2,1)

>>> x.strides
(12, 1, 4)

>>> # The pickling process converts it to a c-contiguous array.
>>> # Often, this is a good thing, but for some applications, the
>>> # non-standard striding is intentional and important to preserve.
>>> pickled = cPickle.dumps(x, protocol=cPickle.HIGHEST_PROTOCOL)
>>> cPickle.loads(pickled).strides
(12, 3, 1)

>>> # This is indeed happening during serialization, not deserialization
>>> pickletools.dis(pickled)
...
151: S        STRING     '\x00\x04\x08\x01\x05\t\x02\x06\n\x03\x07\x0b\x0c\x10\x14\r\x11\x15\x0e\x12\x16\x0f\x13\x17'
...

注意: numpy 足够智能,可以保留 c 连续或 fortran 连续,但它不会保留 pickling 和 unpickling 中的所有非标准步幅模式。

最佳答案

我能想到的唯一方法就是自己动手:

# ---------------------------------------------
import numpy
from numpy.lib.stride_tricks import as_strided
import cPickle

def dumps(x, protocol=cPickle.HIGHEST_PROTOCOL):
    # flatten that keep real data order
    y = as_strided(x, shape=(x.size,), strides=(min(x.strides),))
    return cPickle.dumps([y,x.shape,x.strides],protocol=protocol)

def loads(pickled):
    y,shape,strides = cPickle.loads(pickled)
    return as_strided(y,shape=shape,strides=strides)

if __name__=='__main__':
    x = numpy.arange(2*3*4, dtype='uint8').reshape((2,3,4)).transpose(0,2,1)

    pickled = dumps(x)
    y = loads(pickled)

    print 'x strides =', x.strides
    print 'y strides =', y.strides
    print 'x==y:', (x==y).all()
# ---------------------------------------------

输出:

x strides = (12, 1, 4)
y strides = (12, 1, 4)
x==y: True

关于python - 在 python pickle 文件中保留 numpy strides,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/10439003/

相关文章:

android - 使用 QPython 连接到接入点?

python - learning_rate 不是合法参数

python - 将请求的响应保存到文件

python - 在神经网络中对输入序列进行二进制编码或填充?

python - 向量化结果添加到 numpy 数组

python - 无法重写代码以列出理解python中图像旋转的代码

python - numpy.log 中的 "RuntimeWarning: divide by zero encountered in log"即使小值被过滤掉

python - 在 Pandas 中设置多列索引

python - 如何在 python 中计算昂贵的高精度总和?

python - 将 2D numpy.ndarray 转换为嵌套字典