python - TensorFlow 整个数据集存储在图中

标签 python tensorflow

我正在致力于使用 Cifar-10 数据集开发 CNN,并将数据提供给网络,我正在使用数据集 API 来使用带有句柄占位符的可输入迭代器:https://www.tensorflow.org/programmers_guide/datasets#creating_an_iterator 。就我个人而言,我真的很喜欢这种方法,因为它提供了一种清晰而简单的方法来将数据馈送到网络并在测试集和验证集之间切换。但是,当我在训练结束时保存图表时,创建的 .meta 文件与我开始时的测试数据一样大。我使用这些操作来提供对输入占位符和输出运算符的访问:

tf.get_collection("validation_nodes")
tf.add_to_collection("validation_nodes", input_data)
tf.add_to_collection("validation_nodes", input_labels)
tf.add_to_collection("validation_nodes", predict)

然后使用以下命令保存图表: 训练前:

saver = tf.train.Saver()

训练后:

save_path = saver.save(sess, "./my_model")

有没有办法阻止 TensorFlow 存储图中的所有数据?提前致谢!

最佳答案

您正在为数据集创建一个tf.constant,这就是将其添加到图形定义中的原因。解决方案是使用可初始化迭代器并定义占位符。在开始对图表运行操作之前要做的第一件事就是向其提供数据集。有关示例,请参阅“创建迭代器”部分下的程序员指南。

https://www.tensorflow.org/programmers_guide/datasets#creating_an_iterator

我做的完全一样,所以这里是我用来准确实现您的描述的代码相关部分的复制/粘贴(使用可初始化迭代器训练/测试 cifar10 集):

  def build_datasets(self):
    """ Creates a train_iterator and test_iterator from the two datasets. """
    self.imgs_4d_uint8_placeholder = tf.placeholder(tf.uint8, [None, 32, 32, 3], 'load_images_placeholder')
    self.imgs_4d_float32_placeholder = tf.placeholder(tf.float32, [None, 32, 32, 3], 'load_images_float32_placeholder')
    self.labels_1d_uint8_placeholder = tf.placeholder(tf.uint8, [None], 'load_labels_placeholder')
    self.load_data_train = tf.data.Dataset.from_tensor_slices({
      'data': self.imgs_4d_uint8_placeholder,
      'labels': self.labels_1d_uint8_placeholder
    })
    self.load_data_test = tf.data.Dataset.from_tensor_slices({
      'data': self.imgs_4d_uint8_placeholder,
      'labels': self.labels_1d_uint8_placeholder
    })
    self.load_data_adversarial = tf.data.Dataset.from_tensor_slices({
      'data': self.imgs_4d_float32_placeholder,
      'labels': self.labels_1d_uint8_placeholder
    })

    # Train dataset pipeline
    dataset_train = self.load_data_train
    dataset_train = dataset_train.shuffle(buffer_size=50000)
    dataset_train = dataset_train.repeat()
    dataset_train = dataset_train.map(self._img_augmentation, num_parallel_calls=8)
    dataset_train = dataset_train.map(self._img_preprocessing, num_parallel_calls=8)
    dataset_train = dataset_train.batch(self.hyperparams['batch_size'])
    dataset_train = dataset_train.prefetch(2)
    self.iterator_train = dataset_train.make_initializable_iterator()

    # Test dataset pipeline
    dataset_test = self.load_data_test
    dataset_test = dataset_test.map(self._img_preprocessing, num_parallel_calls=8)
    dataset_test = dataset_test.batch(self.hyperparams['batch_size'])
    self.iterator_test = dataset_test.make_initializable_iterator()



  def init(self, sess):
    self.cifar10 = Cifar10()    # a class I wrote for loading cifar10
    self.handle_train = sess.run(self.iterator_train.string_handle())
    self.handle_test = sess.run(self.iterator_test.string_handle())
    sess.run(self.iterator_train.initializer, feed_dict={self.handle: self.handle_train,
                                                         self.imgs_4d_uint8_placeholder: self.cifar10.train_data,
                                                         self.labels_1d_uint8_placeholder: self.cifar10.train_labels})

关于python - TensorFlow 整个数据集存储在图中,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49903653/

相关文章:

python - 值错误 ("Tensor %s is not an element of this graph."% obj)

python - 我需要将 Pandas df 转换为具有制表符分隔分隔和多行的文本字符串

python - python 中的 matplotlib 条件背景色

python - 如何将 CHATBOT 未回答的问题存储在文本文件中

tensorflow - 如何训练反向嵌入,如 vec2word?

python - 可以为分布式 Tensorflow 虚拟化 NVIDIA GeForce GTX 1070 显卡吗?

machine-learning - 如何加载检查点文件并使用略有不同的图形结构继续训练

Python最准确测量时间(毫秒)的方法

python - MySQL Python 错误 1064,MySQL 语法错误

tensorflow - 矩阵机器学习的 MSE 损失