python - 使用许多压缩的 numpy 文件的 Tensorflow 数据集

标签 python numpy tensorflow dataset data-handling

我有一个大型数据集,我想用它来进行 Tensorflow 训练。

数据以压缩的 numpy 格式存储(使用 numpy.savez_compressed)。每个文件的图像数量因其生成方式而异。

目前,我使用基于 Keras 序列的生成器对象进行训练,但我想完全迁移到不使用 Keras 的 Tensorflow。

我正在 TF 网站上查看数据集 API,但如何使用它来读取 numpy 数据并不明显。

我的第一个想法是这个

import glob
import tensorflow as tf
import numpy as np

def get_data_from_filename(filename):
   npdata = np.load(open(filename))
   return npdata['features'],npdata['labels']

# get files
filelist = glob.glob('*.npz')

# create dataset of filenames
ds = tf.data.Dataset.from_tensor_slices(filelist)
ds.flat_map(get_data_from_filename)

但是,这会将 TF 张量占位符传递给真正的 numpy 函数,并且 numpy 需要一个标准字符串。这会导致错误:

File "test.py", line 6, in get_data_from_filename
   npdata = np.load(open(filename))
TypeError: coercing to Unicode: need string or buffer, Tensor found

我正在考虑的另一个选项(但看起来很困惑)是创建一个基于 TF 占位符的 Dataset 对象,然后在我的 epoch-batch 循环期间从我的 numpy 文件中填充该对象。

有什么建议吗?

最佳答案

您可以定义一个包装器并像这样使用 pyfunc:

def get_data_from_filename(filename):
   npdata = np.load(filename)
   return npdata['features'], npdata['labels']

def get_data_wrapper(filename):
   # Assuming here that both your data and label is float type.
   features, labels = tf.py_func(
       get_data_from_filename, [filename], (tf.float32, tf.float32)) 
   return tf.data.Dataset.from_tensor_slices((features, labels))

# Create dataset of filenames.
ds = tf.data.Dataset.from_tensor_slices(filelist)
ds.flat_map(get_data_wrapper)

如果您的数据集非常大并且存在内存问题,您可以考虑使用 interleave 的组合或parallel_interleavefrom_generator方法来代替。 from_generator 方法在内部使用 py_func,因此您可以直接读取 np 文件,然后在 python 中定义生成器。

关于python - 使用许多压缩的 numpy 文件的 Tensorflow 数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53544809/

相关文章:

python - 了解如何为 numpy reshape() 指定新的形状参数

python - 类型错误 : __call__() takes from 1 to 2 positional arguments but 3 were given

python - tensorflow 2.0.0 : AttributeError: 'TensorSliceDataset' object has no attribute 'as_numpy_iterator'

python - 统一成本解决方案中的算法问题

python - 如何获取gzip压缩文件的随机访问

python - 数组中所有点之间的最小欧氏距离数组

Python深度学习: Shape of irregular multidimensional data sets

python - 获取 ZeroDivisionError : float division in python

python - 是否可以知道您是否在 ipython 中?

python - 如何将 numpy.savetxt 与包含数组的结构化数组一起使用