python - 无法在 TensorFlow 中转换部分转换的张量

标签 python tensorflow

TensorFlow 中有很多方法需要指定形状,例如 truncated_normal:

tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)

我有一个用于输入形状 [None, 784] 的占位符,其中第一个维度是 None,因为批量大小可能会有所不同。我可以使用固定的批量大小,但它仍然与测试/验证集大小不同。

我无法将此占位符提供给 tf.truncated_normal 因为它需要完全指定的张量形状。让 tf.truncated_normal 接受不同张量形状的简单方法是什么?

最佳答案

您只需要将其作为单个示例输入,但以批处理的形式输入。所以这意味着给形状增加一个额外的维度,例如

batch_size = 32 # set this to the actual size of your batch
tf.truncated_normal((batch_size, 784), mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)

这样,它将“适合”占位符。

如果您希望 batch_size 改变,您也可以使用:
tf.truncated_normal(tf.shape(input_tensor), mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)

input_tensor 可以是占位符,也可以是任何张量将添加此噪声的地方。

关于python - 无法在 TensorFlow 中转换部分转换的张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35289773/

相关文章:

python : run multiple queries in parallel and get the first finished

python - 绘制直方图箱的箱线图以进行直方图比较

tensorflow - 如何在 TensorFlow 2.0 中调试 Keras?

python - Tensorflow:如何保存变量并将它们加载到不同的变量?

python - 用图形可视化带有字符串的 numpy bool 数组

Python人类可读的日期差异

python - 如何将张量视为图像?

tensorflow - 多头注意层 - Keras 中的包装多头层是什么?

gpu - cublas 的tensorflow运行错误

python - 无法在 Google Cloud SQL 上插入数据,但 UID 字段仍在自动递增