r - 如何使用apply函数计算两个矩阵之间的距离

标签 r matrix apply euclidean-distance

我正在尝试计算两个矩阵之间的欧氏距离。我已经实现了使用 2 个 for 循环但尝试矢量化计算以加快速度。我使用 pdist 作为基准来验证距离是否计算正确。

感谢这篇文章,https://medium.com/@souravdey/l2-distance-matrix-vectorization-trick-26aa3247ac6c ,我试图用这段代码在 r 中实现同样的事情:

dist <- sqrt(rowSums(xtest**2)+rowSums(xtrain**2)-2*xtrain %*% t(xtest))

但结果与 pdist 的结果不同。我不确定这有什么问题。

下面是一些代码

创建一些数据

xtest=matrix(cbind(c(0,0),c(1,31)),2,2,byrow=TRUE)
xtrain=matrix(cbind(c(9,2),c(4,15),c(7,8),c(-22,-2)),4,2,byrow=TRUE)

使用双循环计算

mydist <- function(xtest,xtrain) {
  euc.dist <- function(x1, x2) sqrt(sum((x1 - x2) ^ 2))
  dist <- matrix(,nrow=nrow(xtrain),ncol=nrow(xtest))
  for (i in 1:nrow(xtrain)){
    for (j in 1:nrow(xtest)){
      dist[i,j] <- euc.dist(xtrain[i,], xtest[j,])
    }
  }
  return (dist)
}
> mydist(xtest,xtrain)
          [,1]     [,2]
[1,]  9.219544 30.08322
[2,] 15.524175 16.27882
[3,] 10.630146 23.76973
[4,] 22.090722 40.22437

结果与使用 pdist 相同

> libdists <- pdist(xtrain,xtest)
> as.matrix(libdists)
          [,1]     [,2]
[1,]  9.219544 30.08322
[2,] 15.524175 16.27882
[3,] 10.630146 23.76973
[4,] 22.090721 40.22437

但是如果我用矩阵乘法就错了

> mydist2 <- function(xtest,xtrain) {
+   dist <- sqrt(rowSums(xtest**2)+rowSums(xtrain**2)-2*xtrain %*% t(xtest))
+   return (dist)
+ }
> mydist2(xtest,xtrain)
          [,1]     [,2]
[1,]  9.219544      NaN
[2,] 34.684290 16.27882
[3,] 10.630146      NaN
[4,] 38.078866 40.22437

我也尝试过使用mapply函数

> mydist3 <- function(xtest,xtrain) {
+   euc.dist <- function(x1, x2) sqrt(sum((x1 - x2) ^ 2))
+   dist <- mapply(euc.dist, xtest,xtrain)
+   return (dist)
+ }
> mydist3(xtest,xtrain)
[1]  9  3  7 53  2 14  8 33

我认为它是元素明智的,而不是将每一行作为一个向量来计算两个向量之间的距离。

任何建议将不胜感激!

最佳答案

使用两个 apply 实例,第二个实例嵌套在第一个实例中:

d1 <- apply(xtest, 1, function(x) apply(xtrain, 1, function(y) sqrt(crossprod(x-y))))

检查 pdist:

library(pdist)
d2 <- as.matrix(pdist(xtrain, xtest))

all.equal(d1, d2, tolerance = 1e-7)
## [1] TRUE

关于r - 如何使用apply函数计算两个矩阵之间的距离,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57823228/

相关文章:

java - 使用 BITwise 运算实现二维数组

将函数应用于 tibble 中的每个值(并返回 tibble)?

r - Shiny 的表格格式

r - 是否有 R 等价于其他语言的三重引号?

r - 从 R 中的索引命名嵌套列表的元素

r - 在 igraph r 包中使用地理坐标作为顶点坐标

c++ - cv::Mat::t () 和 cv::transpose() 之间的区别

c++ - 将相机变换应用于 OpenGL

r - 如何将数据帧 A 中的每一列重复除以数据帧 B 中同一列的中位数?

python - 在多个 pandas 数据帧上应用相同的函数来获取数据帧