我正在使用带有 sklearn.neighbors.BallTree
的自定义度量函数,但我遇到了问题,因为 BallTree
似乎在传递数据之前更改了数据到我的度量函数。下面是一个示例来说明这一点:
from sklearn.neighbors import BallTree
import numpy as np
np.random.seed(0)
data = np.random.randint(0, 20, size=(2, 3))
def metric(x, y):
print('Data passed to metric')
print(x)
print(y)
return 1
print('Original data')
print(data)
BallTree(data, metric=metric)
这给了我
Original data
[[12 15 0]
[ 3 3 7]]
Data passed to metric
[7.5 9. 3.5]
[12. 15. 0.]
Data passed to metric
[7.5 9. 3.5]
[3. 3. 7.]
在将数据传递给 metric
之前,BallTree
会进行哪些预处理?有办法关掉它吗?它甚至似乎改变了对 metric
调用之间的数据...
(我的真实用例 - 我使用 Levenstein 距离作为度量并使用字符串。但是,由于我无法直接传入字符串,因此我将每个字符转换为预定义的标记并传入一个数组 token 。由于数据被修改,我无法再撤消编码以将字符串返回到我的度量函数中,以便我可以正确计算 Levenstein 距离。如果您在使用时有更好的解决方案来查找最近邻居字符串而不是数字数据,我也很高兴听到这一点)。
最佳答案
事实并非如此。
BallTree
对象不会更改您的数据。
- 它会创建您数据的副本,因为:
Note: if X is a C-contiguous array of doubles then data will not be copied. Otherwise, an internal copy will be made.
- 它计算对象和树节点边界之间的距离。如下所示,您可以使用
get_arrays
函数获取内部数组,通过查看源代码您发现边界是[7.5, 9. , 3.5]
,它是与您的对象进行比较的代码。
Source :
def get_arrays(self):
return (self.data_arr, self.idx_array_arr,
self.node_data_arr, self.node_bounds_arr)
输出:
bt.get_arrays()
Out[x]:
(array([[12., 15., 0.],
[ 3., 3., 7.]]), array([0, 1]), array([(0, 2, 1, 1.)],
dtype=[('idx_start', '<i8'), ('idx_end', '<i8'), ('is_leaf', '<i8'), ('radius', '<f8')]), array([[[7.5, 9. , 3.5]]]))
因此,您的指标将应用于数据和节点,而不仅仅是您的数据本身,并且节点与您的数据不同。您可以尝试词嵌入,它允许您计算距离而无需解码数据。不确定您要做什么,但也许基于树的模型并不是适合您的用例的最佳方法。
关于python - sklearn BallTree 更改传递给指标的数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49638947/