我有以下 python 函数:
def npnearest(u: np.ndarray, X: np.ndarray, Y: np.ndarray, distance: 'callbale'=npdistance):
'''
Finds x1 so that x1 is in X and u and x1 have a minimal distance (according to the
provided distance function) compared to all other data points in X. Returns the label of x1
Args:
u (np.ndarray): The vector (ndim=1) we want to classify
X (np.ndarray): A matrix (ndim=2) with training data points (vectors)
Y (np.ndarray): A vector containing the label of each data point in X
distance (callable): A function that receives two inputs and defines the distance function used
Returns:
int: The label of the data point which is closest to `u`
'''
xbest = None
ybest = None
dbest = float('inf')
for x, y in zip(X, Y):
d = distance(u, x)
if d < dbest:
ybest = y
xbest = x
dbest = d
return ybest
其中,npdistance
只是给出两点之间的距离,即
def npdistance(x1, x2):
return(np.sum((x1-x2)**2))
我想通过直接在 numpy
中执行最近邻搜索来优化 npnearest
。这意味着该函数不能使用 for/while
循环。
谢谢
最佳答案
由于您不需要使用确切的函数,因此您可以简单地更改总和以在特定轴上工作。这将返回一个包含计算结果的新列表,您可以调用 argmin 来获取最小值的索引。使用它并查找您的标签:
import numpy as np
def npdistance_idx(x1, x2):
return np.argmin(np.sum((x1-x2)**2, axis=1))
Y = ["label 0", "label 1", "label 2", "label 3"]
u = np.array([[1, 5.5]])
X = np.array([[1,2], [1, 5], [0, 0], [7, 7]])
idx = npdistance_idx(X, u)
print(Y[idx]) # label 1
关于python - 使用numpy优化python函数而不使用for循环,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58575179/