python - Python3 重写字典行为

标签 python python-3.x numpy

我是一个使用 Python 的初学者,我正在尝试使用字典中的搜索功能来搜索带有点坐标 (2) 的 numpy 数组的键。所以,我想要的是:一个字典,其键是 numpy 数组,其值是整数。然后,in 运算符将用于使用某种容差度量(numpy.allclose 函数)来比较键。我知道 numpy 数组不是可哈希的,所以我必须重写 getitemsetitem 函数(基于我在 How to properly subclass dict and override __getitem__ & __setitem__ 中找到的内容)。但是我如何使这些可散列以将它们添加为字典中的键?在这种情况下,如何覆盖 in 运算符的行为?

感谢您的帮助!

最佳答案

Numpy 数组不可散列,但元组可以。因此,如果将数组转换为元组,则可以对数组进行哈希处理。理论上,如果您也预先对其进行舍入,则可以利用快速查找的优势,因为您现在拥有离散点。但是在重新翻译期间您会遇到解析问题,因为四舍五入是使用十进制基数完成的,但数字是二进制存储的。可以通过将其转换为缩放整数来规避此问题,但这会稍微减慢一切。

最后,您只需要编写一个在数组和元组之间动态来回转换的类,就可以开始了。
实现可能如下所示:

import numpy as np

class PointDict(dict):

    def __init__(self, precision=5):
        super(PointDict, self).__init__()
        self._prec = 10**precision

    def decode(self, tup):
        """
        Turns a tuple that was used as index back into a numpy array.
        """
        return np.array(tup, dtype=float)/self._prec

    def encode(self, ndarray):
        """
        Rounds a numpy array and turns it into a tuple so that it can be used
        as index for this dict.
        """
        return tuple(int(x) for x in ndarray*self._prec)

    def __getitem__(self, item):
        return self.decode(super(PointDict, self).__getitem__(self.encode(item)))

    def __setitem__(self, item, value):
        return super(PointDict, self).__setitem__(self.encode(item), value)

    def __contains__(self, item):
        return super(PointDict, self).__contains__(self.encode(item))

    def update(self, other):
        for item, value in other.items():
            self[item] = value

    def items(self):
        for item in self:
            yield (item, self[item])

    def __iter__(self):
        for item in super(PointDict, self).__iter__():
            yield self.decode(item)

当查找很多点时,带有矢量化批量写入/查找的纯 numpy 解决方案可能会更好。然而,这个解决方案很容易理解和实现。

关于python - Python3 重写字典行为,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32071339/

相关文章:

python - 使用每个切片的百分位数过滤多维 numpy 数组

python - 使用 minidom.toprettyxml 时的空行

python - 如何通过函数在执行命令行脚本(行以 `!` 开头)的 Google Colaboratory 单元中抑制输出

python - 使用 python netaddr cidr_merge 汇总相邻子网

python - 为给定索引替换数据框中的值

Python:使用 sklearn 时为 "ValueError: setting an array element with a sequence"

pandas - 如何在 Python 中查找不包括周末和某些假期的两个日期之间的小时数?营业时间套餐

python - 将两个系列合并为一个索引不匹配的系列

python - 如何矢量化 Pandas 函数,该函数计算属于一个组且介于两个日期之间的行?

python - 使用 python 3.0 的 Numpy