python - TF2/Keras 切片张量使用 [ :, :, 0]

标签 python tensorflow keras deep-learning tensorflow2.0

在 TF 2.0 Beta 中我正在尝试:

x = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)
print(x.shape) # (None, 240, 2)
a = x[:, :, 0]
print(a.shape) # <unknown>

在 TF 1.x 中我可以这样做:

x = tf1.placeholder(tf1.float32, (None, 240, 2)
a = x[:, :, 0]

它会工作得很好。如何在 TF 2.0 中实现这一目标?我认为

tf.split(x, 2, axis=2)

可能有用,但我想使用切片而不是硬编码 2(轴 2 的暗淡)。

最佳答案

不同之处在于,从Input返回的对象表示一个层,而不是任何类似于占位符或张量的东西。因此,上面 tf 2.0 代码中的 x 是一个图层对象,而 tf 1.x 代码中的 x 是张量的占位符。

您可以定义一个切片层来执行该操作。有开箱即用的层可用,但对于像这样的简单切片,Lambda 层非常易于阅读,并且可能最接近您在 tf 1.x 中习惯的切片方式.

类似这样的事情:

input_lyr = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)
sliced_lyr = tf.keras.layers.Lambda(lambda x: x[:,:,0])

您可以在 keras 模型中使用它,如下所示:

model = tf.keras.models.Sequential([
    input_lyr,
    sliced_lyr,
    # ...
    # <other layers>
    # ...
])

当然,以上内容特定于 keras 模型。相反,如果您有一个张量而不是 keras 图层对象,则切片的工作方式与以前完全相同。像这样的事情:

my_tensor = tf.random.uniform((8,240,2))
sliced = my_tensor[:,:,0]

print(my_tensor.shape)
print(sliced.shape)

输出:

(8, 240, 2)
(8, 240)

正如预期

关于python - TF2/Keras 切片张量使用 [ :, :, 0],我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57438747/

相关文章:

python - GRU加载模型错误,ValueError : GRU(reset_after=False) is not compatible with GRU(reset_after=True)

python - 如何在 Altair 中创建分组条形图?

python - 将十六进制值转换为 ASCII

python - 跳过 Elif 语句

python - 如何处理 CNN 训练 Keras 的数千张图像

python - 在没有自定义对象的情况下保存完整的 tf.keras 模型?

python - Pytorch 与 Keras : Pytorch model overfits heavily

python - 在 ctype 结构中使用无符号整数时报告无效 JSON

python - 去除 Bazel-ing TensorFlow Serving

python - n 批处理 > 1 的多元正态分布