我想使用自定义比较器对原始 Java 数组进行排序,但出现类型错误。我认为comparator
函数正在创建一个 Comparator<java.lang.Object>
而不是Comparator<Long>
,但我不知道如何解决这个问题。
这是一个最小的例子:
x.core=> (def x (double-array [4 3 5 6 7]))
#'x.core/x
x.core=> (java.util.Arrays/sort x (comparator #(> %1 %2)))
ClassCastException [D cannot be cast to [Ljava.lang.Object; x.core/eval1524 (form-init5588058267991397340.clj:1)
我尝试过向比较器函数添加不同的类型提示,但坦率地说,我对这门语言还比较陌生,基本上只是在扔飞镖。
我特意简化了上面的示例,以关注关键问题,即类型错误。在下面的部分中,我尝试提供更多细节来激发问题并演示为什么我使用自定义比较器。
动机
我想做的是重复 R 的 order
函数,其工作原理如下:
> x = c(7, 2, 5, 3, 1, 4, 6)
> order(x)
[1] 5 2 4 6 3 7 1
> x[order(x)]
[1] 1 2 3 4 5 6 7
正如您所看到的,它返回将对其输入向量进行排序的索引排列。
这是 Clojure 中的一个可行解决方案:
(defn order
"Permutation of indices sorted by x"
[x]
(let [v (vec x)]
(sort-by #(v %) (range (count v)))))
x.core=> (order [7 2 5 3 1 4 6])
(4 1 3 5 2 6 0)
(请注意,R 的索引为 1,而 Clojure 的索引为 0。)诀窍是对一个向量(即 x [0, 1, ..., (count x)]
的索引按向量 x 本身进行排序。
R 与 Clojure 性能
不幸的是,我对该解决方案的性能感到困扰。 R 解决方案速度快得多:
> x = runif(1000000)
> system.time({ y = order(x) })
user system elapsed
0.041 0.004 0.046
对应的Clojure代码:
x.core=> (def x (repeatedly 1000000 rand))
#'x.core/x
x.core=> (time (def y (order x)))
"Elapsed time: 2857.216452 msecs"
#'x.core/y
原始数组是解决方案吗?
我发现原始数组的排序时间往往与 R 相当:
> x = runif(1000000)
> system.time({ y = sort(x) })
user system elapsed
0.061 0.005 0.069
对比
x.core=> (def x (double-array (repeatedly 1000000 rand)))
#'x.core/x
x.core=> (time (java.util.Arrays/sort x))
"Elapsed time: 86.827277 msecs"
nil
这是我尝试将自定义比较器与 java.util.Arrays 类一起使用的动机。我希望速度能与 R 相当。
我应该补充一点,我可以使用带有 ArrayList 的自定义比较器,如下所示,但性能并不比我的起始函数更好:
(defn order2
[x]
(let [v (vec x)
compx (comparator (fn [i j] (< (v i) (v j))))
ix (java.util.ArrayList. (range (count v)))]
(java.util.Collections/sort ix compx)
(vec ix)))
即使您只是想提供一些一般性的 Clojure 建议,我们也将不胜感激。我仍在学习这门语言,并且从中获得了很多乐趣。 :-)
<小时/>编辑
根据下面 Carcigenicate 的回答,
(defn order
[x]
(let [ix (int-array (range (count x)))]
(vec (-> (java.util.Arrays/stream ix)
(.boxed)
(.sorted (fn [i j] (< (aget x i) (aget x j))))
(.mapToInt
(proxy [java.util.function.ToIntFunction] []
(applyAsInt [^long d] d)))
(.toArray)))))
可以工作:
x.core=> (def x (double-array [5 3 1 3.14 -10]))
#'x.core/x
x.core=> (order x)
[4 2 1 3 0]
x.core=> (map #(aget x %) (order x))
(-10.0 1.0 3.0 3.14 5.0)
不幸的是它非常慢。我想原语可能根本不是答案。
最佳答案
这是使用带有随机主元的快速排序的 order
函数的 Clojure 实现。它相当接近 R:使用具有一百万个 double 的基准,我得到的计时大部分在 520-530 毫秒范围内,而 R 通常徘徊在 500 毫秒左右。
更新:使用非常基本的双线程版本(2x 快速排序,然后是生成输出向量的合并步骤),我的计时得到明显改善 - 最差的基准平均值为 415 毫秒,否则我往往会得到 325-365 毫秒范围内的结果。有关双线程版本,请参阅此消息的末尾,或者如果您更喜欢要点形式的任一版本,这里是 – two-threaded , single-threaded。
请注意,作为中间步骤,它将输入倒入 double 组中,并最终返回一个长整型向量。在我的盒子上,将一百万个 double 倒入向量中似乎只需要 30 毫秒多一点,因此如果您对数组结果感到满意,可以跳过该步骤。
主要的复杂性是 invokePrim
- 从 Clojure 1.9.0-RC1 开始,该位置的常规函数调用将导致装箱。其他方法也是可能的,但这种方法有效并且看起来足够简单。
请参阅本消息末尾的一些基准测试结果。第一次运行的下分位数结果实际上是最佳报告结果
(defn order2 [xs]
(let [rnd (java.util.Random.)
a1 (double-array xs)
a2 (long-array (alength a1))]
(dotimes [i (alength a2)]
(aset a2 i i))
(letfn [(quicksort [^long l ^long h]
(if (< l h)
(let [p (.invokePrim ^clojure.lang.IFn$LLL partition l h)]
(quicksort l (dec p))
(quicksort (inc p) h))))
(partition ^long [^long l ^long h]
(let [pidx (+ l (.nextInt rnd (- h l)))
pivot (aget a1 pidx)]
(swap1 a1 pidx h)
(swap2 a2 pidx h)
(loop [i (dec l)
j l]
(if (< j h)
(if (< (aget a1 j) pivot)
(let [i (inc i)]
(swap1 a1 i j)
(swap2 a2 i j)
(recur i (inc j)))
(recur i (inc j)))
(let [i (inc i)]
(when (< (aget a1 h) (aget a1 i))
(swap1 a1 i h)
(swap2 a2 i h))
i)))))
(swap1 [^doubles a ^long i ^long j]
(let [tmp (aget a i)]
(aset a i (aget a j))
(aset a j tmp)))
(swap2 [^longs a ^long i ^long j]
(let [tmp (aget a i)]
(aset a i (aget a j))
(aset a j tmp)))]
(quicksort 0 (dec (alength a1)))
(vec a2))))
基准测试结果(注意。第一次运行使用问题文本中定义的 x
- (def x (repeatedly 1000000 rand))
;它还使用 c/bench
,而以下运行使用 c/quick-bench
):
user> (c/bench (order2 x))
Evaluation count : 120 in 60 samples of 2 calls.
Execution time mean : 522.485408 ms
Execution time std-deviation : 33.490530 ms
Execution time lower quantile : 470.089782 ms ( 2.5%)
Execution time upper quantile : 575.687990 ms (97.5%)
Overhead used : 15.378363 ns
nil
user> (let [x (repeatedly 1000000 rand)]
(c/quick-bench (order2 x)))
Evaluation count : 6 in 6 samples of 1 calls.
Execution time mean : 527.020004 ms
Execution time std-deviation : 14.846061 ms
Execution time lower quantile : 507.175127 ms ( 2.5%)
Execution time upper quantile : 543.675752 ms (97.5%)
Overhead used : 15.378363 ns
nil
user> (let [x (repeatedly 1000000 rand)]
(c/quick-bench (order2 x)))
Evaluation count : 6 in 6 samples of 1 calls.
Execution time mean : 513.476501 ms
Execution time std-deviation : 12.828449 ms
Execution time lower quantile : 497.164534 ms ( 2.5%)
Execution time upper quantile : 525.094463 ms (97.5%)
Overhead used : 15.378363 ns
nil
user> (let [x (repeatedly 1000000 rand)]
(c/quick-bench (order2 x)))
Evaluation count : 6 in 6 samples of 1 calls.
Execution time mean : 529.826816 ms
Execution time std-deviation : 21.454522 ms
Execution time lower quantile : 508.547461 ms ( 2.5%)
Execution time upper quantile : 552.592925 ms (97.5%)
Overhead used : 15.378363 ns
nil
来自同一个盒子的一些 R 计时用于比较:
> system.time({ y = order(x) })
user system elapsed
0.512 0.004 0.514
> system.time({ y = order(x) })
user system elapsed
0.496 0.000 0.496
> system.time({ y = order(x) })
user system elapsed
0.508 0.000 0.510
> system.time({ y = order(x) })
user system elapsed
0.508 0.000 0.513
> system.time({ y = order(x) })
user system elapsed
0.496 0.000 0.499
> system.time({ y = order(x) })
user system elapsed
0.500 0.000 0.502
更新:双线程 Clojure 版本:
(defn order3 [xs]
(let [rnd (java.util.Random.)
a1 (double-array xs)
a2 (long-array (alength a1))]
(dotimes [i (alength a2)]
(aset a2 i i))
(letfn [(quicksort [^long l ^long h]
(if (< l h)
(let [p (.invokePrim ^clojure.lang.IFn$LLL partition l h)]
(quicksort l (dec p))
(quicksort (inc p) h))))
(partition ^long [^long l ^long h]
(let [pidx (+ l (.nextInt rnd (- h l)))
pivot (aget a1 pidx)]
(swap1 a1 pidx h)
(swap2 a2 pidx h)
(loop [i (dec l)
j l]
(if (< j h)
(if (< (aget a1 j) pivot)
(let [i (inc i)]
(swap1 a1 i j)
(swap2 a2 i j)
(recur i (inc j)))
(recur i (inc j)))
(let [i (inc i)]
(when (< (aget a1 h) (aget a1 i))
(swap1 a1 i h)
(swap2 a2 i h))
i)))))
(swap1 [^doubles a ^long i ^long j]
(let [tmp (aget a i)]
(aset a i (aget a j))
(aset a j tmp)))
(swap2 [^longs a ^long i ^long j]
(let [tmp (aget a i)]
(aset a i (aget a j))
(aset a j tmp)))]
(let [lim (alength a1)
mid (quot lim 2)
f1 (future (quicksort 0 (dec mid)))
f2 (future (quicksort mid (dec lim)))]
@f1
@f2
(loop [out (transient [])
i 0
j mid]
(cond
(== i mid)
(persistent!
(if (== j lim)
out
(reduce (fn [out j]
(conj! out (aget a2 j)))
out
(range j lim))))
(== j lim)
(persistent!
(reduce (fn [out i]
(conj! out (aget a2 i)))
out
(range i mid)))
:else
(let [ie (aget a1 i)
je (aget a1 j)]
(if (< ie je)
(recur (conj! out (aget a2 i)) (inc i) j)
(recur (conj! out (aget a2 j)) i (inc j))))))))))
这方面的一些基准测试结果:
user> (let [x (repeatedly 1000000 rand)]
(c/quick-bench (order3 x)))
Evaluation count : 6 in 6 samples of 1 calls.
Execution time mean : 325.351056 ms
Execution time std-deviation : 3.511578 ms
Execution time lower quantile : 321.947510 ms ( 2.5%)
Execution time upper quantile : 330.375038 ms (97.5%)
Overhead used : 15.378363 ns
nil
user> (let [x (repeatedly 1000000 rand)]
(c/quick-bench (order3 x)))
Evaluation count : 6 in 6 samples of 1 calls.
Execution time mean : 339.422989 ms
Execution time std-deviation : 19.929177 ms
Execution time lower quantile : 318.996436 ms ( 2.5%)
Execution time upper quantile : 366.113347 ms (97.5%)
Overhead used : 15.378363 ns
nil
user> (let [x (repeatedly 1000000 rand)]
(c/quick-bench (order3 x)))
Evaluation count : 6 in 6 samples of 1 calls.
Execution time mean : 415.171336 ms
Execution time std-deviation : 13.624262 ms
Execution time lower quantile : 393.242455 ms ( 2.5%)
Execution time upper quantile : 428.881001 ms (97.5%)
Overhead used : 15.378363 ns
Found 1 outliers in 6 samples (16.6667 %)
low-severe 1 (16.6667 %)
Variance from outliers : 13.8889 % Variance is moderately inflated by outliers
nil
user> (let [x (repeatedly 1000000 rand)]
(c/quick-bench (order3 x)))
Evaluation count : 6 in 6 samples of 1 calls.
Execution time mean : 324.547827 ms
Execution time std-deviation : 5.196817 ms
Execution time lower quantile : 318.541727 ms ( 2.5%)
Execution time upper quantile : 331.878289 ms (97.5%)
Overhead used : 15.378363 ns
nil
user> (c/bench (order3 x))
Evaluation count : 180 in 60 samples of 3 calls.
Execution time mean : 361.529793 ms
Execution time std-deviation : 45.285047 ms
Execution time lower quantile : 307.535934 ms ( 2.5%)
Execution time upper quantile : 446.679687 ms (97.5%)
Overhead used : 15.378363 ns
Found 1 outliers in 60 samples (1.6667 %)
low-severe 1 (1.6667 %)
Variance from outliers : 78.9377 % Variance is severely inflated by outliers
nil
关于clojure - 在 Clojure 上使用自定义比较器对原始数组进行排序,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47254742/