python - 在 Keras 中保存模型时引发 'Unable to create group (name already exists)' 错误

标签 python tensorflow keras deep-learning

我正在使用来自 keras.applications 的 ResNet50 和 DenseNet121 构建模型融合,但在保存模型时引发错误。
如果我只使用 ResNet50 和 DenseNet121 的一个网络,例如 DenseNet only,没问题

与 ResNet50 和 DenseNet121 的融合:

img_input = Input(shape=input_shape)

densenet = app.DenseNet121(
    include_top=False,
    input_tensor=img_input,
    input_shape=input_shape,
    weights=base_weights)
resnet = app.ResNet50(
    include_top=False,
    input_tensor=img_input,
    input_shape=input_shape,
    weights=base_weights)

x1 = densenet.output
x1 = GlobalAveragePooling2D(name='dn_gap_last')(x1)
# then x1.shape is (batch, 1024)

x2 = resnet.output
x2 = Flatten()(x2)  # then x2.shape is (batch, 2048)

x = concatenate([x1, x2], axis=-1)
predictions = Dense(len(class_names), activation="sigmoid", name="predictions")(x)
model = Model(inputs=img_input, outputs=predictions)

并通过 ModelCheckpoint 保存模型

checkpoint = ModelCheckpoint(
                 output_weights_path,
                 save_weights_only=True,
                 save_best_only=True,
                 verbose=1,
            )

但在保存 mdoel 时引发错误
Epoch 00001: val_loss improved from inf to 0.72018, saving model to ./experiments/8/weights.h5
Traceback (most recent call last):
  File "train.py", line 229, in <module>
    main()
  File "train.py", line 212, in main
    shuffle=False,
  File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/keras/engine/training.py", line 2280, in fit_generator
    callbacks.on_epoch_end(epoch, epoch_logs)
  File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/keras/callbacks.py", line 77, in on_epoch_end
    callback.on_epoch_end(epoch, logs)
  File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/keras/callbacks.py", line 445, in on_epoch_end
    self.model.save_weights(filepath, overwrite=True)
  File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/keras/engine/topology.py", line 2607, in save_weights
    save_weights_to_hdf5_group(f, self.layers)
  File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/keras/engine/topology.py", line 2878, in save_weights_to_hdf5_group
    g = f.create_group(layer.name)
  File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/h5py/_hl/group.py", line 50, in create_group
    gid = h5g.create(self.id, name, lcpl=lcpl)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5g.pyx", line 151, in h5py.h5g.create
ValueError: Unable to create group (name already exists)

最佳答案

如果我需要像你的情况一样使用像 tf.tile 这样的操作,我会用一个 lambda 层调用它。所以有效的代码如下

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda
from tensorflow.keras import Model

def my_fun(a):
  out = tf.tile(a, (1, tf.shape(a)[0]))
  return out

a = Input(shape=(10,))
#out = tf.tile(a, (1, tf.shape(a)[0]))
out = Lambda(lambda x : my_fun(x))(a)
model = Model(a, out)

x = np.zeros((50,10), dtype=np.float32)
print(model(x).numpy())

model.save('my_model.h5')

#load the model
new_model=tf.keras.models.load_model("my_model.h5")
任何遇到类似问题的人,请关注 GitHub issue与此问题相关的最终解决方案。谢谢!
2020/02/04 编辑
通过最近的代码修改,您可以使用 tf-nightly以“h5”格式保存模型,没有任何问题,如下所示。
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras import Model

a = Input(shape=(10,))
out = tf.tile(a, (1, tf.shape(a)[0]))
model = Model(a, out)

x = np.zeros((50,10), dtype=np.float32)
print(model(x).numpy())

model.save('./my_model', save_format='h5')

关于python - 在 Keras 中保存模型时引发 'Unable to create group (name already exists)' 错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55555311/

相关文章:

Python:循环遍历目录并使用文件名作为数据框名称保存每个文件

python matplotlib 图例仅显示列表的第一个条目

tensorflow - TensorFlow的 `conv2d_transpose()`操作是做什么的?

tensorflow - 如何在 tf contrib estimator 中使用 GPU

python - Keras 模型的矩阵大小错误

python - 在卷积输出外部添加另一层

python - future 警告 : Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated

python - 菜单中的图像和标签

python - 如何制作自定义的 TensorFlow tf.nn.conv2d()?

python - 变分自动编码器 : InvalidArgumentError: Incompatible shapes: [100, 5] 与 [100]