我是 Scala 的新手,我想以相同的性能水平翻译我的 Java 代码。
给定 n 个浮点 vector 和一个附加 vector ,我必须计算所有 n 个点积并获得最大值。
使用 Java 对我来说非常简单
public static void main(String[] args) {
int N = 5000000;
int R = 200;
float[][] t = new float[N][R];
float[] u = new float[R];
Random r = new Random();
for (int i = 0;i<N;i++) {
for (int j = 0;j<R;j++) {
if (i == 0) {
u[j] = r.nextFloat();
}
t[i][j] = r.nextFloat();
}
}
long ts = System.currentTimeMillis();
float maxScore = -1.0f;
for (int i = 0;i < N;i++) {
float score = 0.0f;
for (int j = 0; i < R;i++) {
score += u[j] * t[i][j];
}
if (score > maxScore) {
maxScore = score;
}
}
System.out.println(System.currentTimeMillis() - ts);
System.out.println(maxScore);
}
我的机器上的计算时间是 6 毫秒。
现在我必须用 Scala 来做
val t = Array.ofDim[Float](N,R)
val u = Array.ofDim[Float](R)
// Filling with random floats like in Java
val ts = System.currentTimeMillis()
var maxScore: Float = -1.0f
for ( i <- 0 until N) {
var score = 0.0f
for (j <- 0 until R) {
score += u(j) * t(i)(j)
}
if (score > maxScore) {
maxScore = score
}
}
println(System.currentTimeMillis() - ts)
println(maxScore);
上面的代码在我的机器上花费了超过秒的时间。 我的想法是Scala没有Java中float[]这样的原始数组结构,取而代之的是一个集合。在索引 i 处的访问似乎比 Java 中使用原始数组的访问慢。
下面的代码更慢:
val maxScore = t.map( r => r zip u map Function.tupled(_*_) reduceLeft (_+_)).max
需要 26 秒
我应该如何有效地迭代我的 2 个数组来计算这个?
非常感谢
最佳答案
好吧,很抱歉,但奇怪的是你的 Java 实现有多快,而不是你的 Scala 有多慢 - 6 毫秒遍历 100 亿(!)个单元格听起来好得令人难以置信 - 事实上 - 你在 Java 实现中有一个拼写错误,这使得这段代码做的事情少了很多:
而不是 for (int j = 0; j < R;j++)
, 你有 for (int j = 0; i < R;i++)
- 这使得内部循环运行仅 200 次而不是 100 亿次...
如果您修复此问题 - Scala 和 Java 的性能相当。
顺便说一句,这实际上是 Scala 的一个优势——它更难获得 for (j <- 0 until R)
错了:)
关于java - 与 Java 相比,Scala 点积非常慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40086585/