machine-learning - 使用 Flux.jl 更新函数的参数

标签 machine-learning optimization julia flux-machine-learning

我正在使用 flux.jl,但在更新自定义函数的参数时遇到问题。

该函数定义如下目标:

    using Distributions
    using Flux.Tracker: gradient, param, Params
    using Flux.Optimise: Descent, ADAM, update!

    D = 2 
    num_samples = 100

    function log_density(params)
        mu, log_sigma = params
        d1 = Normal(0, 1.35)
        d2 = Normal(0, exp(log_sigma))
        d1_density = logpdf(d1, log_sigma)
        d2_density = logpdf(d2, mu)
        return d1_density + d2_density
    end


    function J(log_std)
        H = 0.5 * D * (1.0 + log(2 * pi)) + sum(log_std)
        return H
    end

    function objective(mu, log_std; D=2)
        samples = rand(Normal(), num_samples, D) .* sqrt.(log_std) .+ mu
        log_px = mapslices(log_density, samples; dims=2)
        elbo = J(log_std) + mean(log_px)
        return -elbo
    end

我尝试进行一次更新,如下所示:


    mu = param(reshape([-1, -1], 1, :))
    sigma = param(reshape([5, 5], 1, :))

    grads = gradient(() -> objective(mu, sigma), Params([mu, sigma]))

    opt = Descent(0.001)
    for p in (mu, sigma)
        update!(opt, p, grads[p])
    end

产生错误:

ERROR: Can't differentiate `setindex!`
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] setindex!(::TrackedArray{…,Array{Float64,2}}, ::Flux.Tracker.TrackedReal{Float64}, ::CartesianIndex{2}) at /Users/vasya/.julia/packages/Flux/T3PhK/src/tracker/lib/array.jl:63
 [3] macro expansion at ./broadcast.jl:838 [inlined]
 [4] macro expansion at ./simdloop.jl:73 [inlined]
 [5] copyto! at ./broadcast.jl:837 [inlined]
 [6] copyto! at ./broadcast.jl:792 [inlined]
 [7] materialize! at ./broadcast.jl:751 [inlined]
 [8] update!(::Descent, ::TrackedArray{…,Array{Float64,2}}, ::TrackedArray{…,Array{Float64,2}}) at /Users/vasya/.julia/packages/Flux/T3PhK/src/optimise/optimisers.jl:22
 [9] top-level scope at ./REPL[23]:2 [inlined]
 [10] top-level scope at ./none:0

我还尝试将 grads[p] 替换为 grads[p].data。这不会产生错误,但不会更新参数!

环境详细信息:
- Julia 版本 1.0.2
- Flux v0.7.0
- 发行版 v0.16.4

最佳答案

<小时/>

通过 Slack 进行的讨论明确了 update! 函数的正确用法。下面的代码使模块引用显式,并生成更新的参数(对于 Flux v0.7.0):

    using Distributions
    using Flux

    D = 2 
    num_samples = 100

    function log_density(params)
        mu, log_sigma = params
        d1 = Normal(0, 1.35)
        d2 = Normal(0, exp(log_sigma))
        d1_density = logpdf(d1, log_sigma)
        d2_density = logpdf(d2, mu)
        return d1_density + d2_density
    end

    function J(log_std)
        H = 0.5 * D * (1.0 + log(2 * pi)) + sum(log_std)
        return H
    end

    function objective(mu, log_std; D=2)
        samples = rand(Normal(), num_samples, D) .* sqrt.(log_std) .+ mu
        log_px = mapslices(log_density, samples; dims=2)
        elbo = J(log_std) + mean(log_px)
        return -elbo
    end

    mu = Flux.Tracker.param(reshape([-1, -1], 1, :))
    sigma = Flux.Tracker.param(reshape([5, 5], 1, :))

    grads = Flux.Tracker.gradient(() -> objective(mu, sigma), Flux.Tracker.Params([mu, sigma]))

    println(mu, sigma)

    opt = Flux.Optimise.Descent(0.01)
    for p in (mu, sigma)
        Flux.Tracker.update!(p, Flux.Optimise.update!(opt, p, Flux.data(grads[p])))
    end

    println(mu, sigma)

打印:

    [-1.0 -1.0] (tracked)[5.0 5.0] (tracked)
    [-198.742 -459.423] (tracked)[31.0583 225.657] (tracked)

关于machine-learning - 使用 Flux.jl 更新函数的参数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54316541/

相关文章:

string - 如何在 Julia 的字符串数组中查找子字符串或字符

java - 计算具有已知频率的 3 个或更多属性的方差

MATLAB感知器

python - 如何在 TensorFlow 中应用渐变裁剪?

performance - 为长度为 8 的一维卷积核加载 AVX 向量的高效代码

julia - 带数据框的条件语句 [Julia v1.0]

tensorflow - 如何为您训练的模型选择半精度(BFLOAT16 与 FLOAT16)?

python - 返回前的赋值在 python 中有成本吗?

c - 在 C 中,哪个更快 : if with returns, 或者如果有返回?

r - Julia 相当于 R 的 qnorm()?