如何优化这段代码? 目前,它的运行速度因经过此循环的数据量而变慢。此代码运行 1-最近邻。它将根据p_data_set预测training_element的标签
# [x] , [[x1],[x2],[x3]], [l1, l2, l3]
def prediction(training_element, p_data_set, p_label_set):
temp = np.array([], dtype=float)
for p in p_data_set:
temp = np.append(temp, distance.euclidean(training_element, p))
minIndex = np.argmin(temp)
return p_label_set[minIndex]
最佳答案
使用 k-D tree用于快速最近邻查找,例如scipy.spatial.cKDTree
:
from scipy.spatial import cKDTree
# I assume that p_data_set is (nsamples, ndims)
tree = cKDTree(p_data_set)
# training_elements is also assumed to be (nsamples, ndims)
dist, idx = tree.query(training_elements, k=1)
predicted_labels = p_label_set[idx]
关于python - 如何优化这段代码以进行 nn 预测?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39625552/