python - for循环内的神经网络

标签 python tensorflow keras

我有这样的东西

for q in range(10):
   # generate some samples
   x = Input(batch_shape=(n_batch, xx.shape[1]))
   x = Dense(20)(x)
   x = LeakyReLU(alpha=0.001)(x)
   y = Dense(1)(x)
   y = LeakyReLU(alpha=0.001)(y)
   model = Model(inputs=x, outputs=y) 
   model.compile(loss='mean_squared_error', optimizer='Adam', metrics=['accuracy'])
   for i in range(10):
      model.fit(x, y, epochs=1, batch_size=n_batch, verbose=0, shuffle=False)
      model.reset_states()

我想知道神经网络是否是为每个 q 从头开始​​构建的,还是保留了前一个 q 的所有内容?如果它保留,我如何为每个 q 分别重置和构建、编译和拟合神经网络?

最佳答案

当你用keras或tensorflow创建一个层时,tensorflow会在它的图中添加一个或多个节点,每次你添加优化器、损失函数或激活函数时,它都会做同样的事情并为它们添加一个节点。

当您调用model.fit()时,tensorflow 会从其根开始执行其图。如果您在循环中添加节点,则先前的节点将不会被删除。它们会占用内存空间,并会降低您的性能。

该怎么办?非常简单,重新初始化权重并重新使用相同的节点。您的代码不会发生太大变化,只需使用 for 循环向下移动示例生成并定义一个函数来重新初始化即可。

我还降低了第二个 for 循环,只是将纪元数增加到 10,如果您有理由将其放在那里,则可以将其放回去。

def reset_weights(model):
    session = K.get_session()
    for layer in model.layers: 
        if hasattr(layer, 'kernel_initializer'):
            layer.kernel.initializer.run(session=session)

x = Input(batch_shape=(n_batch, xx.shape[1]))
x = Dense(20)(x)
x = LeakyReLU(alpha=0.001)(x)
y = Dense(1)(x)
y = LeakyReLU(alpha=0.001)(y)
model = Model(inputs=x, outputs=y) 
model.compile(loss='mean_squared_error', optimizer='Adam', metrics=['accuracy'])
for q in range(10):
    #generate some samples
    model.fit(x, y, epochs=10, batch_size=n_batch, verbose=1, shuffle=False)
    model.reset_states()
    reset_weights(model)

关于python - for循环内的神经网络,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53001124/

相关文章:

tensorflow - 如何使用 tensorflow 纠正 keras 的这个自定义损失函数?

python - py3k : Maximum Number In Given List - short form

python - Dask + pyinstaller 失败

python - Keras:batch_size 的类型错误

python - 将 batchnorm(TensorFlow) 的 is_training 变为 False

python - Tensorflow:难以置信的巨大稀疏分类交叉熵

machine-learning - 多类分割的广义骰子损失: keras implementation

python - Django 模板 : How to Pass PK/IDs of Displayed Objects Back to the Views. 模板中的 py 文件?

python - 根据 XML 架构 (xsd) 验证具有大文本元素的 XML

python - 为什么要在 tensorflow 中循环评估测试数据?