python - Tensorflow:从大于 2 GB 的 numpy 数组创建小批量

标签 python tensorflow training-data tensorflow-datasets mini-batch

我正在尝试将 numpy 数组的小批量输入到我的模型中,但我受困于批处理。使用“tf.train.shuffle_batch”会引发错误,因为“图像”数组大于 2 GB。我试图绕过它并创建占位符,但是当我尝试提供数组时,它们仍然由 tf.Tensor 对象表示。我主要担心的是我在模型类下定义了操作,并且在运行 session 之前不会调用对象。有谁知道如何处理这个问题?

def main(mode, steps):
  config = Configuration(mode, steps)



  if config.TRAIN_MODE:

      images, labels = read_data(config.simID)

      assert images.shape[0] == labels.shape[0]

      images_placeholder = tf.placeholder(images.dtype,
                                                images.shape)
      labels_placeholder = tf.placeholder(labels.dtype,
                                                labels.shape)

      dataset = tf.data.Dataset.from_tensor_slices(
                (images_placeholder, labels_placeholder))

      # shuffle
      dataset = dataset.shuffle(buffer_size=1000)

      # batch
      dataset = dataset.batch(batch_size=config.batch_size)

      iterator = dataset.make_initializable_iterator()

      image, label = iterator.get_next()

      model = Model(config, image, label)

      with tf.Session() as sess:

          sess.run(tf.global_variables_initializer())

          sess.run(iterator.initializer, 
                   feed_dict={images_placeholder: images,
                          labels_placeholder: labels})

          # ...

          for step in xrange(steps):

              sess.run(model.optimize)

最佳答案

您正在使用 initializable iterator tf.Data 将数据提供给您的模型。这意味着您可以根据占位符对数据集进行参数化,然后为迭代器调用初始化操作以准备使用。

如果您使用可初始化迭代器或 tf.Data 中的任何其他迭代器将输入提供给您的模型,则不应使用 feed_dict 参数>sess.run 尝试进行数据馈送。相反,根据 iterator.get_next() 的输出定义您的模型,并从 sess.run 中省略 feed_dict

沿着这些线的东西:

iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()

# use get_next outputs to define model
model = Model(config, image_batch, label_batch) 

# placeholders fed in while initializing the iterator
sess.run(iterator.initializer, 
            feed_dict={images_placeholder: images,
                       labels_placeholder: labels})

for step in xrange(steps):
     # iterator will feed image and label in the background
     sess.run(model.optimize) 

迭代器将在后台向您的模型提供数据,不需要通过 feed_dict 额外提供数据。

关于python - Tensorflow:从大于 2 GB 的 numpy 数组创建小批量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49053569/

相关文章:

python - flask-sqlalchemy 表编辑未显示

python - 损失: nan When build a model for bike sharing

python - TfLite LSTM 模型

image - Tensorflow、train_step 馈送不正确

python - IBM Watson nl-c 训练时间

python - 从 netCDF 文件读取数据时 Missing_value 属性丢失?

python - 在 Pandas 数据框中水平填充单元格值

Python 将正确格式化为数组的用户输入转换为 int 数组

tensorflow - 语义图像分割神经网络 (DeepLabV3+) 的内存过多问题

machine-learning - 如果我没有标记数据来训练我的模型,这就是无监督学习吗