python - sklearn BallTree 更改传递给指标的数据

标签 python scikit-learn

我正在使用带有 sklearn.neighbors.BallTree 的自定义度量函数,但我遇到了问题,因为 BallTree 似乎在传递数据之前更改了数据到我的度量函数。下面是一个示例来说明这一点:

from sklearn.neighbors import BallTree
import numpy as np

np.random.seed(0)
data = np.random.randint(0, 20, size=(2, 3))
def metric(x, y):
    print('Data passed to metric')
    print(x)
    print(y)
    return 1

print('Original data')
print(data)
BallTree(data, metric=metric)

这给了我

Original data
[[12 15  0]
 [ 3  3  7]]
Data passed to metric
[7.5 9.  3.5]
[12. 15.  0.]
Data passed to metric
[7.5 9.  3.5]
[3. 3. 7.]

在将数据传递给 metric 之前,BallTree 会进行哪些预处理?有办法关掉它吗?它甚至似乎改变了对 metric 调用之间的数据...

(我的真实用例 - 我使用 Levenstein 距离作为度量并使用字符串。但是,由于我无法直接传入字符串,因此我将每个字符转换为预定义的标记并传入一个数组 token 。由于数据被修改,我无法再撤消编码以将字符串返回到我的度量函数中,以便我可以正确计算 Levenstein 距离。如果您在使用时有更好的解决方案来查找最近邻居字符串而不是数字数据,我也很高兴听到这一点)。

最佳答案

事实并非如此。

BallTree 对象不会更改您的数据。

  1. 它会创建您数据的副本,因为:

Note: if X is a C-contiguous array of doubles then data will not be copied. Otherwise, an internal copy will be made.

  • 它计算对象和树节点边界之间的距离。如下所示,您可以使用 get_arrays 函数获取内部数组,通过查看源代码您发现边界是 [7.5, 9. , 3.5],它是与您的对象进行比较的代码。
  • Source :

    def get_arrays(self):
            return (self.data_arr, self.idx_array_arr,
                    self.node_data_arr, self.node_bounds_arr)
    

    输出:

    bt.get_arrays()                                                                                                                                                                                           
    Out[x]:                                                                                                                                                                                                           
    (array([[12., 15.,  0.],                                                                                                                                                                                           
            [ 3.,  3.,  7.]]), array([0, 1]), array([(0, 2, 1, 1.)],                                                                                                                                                   
           dtype=[('idx_start', '<i8'), ('idx_end', '<i8'), ('is_leaf', '<i8'), ('radius', '<f8')]), array([[[7.5, 9. , 3.5]]])) 
    

    因此,您的指标将应用于数据和节点,而不仅仅是您的数据本身,并且节点与您的数据不同。您可以尝试词嵌入,它允许您计算距离而无需解码数据。不确定您要做什么,但也许基于树的模型并不是适合您的用例的最佳方法。

    关于python - sklearn BallTree 更改传递给指标的数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49638947/

    相关文章:

    python - sklearn 中的分类树给出不一致的答案

    python - LogisticRegression.predict_proba 的 scikit-learn 返回值

    python - 如何使用 Pandas 计算另一列中每个值在一列中的出现次数?

    python - 如何统计自定义类的实例数?

    javascript - python 中的 Selenium 错误:NoSuchElementException ./ancestor-or-self::form

    csv - scikit learn - 将存储为字符串的特征转换为数字

    scikit-learn - 多标签分类的特征选择(scikit-learn)

    python - 如何修改matplotlib-venn中的字体大小

    python - 无法检测到 vtk python

    python - 更改使用导出 graphviz 创建的决策 TreeMap 的颜色