python - 在 tensorflow 中对不平衡数据集进行二次采样

标签 python tensorflow tensorflow-datasets

这里是 Tensorflow 初学者。这是我的第一个项目,我正在使用预定义的估算器。

我有一个极其不平衡的数据集,其中积极结果大约占总数据的 0.1%,我怀疑这种不平衡会极大地影响我的模型的性能。作为解决这个问题的第一次尝试,由于我有大量数据,我想扔掉大部分底片以创建一个平衡的数据集。我可以看到两种实现方法:预处理数据以仅保留千分之一的负数,然后将其保存在新文件中,然后再将其传递到 tensorflow ,例如使用 pyspark;并要求 tensorflow 仅使用它找到的一千个负数中的一个。

我尝试对最后一个想法进行编码,但没有成功。我修改了我的输入函数,使其看起来像

def train_input_fn(data_file="../data/train_input.csv", shuffle_size=100_000, batch_size=128):
    """Generate an input function for the Estimator."""

    dataset = tf.data.TextLineDataset(data_file)  # Extract lines from input files using the Dataset API.
    dataset = dataset.map(parse_csv, num_parallel_calls=3)
    dataset = dataset.shuffle(shuffle_size).repeat().batch(batch_size)

    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()

    # TRY TO IMPLEMENT THE SELECTION OF NEGATIVES
    thrown = 0
    flag = np.random.randint(1000)
    while labels == 0 and flag != 0:
        features, labels = iterator.get_next()
        thrown += 1
        flag = np.random.randint(1000)
    print("I've thrown away {} negative examples before going for label {}!".format(thrown, labels))
    return features, labels

这当然是行不通的,因为迭代器不知道它们里面有什么,所以 labels==0 条件永远不会满足。另外,标准输出中只有一个打印,这意味着该函数仅被调用一次(这意味着我仍然不明白 tensorflow 的真正工作原理)。不管怎样,有没有办法实现我想要的?

PS:我怀疑以前的代码即使按预期工作,也会返回不到初始负数的千分之一,因为每次发现正数时都会重新开始计数。这是一个小问题,到目前为止,我什至可以在标志内找到一个神奇的数字,它可以给我预期的结果,而不必过多担心它的数学之美。

最佳答案

通过对代表性不足的类别进行过采样,而不是丢弃代表性过高的类别中的数据,您可能会获得更好的结果。通过这种方式,您可以保持代表性过高的类别中的差异。您不妨使用您拥有的数据。

实现这一目标的最简单方法可能是创建两个数据集,每个类一个。然后,您可以使用 Dataset.interleave 从两个数据集中均匀采样。

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#interleave

关于python - 在 tensorflow 中对不平衡数据集进行二次采样,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49735127/

相关文章:

Python Django 奇怪地映射到意外的 URL

python - Pandas - 如何设置规则来选择要删除的重复项

python - 管理多个 session 和图表的合理方法

python - 具有稀疏数据的 tensorflow 训练

tensorflow - 将 Tensorflow 数据集 API 创建的数据集拆分为训练和测试?

python - tf.data.Dataset - map() 和 cache() 方法的行为

python - Tensorflow:如何使用 Mul 操作创建 tf.NodeDef()?

python - 处理 XML 数据的理想数据结构

python - 如何使用 TensorFlow 2.0 打乱两个 numpy 数据集?

python - 在图形执行模式下拆分 tensorflow tf.data 数据集的示例