python - Tensorflow:如何找到 tf.data.Dataset API 对象的大小

标签 python tensorflow tensorflow-datasets

我理解数据集 API 是一种迭代器,它不会将整个数据集加载到内存中,因此它无法找到数据集的大小。我说的是存储在文本文件或 tfRecord 文件中的大量数据的上下文。这些文件通常使用 tf.data.TextLineDataset 或类似的东西读取。使用 tf.data.Dataset.from_tensor_slices 可以轻松找到加载的数据集的大小。

我询问数据集大小的原因如下: 假设我的数据集大小是 1000 个元素。批量大小 = 50 个元素。然后训练步骤/批处理(假设 1 个纪元)= 20。在这 20 个步骤中,我想将我的学习率从 0.1 呈指数衰减到 0.01 作为

tf.train.exponential_decay(
    learning_rate = 0.1,
    global_step = global_step,
    decay_steps = 20,
    decay_rate = 0.1,
    staircase=False,
    name=None
)

在上面的代码中,我有“并且”想设置decay_steps = number of steps/batches per epoch = num_elements/batch_size。只有事先知道数据集中的元素数量,才能计算出这一点。

提前知道大小的另一个原因是使用 tf.data.Dataset.take()tf.data.Dataset.skip( ) 方法。

PS:我不是在寻找蛮力方法,例如遍历整个数据集并更新计数器以计算元素数量或 putting a very large batch size and then finding the size of the resultant dataset

最佳答案

您可以使用以下方式轻松获取数据样本的数量:

dataset.__len__()

你可以像这样获取每个元素:

for step, element in enumerate(dataset.as_numpy_iterator()):
...     print(step, element)

您还可以获得一个样本的形状:

dataset.element_spec

如果你想获取特定的元素,你也可以使用分片方法。

关于python - Tensorflow:如何找到 tf.data.Dataset API 对象的大小,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50919390/

相关文章:

tensorflow - 嵌入文本数据的 TFRecords

python - 如何访问 tf.data.Dataset.list_files() 收集的文件名?

python - 获取两个列表之间的交集

c++ - 嵌入 Python 并向解释器添加 C 函数

python - 将我的聊天机器人插入网站

python - 霍夫曼树中的唯一标识符节点

python - Tensorflow 可以用于多元函数的全局最小化吗?

tensorflow - 当我尝试在 jetson tx1 中加载卷积预训练模型时,tensorflow 中出现错误

python - ModuleNotFoundError : No module named 'pegasus'

python - model.fit() 不接受 tf.data.Dataset 的输入形状