我需要为可变数量的数组高效地实现笛卡尔积。
我尝试了 Iterators.jl
中的 product
函数,但性能欠佳。
我是一个 python 黑客并且使用过 this function来自 sklearn,并取得了良好的性能结果。
我曾尝试编写此函数的 Julia 版本,但无法生成与 python 函数相同的结果。
我的代码是:
function my_repeat(a, n)
# mimics numpy.repeat
m = size(a, 1)
out = Array(eltype(a), n * m)
out[1:n] = a[1]
for i=2:m
out[(i-1)*n+1:i*n] = a[i]
end
return out
end
function cartesian(arrs; out=None)
dtype = eltype(arrs[1])
n = prod([size(i, 1) for i in arrs])
if is(out, None)
out = Array(dtype, n, length(arrs))
end
m = int(n / size(arrs[1], 1))
out[:, 1] = my_repeat(arrs[1], m)
if length(arrs[2:]) > 0
cartesian(arrs[2:], out=out[1:m, 2:])
for j = 1:size(arrs[1], 1)-1
out[(j*m + 1):(j+1)*m, 2:] = out[1:m, 2:]
end
end
return out
end
我用以下方法测试它:
aa = ([1, 2, 3], [4, 5], [6, 7])
cartesian(aa)
返回值为:
12x3 Array{Float64,2}:
1.0 9.88131e-324 2.13149e-314
1.0 2.76235e-318 2.13149e-314
1.0 9.88131e-324 2.13676e-314
1.0 9.88131e-324 2.13676e-314
2.0 9.88131e-324 2.13149e-314
2.0 2.76235e-318 2.13149e-314
2.0 9.88131e-324 2.13676e-314
2.0 9.88131e-324 2.13676e-314
3.0 9.88131e-324 2.13149e-314
3.0 2.76235e-318 2.13149e-314
3.0 9.88131e-324 2.13676e-314
3.0 9.88131e-324 2.13676e-314
我认为这里的问题是,当我使用这一行时:cartesian(arrs[2:], out=out[1:m, 2:])
,关键字参数 out
不会在递归调用中就地更新。
可以看出,我对这个函数的 Python 版本做了一个非常天真的翻译(见上面的链接)。很可能存在内部语言差异,无法进行天真的翻译。我不认为这是真的,因为这来自 functions 的引用julia 文档部分:
Julia function arguments follow a convention sometimes called “pass-by-sharing”, which means that values are not copied when they are passed to functions. Function arguments themselves act as new variable bindings (new locations that can refer to values), but the values they refer to are identical to the passed values. Modifications to mutable values (such as Arrays) made within a function will be visible to the caller. This is the same behavior found in Scheme, most Lisps, Python, Ruby and Perl, among other dynamic languages.
我怎样才能让这个(或等效的)函数在 Julia 中工作?
最佳答案
Base 中有一个repeat
函数。
更短更快的变体可能会使用 Cartesian 包中的 @forcartesian
宏:
using Cartesian
function cartprod(arrs, out=Array(eltype(arrs[1]), prod([length(a) for a in arrs]), length(arrs)))
sz = Int[length(a) for a in arrs]
narrs = length(arrs)
@forcartesian I sz begin
k = sub2ind(sz, I)
for i = 1:narrs
out[k,i] = arrs[i][I[i]]
end
end
out
end
行的顺序与您的解决方案不同,但这也许无关紧要?
关于julia - 就地更新函数参数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/19498323/