python - tensorflow 找到到真实点的最小距离

标签 python tensorflow

我有一个 Bx3 张量,foo,B= 批量大小 3D 点。通过一些奇思妙想,我得到了另一个张量,bar,其形状为 Bx6x3,其中每个 B 6x3 矩阵对应于 foo 中的一个点。该 6x3 矩阵由 6 个复值 3D 点组成。我想要做的是,对于每个 B 点,在 bar 中的 6 个点中找到与 中对应点最接近的实值点foo,最终得到一个新的 Bx3 min_bar,其中包含 bar 中与 foo 中的点最接近的点。

numpy 中,我可以使用掩码数组完成此壮举:

foo = np.array([
    [1,2,3],
    [4,5,6],
    [7,8,9]])
# here bar is only Bx2x3 for simplicity, but the solution generalizes
bar = np.array([
    [[2,3,4],[1+0.1j,2+0.1j,3+0.1j]],
    [[6,5,4],[4,5,7]],
    [[1j,1j,1j],[0,0,0]],
])

#mask complex elements of bar
bar_with_masked_imag = np.ma.array(bar)
candidates = bar_with_masked_imag.imag == 0
bar_with_masked_imag.mask = ~candidates

dists = np.sum(bar_with_masked_imag**2, axis=1)
mindists = np.argmin(dists, axis=1)
foo_indices = np.arange(foo.shape[0])
min_bar = np.array(
    bar_with_masked_imag[foo_indices,mindists,:], 
    dtype=float
)

print(min_bar)
#[[2. 3. 4.]
# [4. 5. 7.]
# [0. 0. 0.]]

但是,tensorflow 没有屏蔽数组等。我如何将其转换为 tensorflow ?

最佳答案

这是一种方法:

import tensorflow as tf
import math

def solution_tf(foo, bar):
    foo = tf.convert_to_tensor(foo)
    bar = tf.convert_to_tensor(bar)
    # Get real and imaginary parts
    bar_r = tf.cast(tf.real(bar), foo.dtype)
    bar_i = tf.imag(bar)
    # Mask of all real-valued points
    m = tf.reduce_all(tf.equal(bar_i, 0), axis=-1)
    # Distance to every corresponding point
    d = tf.reduce_sum(tf.squared_difference(tf.expand_dims(foo, 1), bar_r), axis=-1)
    # Replace distances of complex points with infinity
    d2 = tf.where(m, d, tf.fill(tf.shape(d), tf.constant(math.inf, d.dtype)))
    # Find smallest distances
    idx = tf.argmin(d2, axis=1)
    # Get points with smallest distances
    b = tf.range(tf.shape(foo, out_type=idx.dtype)[0])
    return tf.gather_nd(bar_r, tf.stack([b, idx], axis=1))

# Test
with tf.Graph().as_default(), tf.Session() as sess:
    foo = tf.constant([
        [1,2,3],
        [4,5,6],
        [7,8,9]], dtype=tf.float32)
    bar = tf.constant([
        [[2,3,4],[1+0.1j,2+0.1j,3+0.1j]],
        [[6,5,4],[4,5,7]],
        [[1j,1j,1j],[0,0,0]]], dtype=tf.complex64)
    sol_tf = solution_tf(foo, bar)
    print(sess.run(sol_tf))
    # [[2. 3. 4.]
    #  [4. 5. 7.]
    #  [0. 0. 0.]]

关于python - tensorflow 找到到真实点的最小距离,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57891175/

相关文章:

python - urllib : get utf-8 encoded site source code

python - Keras : AttributeError: 'int' object has no attribute 'ndim' when using model. 适合

machine-learning - 为什么在模式匹配(神经网络)中的手写脚本匹配过程中图像会变平

python - 我尝试在 TensorBoard 中使用 tf.summary.audio 打印音频,显示任何音频

python - 在python中使用类型提示注释路径的正确方法是什么?

Python Scrapy - 需要动态 HTML、div 和 span 内容

python - 如何获得我的暴力破解的百分比?

python - 从 Numpy 数组中删除列的有效方法?

python - 如何找到旋转图像边界框的新坐标以修改其xml文件以进行Tensorflow数据增强?

tensorflow - 关于使用 Keras 在 VGG16 中构建第一个输入层