tensorflow - tf.segment_max 错误 : segment ids are not increasing by 1

标签 tensorflow

tf.segment_max()和其他段操作一般require the segment IDs to be consecutive .将此操作应用于动态生成的批次并使用 tf.unique() 时定义段时,段 ID 可能不连续,从而产生错误。两种情况(错误/无错误)如下图所示:

with tf.Session() as sess:
    data = tf.constant([[1, 3, 5], [6, 2, 7], [9, 9, 2], [9, 5, 1]])
    #labels = ['a','r','r','d'] # this works
    labels = ['a','r','d','r']  # error: segment ids are not increasing by 1
    y, idx = tf.unique(labels) 
    maxs = tf.segment_max(data, idx)
    rval = sess.run([idx, maxs])
    print('indices: ', rval[0])
    print('maxs: ', rval[1])

如何处理非连续段 ID 的一般情况?

最佳答案

考虑到这个问题是很久以前发布的,我不确定 OP 已经找到了答案。但是,我正在为可能最终遇到相同问题的任何人发布解决方案。
这里的诀窍是使用 tf.math.unsorted_segment_max() .这很简单。例如,考虑以下代码-

# Data
data = tf.constant([0.4, 0.5, 0.2, 0.7, 0.1], dtype=tf.float32)
# Unique segements with Segment IDs
unique_elems, segment_ids = tf.unique(tf.constant([1, 2, 1, 1, 2], dtype=tf.int32))
# Count the number of unique segments.
nb_segments = tf.size(tf.unique(segment_ids)[0])
# Result
result = tf.math.unsorted_segment_max(
    data=data, segment_ids=segment_ids, num_segments=nb_segments
)
tf.print(result) 
# result turns out to be [0.7, 0.5]  
在上面的代码中,我们首先使用 tf.unique() 找到具有相同值的索引。方法。然后,我们使用 tf.unsorted_segment_max()data中找到最大值的方法对应于那些索引段的张量。
您可以阅读更多相关信息 here .

关于tensorflow - tf.segment_max 错误 : segment ids are not increasing by 1,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40109854/

相关文章:

python - tensorflow 错误: ValueError: None values not supported

tensorflow - 如何在 LSTM 层之间实现跳跃连接结构

python - Tensorflow 2.0 中的 tf.function 和 tf.while 循环

tensorflow - ValueError : Error when checking input: expected dense_1_input to have shape (3, )但得到形状为(2,)的数组

tensorflow - 将 Pyspark Dataframe 写入 TFrecords 文件

python - 使用 cv2 显示unet预测图像

python - 如何在 LSTM 中实现 Tensorflow 批量归一化

android - 使用 bazel 编译 tensorflow android demo 时缺少输入文件@androidsdk//:build-tools/26. 0.1/aapt

python - 在 tensorflow 中获取 ValueError 说我的形状不兼容

tensorflow - 从 TFRecord 保存和读取可变大小列表