numpy - Tensorflow/Numpy 中 torch.nn.functional.grid_sample 的等价物是什么?

标签 numpy tensorflow keras deep-learning pytorch

我是 pytorch 的新手,一直在尝试转换一些代码。找不到此特定功能。它存在于 tensorflow 中吗?

最佳答案

我认为 TensorFlow 中没有提供类似的东西。这是 2D 情况的可能实现(我没有考虑填充,但代码的行为应该类似于 border 模式)。请注意,与 PyTorch 版本不同,我假设输入维度顺序是 (batch_size, height, width, channels) (在 TensorFlow 中很常见)。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def grid_sample_2d(inp, grid):
    in_shape = tf.shape(inp)
    in_h = in_shape[1]
    in_w = in_shape[2]

    # Find interpolation sides
    i, j = grid[..., 0], grid[..., 1]
    i = tf.cast(in_h - 1, grid.dtype) * (i + 1) / 2
    j = tf.cast(in_w - 1, grid.dtype) * (j + 1) / 2
    i_1 = tf.maximum(tf.cast(tf.floor(i), tf.int32), 0)
    i_2 = tf.minimum(i_1 + 1, in_h - 1)
    j_1 = tf.maximum(tf.cast(tf.floor(j), tf.int32), 0)
    j_2 = tf.minimum(j_1 + 1, in_w - 1)

    # Gather pixel values
    n_idx = tf.tile(tf.range(in_shape[0])[:, tf.newaxis, tf.newaxis], tf.concat([[1], tf.shape(i)[1:]], axis=0))
    q_11 = tf.gather_nd(inp, tf.stack([n_idx, i_1, j_1], axis=-1))
    q_12 = tf.gather_nd(inp, tf.stack([n_idx, i_1, j_2], axis=-1))
    q_21 = tf.gather_nd(inp, tf.stack([n_idx, i_2, j_1], axis=-1))
    q_22 = tf.gather_nd(inp, tf.stack([n_idx, i_2, j_2], axis=-1))

    # Interpolation coefficients
    di = tf.cast(i, inp.dtype) - tf.cast(i_1, inp.dtype)
    di = tf.expand_dims(di, -1)
    dj = tf.cast(j, inp.dtype) - tf.cast(j_1, inp.dtype)
    dj = tf.expand_dims(dj, -1)

    # Compute interpolations
    q_i1 = q_11 * (1 - di) + q_21 * di
    q_i2 = q_12 * (1 - di) + q_22 * di
    q_ij = q_i1 * (1 - dj) + q_i2 * dj

    return q_ij

# Test it
inp = tf.placeholder(tf.float32, [None, None, None, None])
grid = tf.placeholder(tf.float32, [None, None, None, 2])
res = grid_sample_2d(inp, grid)
with tf.Session() as sess:
    # Make test image
    im_grid_i, im_grid_j = np.meshgrid(np.arange(6), np.arange(10), indexing='ij')
    im = im_grid_i + im_grid_j
    im = im / im.max()
    im = np.stack([im] * 3, axis=-1)
    # Test grid 1: complete image
    grid1 = np.stack(np.meshgrid(np.linspace(-1, 1, 15), np.linspace(-1, 1, 18), indexing='ij'), axis=-1)
    # Test grid 2: lower right corner
    grid2 = np.stack(np.meshgrid(np.linspace(0, 1, 15), np.linspace(.5, 1, 18), indexing='ij'), axis=-1)
    # Run
    res1, res2 = sess.run(res, feed_dict={inp: [im, im], grid: [grid1, grid2]})
    # Plot image and sampled grids
    plt.figure()
    plt.imshow(im)
    plt.figure()
    plt.imshow(res1)
    plt.figure()
    plt.imshow(res2)

这是生成的图像,首先是输入:

Input image

第一个网格结果,这是第一个图像,但形状不同:

First grid result

第二个网格结果,跨越右下角的一个区域:

Second grid result

关于numpy - Tensorflow/Numpy 中 torch.nn.functional.grid_sample 的等价物是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52888146/

相关文章:

python-3.x - 由于负索引查找,Keras 嵌入层中的 InvalidArgumentError

python - 用于收集 numpy 数组的高效查找表

python - np.concatenate 的内存错误

python - 使用 Keras 和 Hyperas 进行参数调整

python - 在 Tensorflow 中实现 Keras 模型时出现的问题

python - 如何在 TF2 中将 ImageDataGenerator 与 TensorFlow 数据集结合起来?

keras - fit_generator中的 “samples_per_epoch”和 “steps_per_epoch”有什么区别

python - 属性错误: 'numpy.ndarray' object has no attribute 'getA1'

python - Pandas 数据框的随机抽样(行和列)

python - 如何对二维 numpy 数组中的向量 [u,v] 进行阈值处理?