python - Tensorflow:替换 tf.nn.rnn_cell._linear(input, size, 0, scope)

标签 python tensorflow neural-network

我正在尝试从 https://github.com/LantaoYu/SeqGAN 获取 SequenceGAN ( https://arxiv.org/pdf/1609.05473.pdf )运行。
在修复了明显的错误后,比如将 pack 替换为 stack,它仍然无法运行,因为 highway-network 部分需要 tf.nn.rnn_cell。 _linear 函数:

# highway layer that borrowed from https://github.com/carpedm20/lstm-char-cnn-tensorflow
def highway(input_, size, layer_size=1, bias=-2, f=tf.nn.relu):
    """Highway Network (cf. http://arxiv.org/abs/1505.00387).

    t = sigmoid(Wy + b)
    z = t * g(Wy + b) + (1 - t) * y
    where g is nonlinearity, t is transform gate, and (1 - t) is carry gate.
    """
    output = input_
    for idx in range(layer_size):
        output = f(tf.nn.rnn_cell._linear(output, size, 0, scope='output_lin_%d' % idx)) #tf.contrib.layers.linear instad doesn't work either.
        transform_gate = tf.sigmoid(tf.nn.rnn_cell._linear(input_, size, 0, scope='transform_lin_%d' % idx) + bias)
        carry_gate = 1. - transform_gate

        output = transform_gate * output + carry_gate * input_

    return output

tf.nn.rnn_cell._linear 函数在 Tensorflow 1.0 或 0.12 中似乎不再存在,我不知道用什么来代替它。我找不到这方面的任何新实现,也找不到关于 tensorflow 的 github 或(不幸的是非常稀疏)文档的任何信息。

有人知道函数的新挂件吗? 非常感谢!

最佳答案

我在使用 SkFlow 的 TensorFlowDNNRegressor 时遇到了这个错误。 第一次看到ruoho ruots的回答,有点懵。 但第二天我就明白他的意思了。

这是我的做法:

from tensorflow.python.ops import rnn_cell_impl

tf.nn.rnn_cell._linear 替换为 rnn_cell_impl._linear

关于python - Tensorflow:替换 tf.nn.rnn_cell._linear(input, size, 0, scope),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42437115/

相关文章:

python - TensorFlow DCGAN 模型 : stability and convergence problems

machine-learning - 为什么 Keras 中的多类分类时,binary_crossentropy 比 categorical_crossentropy 更准确?

python - 由于 ModuleNotFoundError : No module named 'wsgi' ,Flask 无法使用 Docker 启动服务器

python - 要求用户选择文件夹来读取Python中的文件?

c++ - 是否可以拆分 SWIG 模块进行编译,但在链接时重新加入它?

python - python中循环的向量化

machine-learning - Keras,稀疏矩阵问题

python - LSTM 网络在几次迭代后开始生成垃圾

python - 使用 tensorflow 的数据集管道,我如何*命名* `map` 操作的结果?

python - 如何使用tensorflow数据集API复制训练样本?