python - 如何从左侧实现对象与 Numpy 数组的乘法?

标签 python numpy

如果想要将左侧的任意对象与 np.ndarray 相乘,就会遇到问题。

问题是,如果右侧是迄今为止未知的类型,则 numpy.ndarray.__mul__ 会按元素调用右侧的 __rmul__

要查看问题,您可以复制粘贴以下代码并从命令行或 SDK 运行它。您可以通过它单步执行调试器。评论可以帮助您并指导您完成整个过程...

import numpy as np


class Zora(object):

    def __init__(self, array):
        self._array = array
        self._values_field_name = array.dtype.names[-1]

    @property
    def a(self):
        return self._array

    def __repr__(self):
        return repr(self._array)

    def __mul__(self, other):
        result = self.copy()
        result *= other
        return result

    def __imul__(self, other):
        self._array[self._values_field_name] *= other
        return self

    def __rmul__(self, other):
        return self * other

    def copy(self):
        return self.__class__(self.a.copy())

    @classmethod
    def create(cls, fields, codes, data):
        array = np.array([(*c, d) for c, d in zip(codes, data)], dtype=fields)
        return cls(array)


if __name__ == '__main__':

    # Let's create a Zora dataset with  scenarios
    scenarios = 7

    fields = np.dtype([('LegalEntity', np.unicode_, 32), ('Division', np.unicode_, 32),
                       ('Scenarios', np.float64, (scenarios, ))])

    legal_entities = ['A', 'A', 'B', 'B', 'C']
    divisions = ['a', 'b', 'a', 'b', 'b']

    codes = list(zip(legal_entities, divisions))

    data = np.random.uniform(0., 1., (len(codes), scenarios))

    zora = Zora.create(fields, codes, data)

    # The dataset looks like the following
    print(zora)

    # We can multiply it from the left with scalars ...
    z = zora * 2
    print(zora * 2)

    # ... and with column vectors, for example ...
    # ... for this we generate a columns vector with some weights ...
    numrows = zora.a.shape[0]

    weights = np.expand_dims(np.array(list(range(numrows))), 1)
    # ... the weights ...
    print(weights)

    # ... left side multiplication works fine too with this
    print(zora * weights)

    # Let's show inplace multiplication ...
    # Which we apply on a copy, so that we can still compare ...
    z = zora.copy()
    z *= 2

    # Is pretty fine too, ...
    print(zora)
    print(z)

    # Now it becomes a bit special ...
    # ... when multiplying from the left.
    # It works fine with a scalar..
    z = 2 * zora
    print(z)

    # But becomes special with np.ndarrays ...
    print('-------------------------------------')
    print('-------------------------------------')
    print('The following result ...')
    z = weights * zora
    print(z)

    # which is not the same, but should, as ...
    print('-------------------------------------')
    print('... should be the same as this one ...')

    z = zora * weights
    print(z)

    # We got a list of arrays, where for each i-th array
    # the corresponding i-th weight has been used for
    # multiplication
    #
    # This has to do with numpy's implementation of calling
    # __rmul__ from the right hand side within the __mul__
    # from the np.ndarray ...
    #
    # The same is true for all other __r{...}__ methods

最佳答案

使用 Numpy Hook __array_ufunc__

Numpy 允许您通过在对象上正确定义 __array_ufunc__ 方法来规避此问题。

将以下方法添加到上面的 Zora 类中即可完成这项工作。很容易将其推广到所有二进制 uf 函数,包括非交换运算。但我只展示这一点,并不是为了让事情变得复杂。

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
    lhs, rhs = inputs
    return rhs * lhs

因此,以下将显示预期结果...

import numpy as np


class Zora(object):

    def __init__(self, array):
        self._array = array
        self._values_field_name = array.dtype.names[-1]

    @property
    def a(self):
        return self._array

    def __repr__(self):
        return repr(self._array)

    def __mul__(self, other):
        result = self.copy()
        result *= other
        return result

    def __imul__(self, other):
        self._array[self._values_field_name] *= other
        return self

    def __rmul__(self, other):
        return self * other

    def copy(self):
        return self.__class__(self.a.copy())

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        lhs, rhs = inputs
        return rhs * lhs

    @classmethod
    def create(cls, fields, codes, data):
        array = np.array([(*c, d) for c, d in zip(codes, data)], dtype=fields)
        return cls(array)


if __name__ == '__main__':

    # Let's create a Zora dataset with  scenarios
    scenarios = 7

    fields = np.dtype([('LegalEntity', np.unicode_, 32), ('Division', np.unicode_, 32),
                       ('Scenarios', np.float64, (scenarios, ))])

    legal_entities = ['A', 'A', 'B', 'B', 'C']
    divisions = ['a', 'b', 'a', 'b', 'b']

    codes = list(zip(legal_entities, divisions))

    data = np.random.uniform(0., 1., (len(codes), scenarios))

    zora = Zora.create(fields, codes, data)

    # The dataset looks like the following
    print(zora)

    # We can multiply it from the left with scalars ...
    z = zora * 2
    print(zora * 2)

    # ... and with column vectors, for example ...
    # ... for this we generate a columns vector with some weights ...
    numrows = zora.a.shape[0]

    weights = np.expand_dims(np.array(list(range(numrows))), 1)
    # ... the weights ...
    print(weights)

    # ... left side multiplication works fine too with this
    print(zora * weights)

    # Let's show inplace multiplication ...
    # Which we apply on a copy, so that we can still compare ...
    z = zora.copy()
    z *= 2

    # Is pretty fine too, ...
    print(zora)
    print(z)

    # Now it becomes a bit special ...
    # ... when multiplying from the left.
    # It works fine with a scalar..
    z = 2 * zora
    print(z)

    # But becomes special with np.ndarrays ...
    print('-------------------------------------')
    print('-------------------------------------')
    print('The following result ...')
    z = weights * zora
    print(z)

    # which is now the same ...
    print('-------------------------------------')
    print('... should be the same as this one ...')

    z = zora * weights
    print(z)

关于python - 如何从左侧实现对象与 Numpy 数组的乘法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62023723/

相关文章:

python - 如何追加屏蔽数组

python - Numpy 数组插入第二个数组中的每隔一个元素

Python 复杂对象索引 - 访问特定嵌套元素时遇到问题

python - 使用数据结构传递多个参数

python - 以编程方式获取 python 函数参数的描述

python - 如何使用PyPDF2设置注释的透明背景?

Python 导入在解释器中有效,在脚本 Numpy/Matplotlib 中不起作用

python - 为什么 numpy.prod() 对于我的一长串自然数错误地返回负结果或 0?

Python多维符号自动转置

python - 将整个数据框中的 NaN 值替换为其他值的平均值