python - 张量维度的大小范围 - tf.range

标签 python tensorflow neural-network deep-learning

我正在尝试为我正在实现的神经网络定义一个操作,但为此我需要迭代张量的维度。我在下面有一个小的工作示例。

X = tf.placeholder(tf.float32, shape=[None, 10])
idx = [[i] for i in tf.range(X.get_shape()[0])]

这会产生一个错误说明

ValueError: Cannot convert an unknown Dimension to a Tensor: ?

当使用相同的代码但使用 tf.shape 代替时,导致代码为

X = tf.placeholder(tf.float32, shape=[None, 10])
idx = [[i] for i in tf.range(tf.shape(X)[0])]

出现以下错误

TypeError: 'Tensor' object is not iterable.

我实现此 NN 的方式是,batch_size 直到代码末尾的训练函数才定义。这正是我构建图表本身的地方,所以此时 batch_size 是未知的,它不能固定为训练 batch_size 和测试集 batch_sizes 不同。

解决此问题的最佳方法是什么?这是阻止我的代码运行的最后一件事,因为我让它以固定的 batch_size 运行,尽管这些结果没有用。数周以来,我一直在研究 TensorFlow API 文档和堆栈溢出,但无济于事。

我还尝试在范围内输入占位符,所以当我运行测试/训练集时,代码如下

X = tf.placeholder(tf.float32, shape=[None, 10])
bs = tf.placeholder(tf.int32)

def My_Function(X):
    # Do some stuff to X
    idx = [[i] for i in tf.range(bs)]
    # return some tensor

A = tf.nn.relu(My_Function(X))

然而,这给出了与上面相同的错误

TypeError: 'Tensor' object is not iterable.

最佳答案

我认为您应该改用 tf.shape(x)。

x = tf.placeholder(..., shape=[None, ...])
batch_size = tf.shape(x)[0]  # Returns a scalar `tf.Tensor`

print x.get_shape()[0]  # ==> "?"

# You can use `batch_size` as an argument to other operators.
some_other_tensor = ...
some_other_tensor_reshaped = tf.reshape(some_other_tensor, [batch_size, 32, 32])

# To get the value, however, you need to call `Session.run()`.
sess = tf.Session()
x_val = np.random.rand(37, 100, 100)
batch_size_val = sess.run(batch_size, {x: x_val})
print x_val  # ==> "37"

参见:get the size of a variable batch dimension

关于python - 张量维度的大小范围 - tf.range,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43126116/

相关文章:

python - 仅匹配域名中最左边的通配符 - Python

python - 为什么我的代码(与 CUDA 链接)偶尔会导致 Python 中的段错误?

tensorflow - 使用 Google ML 引擎和 Google Storage 存储大量图像进行训练的最佳实践

python - CUDNN_STATUS_ALLOC_FAILED 导致 Tensorflow 崩溃

python - 获取神经网络的预测列表

machine-learning - 无法使用神经网络近似正弦函数

python - 我在 Django 中的 urls.py 中的 url 模式不匹配

python - pip 下载 + 为什么 pip 不下载最新版本

slice - 如何在 TensorFlow 中对 4 阶张量进行切片?

python - 了解 Tensorflow 对象检测 API,检查点类的 kwargs,什么是 `_base_tower_layers_for_heads` ?