julia - Flux 的自定义渐变而不是使用 Zygote A.D

标签 julia flux

我有一个机器学习模型,其中模型参数的梯度是解析的,不需要自动微分。但是,我仍然希望能够利用 Flux 中的不同优化器,而不必依赖 Zygote 进行区分。这是我的代码的一些片段。

W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)

θ = Flux.Params([b, c, U, W])

opt = ADAM(0.01)

然后我有一个函数来计算我的模型参数的解析梯度,θ .
function gradients(x) # x = one input data point or a batch of input data points
    # stuff to calculate gradients of each parameter
    # returns gradients of each parameter

然后,我希望能够执行以下操作。
grads = gradients(x)
update!(opt, θ, grads)

我的问题是:我的 gradient(x) 是什么形式/类型函数需要返回才能执行update!(opt, θ, grads) ,我该怎么做?

最佳答案

如果您不使用 Params然后 grads只需要是渐变。唯一的要求是 θgrads大小相同。

例如,map((x, g) -> update!(opt, x, g), θ, grads)哪里θ == [b, c, U, W]grads = [gradients(b), gradients(c), gradients(U), gradients(W)] (不太确定 gradients 期望什么作为您的输入)。

更新:但要回答您原来的问题,gradients需要返回一个 Grads在这里找到的对象:https://github.com/FluxML/Zygote.jl/blob/359e586766129878ca0e56121037ed80afda6289/src/compiler/interface.jl#L88

所以像

# within gradient function body assuming gb is the gradient w.r.t b
g = Zygote.Grads(IdDict())
g.grads[θ[1]] = gb # assuming θ[1] == b

但不使用 Params调试起来可能更简单。唯一的问题是没有 update!这将适用于一系列参数,但您可以轻松定义自己的:
function Flux.Optimise.update!(opt, xs::Tuple, gs)
    for (x, g) in zip(xs, gs)
        update!(opt, x, g)
    end
end

# use it like this
W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)

θ = (b, c, U, W)

opt = ADAM(0.01)
x = # generate input to gradients
grads = gradients(x) # return tuple (gb, gc, gU, gW)
update!(opt, θ, grads)

更新 2:

另一种选择是仍然使用 Zygote 来获取梯度,以便它自动设置 Grads对象,但要使用自定义伴随,以便它使用您的分析函数来计算伴随。假设您的 ML 模型定义为名为 f 的函数,所以 f(x)为输入返回模型的输出 x .我们还假设 gradients(x)返回分析梯度 w.r.t. x就像你在问题中提到的那样。那么下面的代码仍然会使用 Zygote 的 AD 来填充 Grads对象正确,但它将使用您为函数计算梯度的定义f :
W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)

θ = Flux.Params([b, c, U, W])

f(x) = # define your model
gradients(x) = # define your analytical gradient

# set up the custom adjoint
Zygote.@adjoint f(x) = f(x), Δ -> (gradients(x),)

opt = ADAM(0.01)
x = # generate input to model
y = # output of model
grads = Zygote.gradient(() -> Flux.mse(f(x), y), θ)
update!(opt, θ, grads)

请注意,我使用了 Flux.mse作为上面的损失示例。这种方法的一个缺点是 Zygote 的 gradient函数需要标量输出。如果你的模型被传递到一些会输出标量错误值的损失,那么 @adjoint是最好的方法。这适用于您进行标准机器学习的情况,唯一的变化是您希望 Zygote 计算 f 的梯度。使用您的函数进行分析。

如果您正在做一些更复杂的事情并且无法使用 Zygote.gradient ,那么第一种方法(不使用 Params )是最合适的。 Params实际上只是为了与 Flux 的旧 AD 向后兼容而存在,因此最好尽可能避免它。

关于julia - Flux 的自定义渐变而不是使用 Zygote A.D,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61255285/

相关文章:

dataframe - 查找列值在集合中的行(类似于 pandas isin 或 R %in%)

python - 用于图像处理和语音识别的 Julia

reactjs - 在无 Flux 应用程序中使用 React `context` 访问模型更改器(mutator)是否合理?

vue.js - 在 vue/vuex(/flux?) 中使用 ES6 类是一种反模式吗?

julia - 创建一个行为类似于另一个类型的原始类型

Julia 中的 C 结构类型对应关系

statistics - Julia 的移动平均线

java - WebFlux 扩展未检索第二个请求

java - 如何使用 Spring webflux 将实时进度发送到 webclient?

javascript - React.js fadeIn + 渲染每个元素的延迟