python - 连接 numpy 数组时出现值错误

标签 python numpy keras concatenation generative-adversarial-network

我正在加载 mnist 数据集,如下所示,

(X_train, y_train), (X_test, y_test) = mnist.load_data()

但是,由于我需要加载和训练自己的数据集,因此我编写了如下小脚本,它将给出准确的训练和测试值

def load_train(path):
X_train = []
y_train = []
print('Read train images')
for j in range(10):
    files = glob(path + "*.jpeg")
    for fl in files:
        img = get_im(fl)
        print(fl)
        X_train.append(img)
        y_train.append(j)

return np.asarray(X_train), np.asarray(y_train)

相关模型在训练时生成大小为 (64, 28, 28, 1) 的 numpy 数组。我按如下方式连接生成图像中的 image_batch,

    X = np.concatenate((image_batch, generated_images))

但是我收到以下错误,

ValueError: all the input arrays must have same number of dimensions

img_batch 的大小为 (64, 28, 28) generated_images 的大小为 (64, 28, 28, 1)

如何扩展 X_train 中 img_batch 的维度以便与 generated_images 连接?或者是否有其他方法来加载自定义图像来代替 loadmnist?

最佳答案

Python中有一个名为np.expand_dims()的函数,它可以沿着参数中提供的轴扩展任何数组的维度。在您的情况下,请使用 img_batch = np.expand_dims(img_batch, axis=3)

另一种方法是使用@Ioannis Nasios 建议的reshape 函数。 img_batch = img_batch.reshape(64,28,28,1)

关于python - 连接 numpy 数组时出现值错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52113872/

相关文章:

python - 如何使用函数中的参数来定义要在函数中使用的数据框的名称?

python - Tweepy 速率错误

python - time.sleep -- 休眠线程或进程?

python - 单维 tensorflow 占位符

python - 从截断的正态分布中抽样

python - 将一对二的 x-y 数据分为顶部和底部集

python - Keras Multilabel Multiclass 单个标签准确率

python - Lambda层: An op outside of the function building code is being passed a “Graph” tensor

python - 根据其他数组从 numpy 数组中选择值

tensorflow - 训练损失在 12 个时期后增加