python - 如何优化这段代码以进行 nn 预测?

标签 python performance numpy scipy nearest-neighbor

如何优化这段代码? 目前,它的运行速度因经过此循环的数据量而变慢。此代码运行 1-最近邻。它将根据p_data_set预测training_element的标签

#               [x] ,           [[x1],[x2],[x3]],    [l1, l2, l3]
def prediction(training_element, p_data_set, p_label_set):
    temp = np.array([], dtype=float)
    for p in p_data_set:
        temp = np.append(temp, distance.euclidean(training_element, p))

    minIndex = np.argmin(temp)
    return p_label_set[minIndex]

最佳答案

使用 k-D tree用于快速最近邻查找,例如scipy.spatial.cKDTree :

from scipy.spatial import cKDTree

# I assume that p_data_set is (nsamples, ndims)
tree = cKDTree(p_data_set)

# training_elements is also assumed to be (nsamples, ndims)
dist, idx = tree.query(training_elements, k=1)

predicted_labels = p_label_set[idx]

关于python - 如何优化这段代码以进行 nn 预测?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39625552/

相关文章:

python - 如何在多线程 python 应用程序中创建全局错误处理程序

python - 使用二维数组作为索引从二维 numpy 数组中提取元素

performance - Julia:有效地将压缩的元组映射到元组

python - Sounddevice ValueError : could not broadcast input array from shape (2048) into shape (2048, 1)

python - ndarray连接错误

python - 如何在 PyCharm 项目 View 中移动项目文件夹?

python - MySQL并行选择请求(python)

jquery - 基于数值属性的 DOM 元素更高效的 jQuery 排序方法

java - Android如何多次调用一个方法并在它们之间有延迟

python - 关于使用空 ndarray 进行 numpy 索引的奇怪事情