python - 如何使用 tensorflow 应用动态形状 `scatter_nd`

标签 python python-3.x tensorflow

在tensorflow中,如何将动态形状应用于scatter_nd

当我使用具有动态形状的输入张量时,出现以下错误:

ValueError: Cannot convert a partially known TensorShape to a Tensor: (20, ?)

这是我使用的函数。当张量具有静态形状时它会起作用。但对于动态形状(例如 (?, 7)),它会失败。

def tf_zero_pad_columns(tensor, columns_list, num_output_columns):
    assert(tensor.shape.as_list()[1] == len(columns_list))
    assert(num_output_columns >= len(columns_list))

    tensor = tf.transpose(tensor)
    columns = tf.constant(np.array([columns_list]).T.astype('int32'))
    shape=tf.TensorShape((num_output_columns, tensor.get_shape()[1]))
    scattered = tf.scatter_nd(columns, tensor, shape=shape)
    return tf.transpose(scattered)

我还尝试用 -1 替换 tensor.get_shape()[1],但这在训练过程中会产生不同的错误:

InvalidArgumentError: Dimension -1 must be >= 0 [[Node: lambda_40/ScatterNd ....

<小时/>

编辑:

具有动态形状的示例输入(这会重现错误):

tensor = tf.placeholder(tf.float32, shape=(None, 7))
tf_zero_pad_columns(tensor, [11,12,13,4,5,6,7], 20)

具有静态形状的示例输入:

import numpy as np
tensor_np = np.tile(range(7), (4, 1)) + np.array(range(4))[:, None]
tensor = tf.constant(tensor_np)


tf_zero_pad_columns(tensor, [11,12,13,4,5,6,7], 20)

输出是:

array([[0, 0, 0, 0, 3, 4, 5, 6, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 4, 5, 6, 7, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 5, 6, 7, 8, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 6, 7, 8, 9, 0, 0, 0, 3, 4, 5, 0, 0, 0, 0, 0, 0]])

最佳答案

这对我有用:

def tf_zero_pad_columns(tensor, columns_list, num_output_columns):
    assert(tensor.shape.as_list()[1] == len(columns_list))
    assert(num_output_columns >= len(columns_list))

    tensor = tf.transpose(tensor)
    columns = tf.constant(np.array([columns_list]).T.astype('int32'))
    tensor_shape = tf.shape(tensor)[1]
    scattered = tf.scatter_nd(columns, tensor, shape=(num_output_columns, tensor_shape))
    return tf.transpose(scattered)

关于python - 如何使用 tensorflow 应用动态形状 `scatter_nd`,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54505292/

相关文章:

python - 将列表中的奇数和偶数移动到奇数和偶数位置

python - 将矩阵中位置低于 0 的所有元素转换为 0 (Python)

python - python 中多线程中运行的tensorflow 问题。该函数在没有线程的情况下运行良好,但在线程中则不行

python - python 中的 smtplib.server.sendmail 函数引发 UnicodeEncodeError : 'ascii' codec can't encode character

python - Keras 新手 : how to load a pretrained MalConv model to predict in my data?

c - tensorflow C API : How to modify the value in tensor

python - 在同一目录中导入对象时出现问题

python - 我们如何隐藏轴 matplotlib 中的第一个零

python - 使用 turtle 图形的强化学习算法不起作用

javascript - 仅匹配完整数字正则表达式