python - 如何比较持有 numpy.ndarray 的数据类的相等性(bool(a==b) 引发 ValueError)?

标签 python numpy python-dataclasses

如果我创建一个包含 Numpy ndarray 的 Python 数据类,我将无法再使用自动生成的 __eq__

import numpy as np

@dataclass
class Instr:
    foo: np.ndarray
    bar: np.ndarray

arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

这是因为 ndarray.__eq__ 有时 通过比较 a[0] 返回真值的 ndarray > 到 b[0],依此类推,直到 2 中较长的一个。这是相当复杂且不直观的,事实上,只有当数组形状不同或具有不同形状时才会引发错误不同的值(value)观或其他什么。

如何安全地比较持有 Numpy 数组的 @dataclass

<小时/>

@dataclass__eq__ 实现是使用 eval() 生成的。堆栈跟踪中缺少其源代码,无法使用 inspect 查看,但它实际上使用元组比较,调用 bool(foo)。

import dis
dis.dis(Instr.__eq__)

摘录:

  3          12 LOAD_FAST                0 (self)
             14 LOAD_ATTR                1 (foo)
             16 LOAD_FAST                0 (self)
             18 LOAD_ATTR                2 (bar)
             20 BUILD_TUPLE              2
             22 LOAD_FAST                1 (other)
             24 LOAD_ATTR                1 (foo)
             26 LOAD_FAST                1 (other)
             28 LOAD_ATTR                2 (bar)
             30 BUILD_TUPLE              2
             32 COMPARE_OP               2 (==)
             34 RETURN_VALUE

最佳答案

解决方案是放入您自己的 __eq__ 方法并设置 eq=False 以便数据类不会生成自己的(尽管检查最后的 docs步骤不是必需的,但我认为无论如何明确一下都是很好的)。

import numpy as np

def array_eq(arr1, arr2):
    return (isinstance(arr1, np.ndarray) and
            isinstance(arr2, np.ndarray) and
            arr1.shape == arr2.shape and
            (arr1 == arr2).all())

@dataclass(eq=False)
class Instr:

    foo: np.ndarray
    bar: np.ndarray

    def __eq__(self, other):
        if not isinstance(other, Instr):
            return NotImplemented
        return array_eq(self.foo, other.foo) and array_eq(self.bar, other.bar)
<小时/>

编辑

针对通用数据类的通用且快速的解决方案,其中某些值是 numpy 数组,而另一些值则不是

import numpy as np
from dataclasses import dataclass, astuple

def array_safe_eq(a, b) -> bool:
    """Check if a and b are equal, even if they are numpy arrays"""
    if a is b:
        return True
    if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
        return a.shape == b.shape and (a == b).all()
    try:
        return a == b
    except TypeError:
        return NotImplemented

def dc_eq(dc1, dc2) -> bool:
   """checks if two dataclasses which hold numpy arrays are equal"""
   if dc1 is dc2:
        return True
   if dc1.__class__ is not dc2.__class__:
       return NotImplmeneted  # better than False
   t1 = astuple(dc1)
   t2 = astuple(dc2)
   return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))

# usage
@dataclass(eq=False)
class T:

   a: int
   b: np.ndarray
   c: np.ndarray

   def __eq__(self, other):
        return dc_eq(self, other)

关于python - 如何比较持有 numpy.ndarray 的数据类的相等性(bool(a==b) 引发 ValueError)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51743827/

相关文章:

python - 我如何正确处理多维 numpy 数组

Python 3.7 : dataclass does not raise `TypeError` for `eq=False`

python - 同一图中的多个等值线图

python - 检查相邻值是否在 Numpy 矩阵中

python - 使用 sklearn.cluster Kmeans 时出现内存错误

python - 如何在keras中手动获取与model.predict()相同的输出

python - 如何声明与数据类类型相同的python数据类成员字段

python - 如何输入提示动态类实例化,如 pydantic 和数据类?

python - numpy - 自动改变形状?

python - 如何从文件中绘制多条垂直线?