python - 如何在 tensorflow 数据集中加载 numpy 数组

标签 python tensorflow tensorflow-datasets

我正在尝试从 numpy 数组开始在tensorflow 1.14中创建一个数据集对象(我有一些无法更改此特定项目的遗留代码),但每次我尝试时,我都会将所有内容复制到我的图表上并用于这就是为什么当我创建事件日志文件时它很大(在本例中为 719 MB)。

最初我尝试使用这个函数“tf.data.Dataset.from_tensor_slices()”,但它不起作用,然后我读到这是一个常见问题,有人建议我尝试使用生成器,因此我尝试使用下面的代码,但我再次得到了一个巨大的事件文件(再次为 719 MB)

def fetch_batch(x, y, batch):
    i = 0
    while i < batch:
        yield (x[i,:,:,:], y[i])
        i +=1

train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train  
images = images/255

training_dataset = tf.data.Dataset.from_generator(fetch_batch, 
    args=[images, np.int32(labels), batch_size], output_types=(tf.float32, tf.int32), 
    output_shapes=(tf.TensorShape(features_shape), tf.TensorShape(labels_shape)))

file_writer = tf.summary.FileWriter("/content", graph=tf.get_default_graph())

我知道在这种情况下我可以使用tensorflow_datasets API,它会更容易,但这是一个更普遍的问题,它涉及如何创建一般数据集,而不仅仅是使用mnist数据集。 你能向我解释一下我做错了什么吗?谢谢您

最佳答案

我猜这是因为您在 from_generator 中使用了 args。这肯定会将提供的 args 放入图表中。

您可以做的是定义一个函数,该函数将返回一个生成器,该生成器将迭代您的集合,例如(尚未测试):

def data_generator(images, labels):
  def fetch_examples():
    i = 0
    while True:
      example = (images[i], labels[i])
      i += 1
      i %= len(labels)
      yield example
  return fetch_examples

这将在您的示例中给出:

train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train  
images = images/255

training_dataset = tf.data.Dataset.from_generator(data_generator(images, labels), output_types=(tf.float32, tf.int32), 
    output_shapes=(tf.TensorShape(features_shape), tf.TensorShape(labels_shape))).batch(batch_size)

file_writer = tf.summary.FileWriter("/content", graph=tf.get_default_graph())

请注意,我将 fetch_batch 更改为 fetch_examples,因为您可能希望使用数据集实用程序 (.batch) 进行批处理。

关于python - 如何在 tensorflow 数据集中加载 numpy 数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59020136/

相关文章:

python - 如何在 Tensorflow 中读取二进制文件

python - 在Python中填充聚合列

python - 为什么 skimage.io.imsave 和 opencv.VideoWriter 之间的颜色不同

python - Keras 编写一个接受图像的循环层

c++ - 编译用于ARM64-v8a的Tensorflow C++ API

python - 类型错误 : fit_generator() got an unexpected keyword argument 'nb_val_samples'

python - TensorFlow 中的嵌套结构是什么?

python - tf.data : Parallelize loading step

python - Tensorflow:连接多个 tf.Dataset 非常慢

python - TensorFlow:如何以及为何使用 SavedModel