我已经在论坛上询问过这个问题,但这似乎足够小众,有自己的问题
我从 here 在线获取了带有余弦距离的片段。但输出似乎不正确...
这是我的代码(注意:我从 np.matmul
更改为 np.dot
但仍然没有区别。我也很困惑为什么我需要使用 transpose
。没有它它将无法工作......:
import PIL
from PIL import Image
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import img_to_array
from sklearn.metrics.pairwise import cosine_similarity
#load model and compile
facenet = load_model('facenet_keras.h5', compile='False')
facenet.compile(optimizer='adam', loss='categorical_crossentropy',metrics=['accuracy'])
def findCosineDistance(a, b):
x = np.dot(np.transpose(a),b)
y = np.dot(np.transpose(a),a)
z = np.dot(np.transpose(b),b)
return (1 - (x / (np.sqrt(y) * np.sqrt(z))))
def dist(a,b):
#prepare image for FaceNet
a,b = Image.open(a), Image.open(b)
a,b = np.array(a), np.array(b)
a,b = Image.fromarray(a), Image.fromarray(b)
a,b = a.resize((160,160)), b.resize((160,160))
a,b = img_to_array(a), img_to_array(b)
a = a.reshape((1,a.shape[0], a.shape[1], a.shape[2]))
b = b.reshape((1,b.shape[0], b.shape[1], b.shape[2]))
#get FaceNet embedding vector
a, b = facenet.predict(a), facenet.predict(b)
#compute distance metric
output = findCosineDistance(a,b)
#print(output)
#print((cosine_similarity(a, b)))
print(output)
输出:
c:/Users/Jerome Ariola/Desktop: RuntimeWarning: invalid value encountered in sqrt
return (1 - (x / (np.sqrt(y) * np.sqrt(z))))
[[ 0. -0.3677783 -0.1329441 ... 0.2845478 -0.33033693
nan]
[ 0.26888728 0. 0.17169017 ... 0.47692382 0.02737373
nan]
[ 0.1173439 -0.2072779 0. ... 0.36850178 -0.17422998
nan]
...
[-0.39771736 -0.9117675 -0.58353555 ... 0. -0.85943496
nan]
[ 0.24831063 -0.02814436 0.14837813 ... 0.4622023 0.
nan]
[ nan nan nan ... nan nan
0. ]]
最佳答案
FaceNet 的 predict()
方法似乎返回包含 NaN 值的人脸嵌入。在计算余弦相似度之前裁剪 NaN 值可能会有所帮助。使用下面的代码行进行相同的操作:
a, b = np.clip(a, -1000, 1000), np.clip(b, -1000, 1000)
注意:使用上述方法从a和b的取值范围中选择合适的阈值进行裁剪。
关于python - FaceNet 嵌入的余弦距离问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59519398/