csv - 如何在 Keras 中保存编码输出

标签 csv machine-learning tensorflow neural-network keras

我想在具有 Tensorflow 后端的 Keras 自动编码器模型中的解码器阶段之前保存编码部分。

例如;

encoding_dim = 210
input_img = Input(shape=(5184,))

encoded = Dense(2592, activation='relu')(input_img)
encoded1 = Dense(encoding_dim, activation='relu')(encoded)

decoded = Dense(encoding_dim, activation='relu')(encoded1)
decoded = Dense(5184, activation='sigmoid')(decoded)

我想在自动编码器训练后将encoded1保存为csv文件。假设Dense的输出为(nb_samples, output_dim)

谢谢

最佳答案

尝试:

autoencoder = Model(input_img, decoded)
encoder = Model(input_img, encoded1)

autoencoder.compile(loss=my_loss, optimizer=my_optimizer)
autoencoder.fit(my_data, my_data, .. rest of fit params)

numpy.savetxt("encoded1.csv", encoder.predict(x), delimiter=",")

此外 - 我不知道您有什么样的数据,但我建议您使用最后一层的线性激活和 mse 损失函数。

关于csv - 如何在 Keras 中保存编码输出,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42486004/

相关文章:

java - 使用 Groovy 重新映射 CSV 结构

python - 使用 MNIST 训练模型的 Tensorflow 总是打印错误的数字

apache-spark - 如何在 pyspark 中测试/训练按列值而不是按行分割

python - 加载保存的优化器时出错。 keras python 覆盆子

mysql - 使用Cron将数据从mysql数据库导出到CSV,然后将所有数据获取到bigquery表

MySQL csv 导入截断字符

python - 修复 csv 中开头和中间带有双引号的字符串的额外双引号?

python - 如何使用 Scikit Learn 调整随机森林中的参数?

python - 由于使用 "lambda",无法加载保存的 Keras 模型

python-2.7 - 在 Heroku 上使用 python 运行tensorflow时出现导入错误?