python - 一个问题以及如何在创建模型时处理批处理

标签 python tensorflow keras keras-layer tf.keras

enter image description here

from keras_multi_head import MultiHeadAttention
import keras
from keras.layers import Dense,Input,Multiply
from keras import backend as K
from keras.layers.core import Dropout, Layer
from keras.models import Sequential,Model
import numpy as np
import tensorflow as tf
from self_attention_layer import Encoder



## multi source attention
class Multi_source_attention(keras.Model):

    def __init__(self,read_n,embed_dim,num_heads,ff_dim,num_layers):
        super().__init__()
        self.read_n = read_n
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.num_layers = num_layers
        self.get_weights = Dense(49, activation = 'relu',name = "get_weights")
    
        
    def compute_output_shape(self,input_shape):
        #([batch,7,7,256],[1,256])
        return input_shape


    def call(self,inputs):
        ## weights matrix

        #(1,49)
        weights_res = self.get_weights(inputs[1])
        #(1,7,7)
        weights = tf.reshape(weights_res,(1,7,7))
        #(256,7,7)
        weights = tf.tile(weights,[256,1,1])
      
        ## img from mobilenet
        img=tf.reshape(inputs[0],[-1,7,7])


        
        inter_res = tf.multiply(img,weights)
        inter_res = tf.reshape(inter_res, (-1,256,49))
        print(inter_res.shape)
        att = Encoder(self.embed_dim,self.num_heads,self.ff_dim,self.num_layers)(inter_res)

        return att

我尝试构建一个网络来实现图中圈出的部分。 LSTM **(1,256) 和之前的 Mobilenet (batch,7,7,256) 的输出。然后将 LSTM 的输出转换为 (7,7) 形式的权重矩阵。

但问题是 mobilenet 输出的输入形状有一个属性 batch。我不知道如何处理 “批处理” 或如何设置参数来限制批处理?

有人可以给我提示吗?

如果我删除函数 compute_output_shape(),就会出现一个错误 unimplementerror。 keras 官方文档告诉我不需要覆盖该函数。 有人可以解释一下吗?

最佳答案

Compute_output_shape 对于自定义层至关重要。如果调用函数 summary() ,则会生成相应的图形,其中显示每一层的输入和输出形状。 compute_output_shape 负责输出形状。

关于python - 一个问题以及如何在创建模型时处理批处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67170370/

相关文章:

python canvas在循环中创建图像

python - 使用数字后缀对不同长度的字符串进行排序

python - 训练损失极低,预测值始终相同

python - 使用 Keras 训练时的 Tensorflow InvalidArgumentError(索引)

c# - 将基于 Python 的 TensorFlow 集成到 .NET 应用程序中

python - Keras 模型训练在一段时间后会占用更多时间

python - re.sub 试图转义 repl 字符串?

python - 子图中的分页符? Matplotlib 多页子图

python - 在 keras 中制作自定义损失函数

tensorflow - 在自己的损失函数中计算预测导数