tensorflow - 将二进制掩码转换为 tensorflow 中的边界框

标签 tensorflow tensor

我有一个二进制掩码作为 tensorflow 中的张量。

如何使用 tensorflow 操作将此二进制掩码转换为边界框?

最佳答案

经过一番努力,我终于解决了这个问题。请注意,给出的解决方案仅适用于单个对象,但是通过一些调整,您也可以将其应用于多个对象。

基本上,您想检查沿整个轴是否有任何真实像素。您从边缘开始,进一步向内移动,直到到达至少有一个真实像素的轴。对左、右、上、下执行此操作。

image = tf.io.read_file('mask.png')
image = tf.io.decode_png(image)
image = tf.image.resize(image, size=(300, 300), method='nearest')

rows = tf.math.count_nonzero(image, axis=0, keepdims=None, dtype=tf.bool) # return true if any pixels in the given row is true
rows = tf.squeeze(rows, axis=1) #make a scalar

columns = tf.math.count_nonzero(image, axis=1, keepdims=None, dtype=tf.bool)
columns = tf.squeeze(columns, axis=1)

def indicies_by_value(value): return tf.where(tf.equal(value, True))[:,-1] #return all the indices where mask is present along given axis

#coordinates
y_min = indicies_by_value(columns)[0] #first true pixel along axis
y_max = indicies_by_value(columns)[-1] #last true pixel along axis
x_min = indicies_by_value(rows)[0]
x_max = indicies_by_value(rows)[-1]

#apply the bounding box
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
img = tf.expand_dims(image, axis=-0)
img = tf.reshape(img, shape=[1, 300, 300, 1])

box = tf.stack([y_min, x_min, y_max, x_max], axis=0)
box = tf.math.divide(box, 300)
box = box.numpy()
boxes = box.reshape([1,1,4])

colors = np.array([[0.5, 0.9, 0.5], [0.5, 0.9, 0.5]])
boundning_box = tf.image.draw_bounding_boxes(img, boxes, colors)

tf.keras.preprocessing.image.save_img('boxed.png', boundning_box.numpy()[0])

关于tensorflow - 将二进制掩码转换为 tensorflow 中的边界框,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63804939/

相关文章:

tensorflow - 获取 TensorFlow 训练的模型中某些权重的值

python - Tensorflow 对象检测 - 避免重叠框

machine-learning - TensorFlow:训练好的模型存储在哪里以及如何访问?

python - Conv2dTranspose 产生错误的输出形状

python - Pytorch 批量矩阵向量外积

python - 删除 Torch 张量中的行

python - "RuntimeError: expected scalar type Double but found Float"在 Pytorch CNN 训练中

pytorch - 我怎样才能制作 torch 张量?

python - 在 tensorflow 中创建一个 float64 变量

python - 了解 PyTorch 中 index_put 的行为