python - 获取具有可变序列长度的激活时出现 Tensorflow GRU 单元错误

标签 python python-2.7 tensorflow recurrent-neural-network gated-recurrent-unit

我想在一些时间序列数据上运行 GRU 单元,根据最后一层的激活对它们进行聚类。我对 GRU 单元实现做了一个小改动

def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) with nunits cells."""
with vs.variable_scope(scope or type(self).__name__):  # "GRUCell"
  with vs.variable_scope("Gates"):  # Reset gate and update gate.
    # We start with bias of 1.0 to not reset and not update.
    r, u = array_ops.split(1, 2, linear([inputs, state], 2 * self._num_units, True, 1.0))
    r, u = sigmoid(r), sigmoid(u)
  with vs.variable_scope("Candidate"):
    c = tanh(linear([inputs, r * state], self._num_units, True))
  new_h = u * state + (1 - u) * c

  # store the activations, everything else is the same
  self.activations = [r,u,c]
return new_h, new_h

在此之后,我按以下方式连接激活,然后在调用此 GRU 单元的脚本中返回它们

@property
def activations(self):
    return self._activations


@activations.setter
def activations(self, activations_array):
    print "PRINT THIS"         
    concactivations = tf.concat(concat_dim=0, values=activations_array, name='concat_activations')
    self._activations = tf.reshape(tensor=concactivations, shape=[-1], name='flatten_activations')

我按以下方式调用 GRU 单元

outputs, state = rnn.rnn(cell=cell, inputs=x, initial_state=initial_state, sequence_length=s)

其中 s 是批处理长度数组,其中包含输入批处理的每个元素中的时间戳数。

最后我使用

fetched = sess.run(fetches=cell.activations, feed_dict=feed_dict)

执行时出现如下错误

追溯(最近的调用最后): 文件“xxx.py”,第 162 行,位于 获取 = sess.run(获取=cell.activations,feed_dict=feed_dict) 运行中的文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第 315 行 返回 self._run(无,获取,feed_dict) 文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第 511 行,在 _run feed_dict_string) 文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第 564 行,在 _do_run 目标列表) 文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第 588 行,在 _do_call 六.reraise(e_type,e_value,e_traceback) 文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第 571 行,在 _do_call 返回 fn(*args) _run_fn 中的文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第 555 行

返回 tf_session.TF_Run( session 、feed_dict、fetch_list、target_list) tensorflow.python.pywrap_tensorflow.StatusNotOK:参数无效:为 RNN/cond_396/ClusterableGRUCell/flatten_activations:0 返回的张量无效。

有人可以深入了解如何在最后一步通过可变长度序列从 GRU 单元中获取激活吗?谢谢。

最佳答案

要从最后一步获取激活,您希望激活成为状态的一部分,状态由 tf.rnn 返回。

关于python - 获取具有可变序列长度的激活时出现 Tensorflow GRU 单元错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36949019/

相关文章:

python - 无法使用 Python 在 GPU (Jetson Nano) 上运行 tflite 模型

python - 返回字典中的三个最大值

python - 如何在Python脚本中将随机参数传递给另一个Python程序?

linux - 如何修复 "openerp.osv.expression: Field '销售订单'(sale_id)无法搜索: non-stored function field without fnct_search"

docker - Windows 7 jupyter 笔记本 - ImportError : No module named 'tensorflow'

python - 读入大型 CSV 文件并输入 TensorFlow

python - 使用 pytest,为什么单个测试结果与运行所有测试不同?

python csv writer 添加额外的引号

python - Flask-login is_active 不会保持更改

python - 使用列表根据多列中的值有条件地填充新列