python - 绘图模型不显示模型层,仅显示模型名称

标签 python python-3.x tensorflow keras tensorflow2.x

我正在尝试使用 TensorFlow2 构建一些模型,因此我创建了一个模型类,如下所示:

import tensorflow as tf

class Dummy(tf.keras.Model):
    def __init__(self, name="dummy"):
        super(Dummy, self).__init__()
        self._name = name

        self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        return self.dense2(x)

model = Dummy()
model.build(input_shape=(None,5))

现在我想绘制模型,同时使用 summary() 返回我期望的内容,plot_model(model, show_shapes=True, Expand_nested=True) 仅返回一个带有模型名称的 block 。

如何返回模型的图表?

最佳答案

弗朗索瓦·肖莱 (Francois Chollet) 说道:

You can do all these things (printing input / output shapes) in a Functional or Sequential model because these models are static graphs of layers.

In contrast, a subclassed model is a piece of Python code (a call method). There is no graph of layers here. We cannot know how layers are connected to each other (because that's defined in the body of call, not as an explicit data structure), so we cannot infer input / output shapes.

有两种解决方案:

  1. 您可以按顺序构建模型/使用函数式 API。
  2. 您将“call”函数包装到函数模型中,如下所示:

类子类(模型):

def __init__(self):
    ...
def call(self, x):
    ...

def model(self):
    x = Input(shape=(24, 24, 3))
    return Model(inputs=[x], outputs=self.call(x))


if __name__ == '__main__':
    sub = subclass()
    sub.model().summary()

答案取自此处:model.summary() can't print output shape while using subclass model

此外,这是一篇好文章:https://medium.com/tensorflow/what-are-symbolic-and-imperative-apis-in-tensorflow-2-0-dfccecb01021

关于python - 绘图模型不显示模型层,仅显示模型名称,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60416449/

相关文章:

Python 类型.MethodType()

python - 使用 SqlAlchemy 将数据保存到数据库中,对象不可下标

python - 仅 1 个标量的 Tensorboard 摘要标量错误

python - 如何使用二维列表中的列数据创建新列表

python - 用于 Python 的 Directory Walker

python - 无法从文本文件中打印特定行

tensorflow - 如何在本地部署在 amazon sagemaker 上训练的模型?

multidimensional-array - Tensorflow nn.conv3d() 和 max_pool3d

python - 使用 Python 从 PNG 或 JPG 创建一个在 XP 上运行的 ICO

python argparse 在描述后打印用法文本