python - 如何针对不同的输入重用计算图?

标签 python theano

我已经设置了主要的计算流程,可以使用它进行训练

train = theano.function(inputs=[x], outputs=[cost], updates=updates)

同样,我有一个预测函数

predict = theano.function(inputs=[x], outputs=[output])

这两个函数都接受输入x并通过相同的计算图发送它。

我现在想修改一些东西,以便在训练时,我可以使用噪声输入进行训练,所以我有类似的东西

input = get_corrupted_input(self.theano_rng, x, 0.5)

在计算开始时。

但这也会影响我的预测函数,因为它的输入也会被损坏。如何为 trainpredict 重用相同的代码,但只为前者提供噪声输入?

最佳答案

您可以这样组织代码:

import numpy
import theano
import theano.tensor as tt
import theano.tensor.shared_randomstreams


def get_cost(x, y):
    return tt.mean(tt.sum(tt.sqr(x - y), axis=1))


def get_output(x, w, b_h, b_y):
    h = tt.tanh(tt.dot(x, w) + b_h)
    y = tt.dot(h, w.T) + b_y
    return y


def corrupt_input(x, corruption_level):
    rng = tt.shared_randomstreams.RandomStreams()
    return rng.binomial(size=x.shape, n=1, p=1 - corruption_level,
                        dtype=theano.config.floatX) * x


def compile(input_size, hidden_size, corruption_level, learning_rate):
    x = tt.matrix()
    w = theano.shared(numpy.random.randn(input_size,
                      hidden_size).astype(theano.config.floatX))
    b_h = theano.shared(numpy.zeros(hidden_size, dtype=theano.config.floatX))
    b_y = theano.shared(numpy.zeros(input_size, dtype=theano.config.floatX))
    cost = get_cost(x, get_output(corrupt_input(x, corruption_level), w, b_h, b_y))
    updates = [(p, p - learning_rate * tt.grad(cost, p)) for p in (w, b_h, b_y)]
    train = theano.function(inputs=[x], outputs=cost, updates=updates)
    predict = theano.function(inputs=[x], outputs=get_output(x, w, b_h, b_y))
    return train, predict


def main():
    train, predict = compile(input_size=3, hidden_size=2,
                             corruption_level=0.2, learning_rate=0.01)


main()

请注意,get_output 被调用两次。对于train函数,它提供了损坏的输入,但对于predict函数,它提供了干净的输入。 get_output 需要包含您所说的“相同的计算图”。我刚刚在其中放置了一个小型自动编码器,但您可以在其中放置任何您想要的内容。

假设损坏的输入与输入具有相同的形状,get_output 函数不会关心其输入是 x 还是 x 的损坏版本。因此 get_output 可以共享,但不需要包含损坏代码。

关于python - 如何针对不同的输入重用计算图?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33858990/

相关文章:

python - Theano 在将 Python 作为模块运行时无法使用 GPU (python -m)

python - Flask-CLI 无法运行应用程序?

python - 属性错误: Object has no attribute 'listbox

在循环内计算索引的 Python theano

python - 如何在 nolearn、lasagne 中定义成本函数?

python - 在 Keras 中编译模型后如何动态卡住权重?

python - 在 Plotly 图中重新排序 Axis

python - 如何从Python中的单个input()函数获取多个输入?

python - 如何将 cProfile 与 nosetest --with-profile 一起使用?

theano - Theano.function中 'givens'变量的用途