tensorflow - 从 TFRecord 保存和读取可变大小列表

标签 tensorflow

将稀疏向量存储到 TFRecord 的最佳方法是什么?我的稀疏向量仅包含 1 和 0,因此我决定只保存“1”所在位置的索引,如下所示:

example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'label': self._int64_feature(label),
                'features' : self._int64_feature_list(values)
            }
        )
    )

这里,values 是包含“ones”索引的列表。这个 values 数组有时包含数百个元素,有时根本没有。之后,我只需将序列化示例保存到 tfrecord。后来,我像这样读取 tfrecord:

features = tf.parse_single_example(
    serialized_example,
    features={
        # We know the length of both fields. If not the
        # tf.VarLenFeature could be used
        'label': tf.FixedLenFeature([], dtype=tf.int64),
        'features': tf.VarLenFeature(dtype=tf.int64)
    }
)

label = features['label']
values = features['features']

这不起作用,因为 values 数组被识别为稀疏数组,并且我没有获取已保存的数据。在 tfrecords 中存储稀疏张量的最佳方法是什么以及如何读取它?

最佳答案

如果您只是序列化 1 的位置,您应该能够通过一些技巧得到正确的稀疏张量:

解析后的稀疏张量features['features']看起来像这样:

features['features'].indices: [[batch_id, 位置]...]

其中position是一个无用的枚举。

但您确实希望 feature['features'] 看起来像 [[batch_id, one_position], ...]

其中 one_position 是您在稀疏张量中指定的实际值。

所以:

indices = features['features'].indices
indices = tf.transpose(indices) 
# Now looks like [[batch_id, batch_id, ...], [position, position, ...]]
indices = tf.stack([indices[0], features['features'].values])
# Now looks like [[batch_id, batch_id, ...], [one_position, one_position, ...]]
indices = tf.transpose(indices)
# Now looks like [[batch_id, one_position], [batch_id, one_position], ...]]
features['features'] = tf.SparseTensor(
   indices=indices,
   values=tf.ones(shape=tf.shape(indices)[:1])
   dense_shape=1 + tf.reduce_max(indices, axis=[0])
)

瞧! features['features'] 现在表示一个矩阵,该矩阵是串联的一批稀疏向量。

注意:如果您想将其视为稠密张量,则必须执行 tf.sparse_to_dense 并且稠密张量将具有形状 [None, None] (这使得使用起来有点困难]。如果您知道最大可能的向量长度,您可能需要对其进行硬编码:dense_shape=[batch_size, max_vector_length]

关于tensorflow - 从 TFRecord 保存和读取可变大小列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37270697/

相关文章:

python - 值错误 : Feature not in features dictionary

javascript - 检查输入 : expected dense_Dense5_input to have 4 dimension(s). 时出错,但得到形状为 5、2、5 的数组

machine-learning - tensorflow 中的多热编码(谷歌云机器学习,tf estimator api)

python - 在 Ubuntu 14.04(64 位)上安装 TensorFlow-0.9.0rc0 在此平台上不受支持

tensorflow - 如何将 saved_model.pb 转换为 EvalSavedModel?

string - tensorflow,如何将 tf.string SparseTensor 连接到一维密集张量

tensorflow - 如何在使用自定义训练循环时使用类权重来计算自定义损失函数(即不使用 .fit )

tensorflow - `batch` 在 Tensorflow 数据集管道创建中解决了什么问题以及它如何与训练中使用的(迷你)批量大小交互?

python - 两个 Conv 层之间的 Dropout 和 Batchnormalization

python - TensorFlow 中的延迟加载实现