python - 当在 keras 中训练时 a 的形状和权重不同时必须指定轴

标签 python tensorflow keras deep-learning

我正在尝试训练一个网络,其中我使用两个生成器,一个用于训练,一个用于验证。这些只是针对不可抗拒地产生样本的函数。

我在验证结束时收到以下错误:

File "/home/ubuntu/tensorflow/lib/python3.5/site-packages/numpy/lib/function_base.py", 
line 1142, in average "Axis must be specified when shapes of a and weights "

我查看了代码,keras.engine 中的函数 training_generator 包含以下行

averages.append(np.average([out[i] for out in outs_per_batch], weights=batch_sizes))

查看np.average 的定义,当权重和数组的长度不同时,该函数需要axis。我调试了代码,并通过将 axis=0np.squeeze 放在 out[i] 上,它“””“有效” “”,仅在收集验证摘要统计信息后停止几行。我一直在想我的代码中其他地方有错误。

这是我的发电机

def batch_generator(batch_size, folder):
    files = listdir(folder)
    print("Folder " + folder + " with " + str(len(files)) + " files.")
    np.random.shuffle(files)
    while True:
        np.random.shuffle(files)
        for i in range(batch_size, len(files), batch_size):
            batch = files[(i-batch_size):(i)]
            batch = tensor_generator(folder, files=batch)
            yield (batch, batch)

def tensor_generator(folder, files=None):
    if files is None:
        files = listdir(folder)
    verbose = len(files)>100
    if verbose:
        pbar = tqdm(total=len(files), unit='img')
    tensor = []
    for f in files:
        f = SimpleITK.ReadImage(join(folder, f))
        f = SimpleITK.GetArrayFromImage(f)
        f = (f + 1000)/4000
        tensor.append(f)
        if verbose: pbar.update(1)
    if verbose: pbar.close()
    return np.stack(tensor, axis=0)

这是拟合函数

    self.autoencoder.fit_generator(
            generator=x_train,
            steps_per_epoch=iters,
            epochs=epochs,
            callbacks=[log, rop],
            validation_data=x_test,
            validation_steps=10)

知道哪里出了问题吗?

最佳答案

我遇到了同样的问题。尽管我不知道是什么导致了这个问题,但我已经解决了这个奇怪的问题。

您只需要将代码validation_data=x_test 更改为validation_data=next(x_test)。这意味着您只需要在验证数据生成器上添加 next()

关于python - 当在 keras 中训练时 a 的形状和权重不同时必须指定轴,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51408162/

相关文章:

python - 在 Python 的交互式 shell 中处理没有 try block 的异常

python - 使用keras的句子相似度

python - 如何根据权重/偏差重现 Keras 模型?

python - 将 CSV 文件转换为 TF 记录

keras - 使用 Keras train_on_batch 时将直方图摘要添加到张量板

python - 创建日志记录处理程序以连接到 Oracle?

python - Kivy 未检测到目录中的文件

python - Django:如何过滤某个科室的患者?

python - 找不到 Bazel 构建包

python - Tensorflow:将变量嵌入矩阵并求解