python - Tensorflow build() 如何从 tf.keras.layers.Layer 工作

标签 python tensorflow keras

我想知道是否有人知道 build()函数从 tf.keras.layers.Layer 开始工作引擎盖下的类。根据documentation :

build is called when you know the shapes of the input tensors and can do the rest of the initialization


所以对我来说,类(class)的行为似乎与此类似:
class MyDenseLayer:
  def __init__(self, num_outputs):
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]), self.num_outputs])

  def __call__(self, input):
    self.build(input.shape) ## build is called here when input shape is known
    return tf.matmul(input, self.kernel)
我无法想象build()将永远被召唤 __call__ ,但它是唯一传入输入的地方。有谁知道这到底是如何工作的?

最佳答案

Layer.build()方法通常用于实例化层的权重。见source code for tf.keras.layers.Dense 例如,请注意在该函数中创建了权重和偏差张量。 Layer.build()方法采用 input_shape参数,权重和偏差的形状通常取决于输入的形状。Layer.call()另一方面,方法实现了层的前向传递。您不想覆盖 __call__ , 因为这是在基类 tf.keras.layers.Layer 中实现的.在自定义层中,您应该实现 call() .Layer.call()不打电话 Layer.build() .但是,Layer().__call__()如果该层尚未构建( source ),则调用它,这将设置属性 self.built = True防止Layer.build()从再次被调用。换句话说,Layer.__call__()只来电 Layer.build()第一次调用它。

关于python - Tensorflow build() 如何从 tf.keras.layers.Layer 工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63383594/

相关文章:

python - 支持思科路由器,使用 NAPALM,使用 SSH 远程登录

tensorflow - 当前一层中 mask_zero=True 时,由于连接层,Keras 图像字幕模型无法编译

python - 如何打印输出到 keras model.fit 的数据,特别是使用 petastorm 数据集时

Python:Keras 上的训练和预测回归问题

python - 我可以使用 VIM 在 Python 中查看函数的文档字符串吗?

Python MySQLdb 更新查询失败

python - 使用 beautifulsoup 访问未标记的文本

tensorflow - GPU 未被 TensorFlow 捕获

python - Conda 不使用已安装的包,而是使用外部包

tensorflow - TensorFlow 对象检测配置文件中的 "num_examples: 2000"是什么意思?