python - 将标量输入 Tensorflow 2 模型的正确方法

标签 python tensorflow deep-learning

在我的 Tensorflow 2 模型中,我希望批量大小是参数化的,这样我就可以动态构建具有适当批量大小的张量。我有以下代码:

batch_size_param = 128

tf_batch_size = tf.keras.Input(shape=(), name="tf_batch_size", dtype=tf.int32)
batch_indices = tf.range(0, tf_batch_size, 1)

md = tf.keras.Model(inputs={"tf_batch_size": tf_batch_size}, outputs=[batch_indices])
res = md(inputs={"tf_batch_size": batch_size_param})

代码在 tf.range 中抛出错误:

ValueError: Shape must be rank 0 but is rank 1
 for 'limit' for '{{node Range}} = Range[Tidx=DT_INT32](Range/start, tf_batch_size, Range/delta)' with input shapes: [], [?], []

我认为问题在于 tf.keras.Input 会自动尝试在第一维扩展输入数组,因为它期望输入的部分形状没有批量大小和将根据输入数组的形状附加批量大小,在我的例子中是一个标量。我可以将标量值作为常量整数输入到 tf.range 中,但这一次,在编译模型图后我将无法更改它。

有趣的是,即使我也检查了文档,我也未能找到将标量输入 TF-2 模型的正确方法。那么,处理这种情况的最佳方法是什么?

最佳答案

不要使用 tf.keras.Input,只需通过子类化来定义模型。

import tensorflow as tf


class ScalarModel(tf.keras.Model):
    def __init__(self):
        super().__init__()

    def call(self, x):
        return tf.range(0, x, 1)


print(ScalarModel()(10))
# tf.Tensor([0 1 2 3 4 5 6 7 8 9], shape=(10,), dtype=int32)

关于python - 将标量输入 Tensorflow 2 模型的正确方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66546321/

相关文章:

python - 无法通过python Spark连接MysqlDB

python - 嵌套函数的性能开销是多少?

python - 急切执行: gradient computation

machine-learning - 我无法让我的 tensorflow 梯度下降线性回归算法工作

python - TF.Keras model.predict 比直接 Numpy 慢?

tensorflow - 为什么在训练模型时我的 Keras 损失没有变化?

python - 未检测到 theano g++,但我已成功安装 g++

python - 从字典数组中提取数字数组

python - 如何根据列数据类型过滤数据框

python - 使用元图 tensorflow 运行模型失败