python - 无法写入 TensorArray 索引,因为值形状与 TensorArray 的推断元素形状不兼容

标签 python tensorflow

我有一个字符串张量(名为句子),我想在其中获取其单词的嵌入:

sentence = tf.map_fn(lambda x: tf.string_split([x], delimiter=' ').values, sentence, dtype=tf.string)

我使用上面的代码对批处理中的所有句子应用字符串拆分。然后,我在单词表中应用查找来获取这些张量内每个单词的单词索引:

sentence = tf.map_fn(lambda x: tf.cast(word_table.lookup(x), tf.int32), sentence, dtype=tf.int32)

批量大小1运行时,我运行代码没有任何问题。但是,当批量大小大于 1 时,我总是会收到以下错误,该错误指向上面的第一个代码片段。

InvalidArgumentError (see above for traceback): TensorArray sentence_splitter/map/TensorArray_1_1: Could not write to TensorArray index 10 because the value shape is [4] which is incompatible with the TensorArray's inferred element shape: [6] (consider setting infer_shape=False).

我不明白 Tensorflow 试图通过此错误表达什么!如果有人能解释这个错误,那就太好了。谢谢!

最佳答案

当你的batch size大于1时,在这段代码之后

sentence = tf.map_fn(lambda x: tf.string_split([x], delimiter=' ').values, sentence, dtype=tf.string)

tf.string_split()函数作用于不同的句子,产生不同数量的分割结果。各个维度的不兼容导致最终结果无法存储到张量中,从而出现错误。这清楚吗?

关于python - 无法写入 TensorArray 索引,因为值形状与 TensorArray 的推断元素形状不兼容,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54052242/

相关文章:

python - 如何用python删除字母[A-Z]

python - 有没有办法使用 pgeocode 获取国家代码?

python - Keras神经网络输出函数参数/如何构造损失函数?

python - 是否有明确定义的 next_batch 函数?

由列表索引的python数据框

python - Pandas 数据框 : turn date in column into value in row

python - pandas 数据帧的准确内存使用估计

python - 针对住房回归示例关闭 Cloud ML 预测

python-3.x - 无法使用经过训练的 Tensorflow 模型

tensorflow - 数据增强层的行为是什么?