python - 使用 tf.data 读取 CSV 文件很慢,改用 tfrecords?

标签 python tensorflow tensorflow-datasets

我有很多 CSV 文件,每条记录包含约 6000 列。第一列是标签,其余列应视为特征向量。我是 Tensorflow 的新手,我不知道如何将数据读入具有所需格式的 Tensorflow Dataset。我目前正在运行以下代码:

DEFAULTS = []
n_features = 6170
for i in range(n_features+1):
  DEFAULTS.append([0.0])

def parse_csv(line):
    # line = line.replace('"', '')
    columns = tf.decode_csv(line, record_defaults=DEFAULTS)  # take a line at a time
    features = {'label': columns[-1], 'x': tf.stack(columns[:-1])}  # create a dictionary out of the features
    labels = features.pop('label')  # define the label

    return features, labels


def train_input_fn(data_file=sample_csv_file, batch_size=128):
    """Generate an input function for the Estimator."""
    # Extract lines from input files using the Dataset API.
    dataset = tf.data.TextLineDataset(data_file)
    dataset = dataset.map(parse_csv)
    dataset = dataset.shuffle(10000).repeat().batch(batch_size)
    return dataset.make_one_shot_iterator().get_next()

每个 CSV 文件大约有 10K 条记录。我已尝试对 train_input_fn 进行示例评估,如 labels = train_input_fn()[1].eval(session=sess)。这会获得 128 个标签,但大约需要 2 分钟

我是在使用一些多余的操作还是有更好的方法来做到这一点?

PS:我在 Spark Dataframe 中有原始数据。因此,如果可以加快速度,我也可以使用 TFRecords。

最佳答案

你做得对。但更快的方法是使用 TFRecords,如以下步骤所示:

  1. 使用tf.python_io.TFRecordWriter: -- 要读取 csv 文件并将其写入 tfrecord 文件,如下所示:Tensorflow create a tfrecords file from csv .

  2. 从 tfrecord 读取: --

    def _parse_function(proto):
       f = {
           "features": tf.FixedLenSequenceFeature([], tf.float32, default_value=0.0, allow_missing=True),
           "label": tf.FixedLenSequenceFeature([], tf.float32, default_value=0.0, allow_missing=True)
           }
           parsed_features = tf.parse_single_example(proto, f)
           features = parsed_features["features"]
           label = parsed_features["label"]
           return features, label
    
    
    dataset = tf.data.TFRecordDataset(['csv.tfrecords'])
    dataset = dataset.map(_parse_function)
    dataset = dataset.shuffle(10000).repeat().batch(128)
    iterator = dataset.make_one_shot_iterator()
    features, label = iterator.get_next()
    

我在随机生成的 csv 上运行了两种情况 (csv vs tfrecords)。 csv 直接读取的 10 个批处理(每个批处理 128 个样本)的总时间约为 204s,而 tfrecord 的总时间约为 0.22s

关于python - 使用 tf.data 读取 CSV 文件很慢,改用 tfrecords?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50383963/

相关文章:

tensorflow - 评估 Tensorflow 张量

python - 如何将 grad 方法添加到 theano Op 中?

python 使用包含模式复制文件

python - 随机 Sprite 运动和碰撞Python

python 在您的分发包中包含一个文件夹

tensorflow - 无法在 Intel i7 930 CPU 上从源代码编译 TensorFlow; GTS-250 显卡

python - Tensorflow - 在 tensorflow.models.embeddings 中没有名为 'embeddings' 的模块

python - 如何向现有 Tensorflow 数据集对象添加/更改组件名称?

machine-learning - 高偏差卷积神经网络不会随着更多层/滤波器而得到改善

python - Keras 模型未能减少损失