我有一个字符串张量(名为句子),我想在其中获取其单词的嵌入:
sentence = tf.map_fn(lambda x: tf.string_split([x], delimiter=' ').values, sentence, dtype=tf.string)
我使用上面的代码对批处理中的所有句子应用字符串拆分。然后,我在单词表中应用查找来获取这些张量内每个单词的单词索引:
sentence = tf.map_fn(lambda x: tf.cast(word_table.lookup(x), tf.int32), sentence, dtype=tf.int32)
当批量大小为1运行时,我运行代码没有任何问题。但是,当批量大小大于 1 时,我总是会收到以下错误,该错误指向上面的第一个代码片段。
InvalidArgumentError (see above for traceback): TensorArray sentence_splitter/map/TensorArray_1_1: Could not write to TensorArray index 10 because the value shape is [4] which is incompatible with the TensorArray's inferred element shape: [6] (consider setting infer_shape=False).
我不明白 Tensorflow 试图通过此错误表达什么!如果有人能解释这个错误,那就太好了。谢谢!
最佳答案
当你的batch size大于1时,在这段代码之后
sentence = tf.map_fn(lambda x: tf.string_split([x], delimiter=' ').values, sentence, dtype=tf.string)
tf.string_split()函数作用于不同的句子,产生不同数量的分割结果。各个维度的不兼容导致最终结果无法存储到张量中,从而出现错误。这清楚吗?
关于python - 无法写入 TensorArray 索引,因为值形状与 TensorArray 的推断元素形状不兼容,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54052242/