python - Tensorflow 数据集 API 中的过采样功能

标签 python tensorflow sampling tensorflow-datasets

请问目前数据集的API是否允许实现过采样算法?我处理高度不平衡的类(Class)问题。我在想在数据集解析过程中对特定类进行过度采样会很好,即在线生成。我已经看到了 rejection_resample 函数的实现,但是这会删除样本而不是复制它们,并且它会减慢批处理生成的速度(当目标分布与初始分布有很大不同时)。我想实现的是:举个例子,看看它的类概率决定是否复制它。然后调用 dataset.shuffle(...) dataset.batch(...) 得到迭代器。最好的(在我看来)方法是对低概率类别进行过度采样,并对最可能的类别进行子采样。我想在线进行,因为它更灵活。

最佳答案

此问题已在issue #14451 中解决. 只需在此处发布答案即可让其他开发人员更清楚地看到它。

示例代码对低频类进行过采样,对高频类进行欠采样,其中 class_target_prob 在我的例子中只是均匀分布。我想检查最近手稿的一些结论A systematic study of the class imbalance problem in convolutional neural networks

特定类的过采样是通过调用:

dataset = dataset.flat_map(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(oversample_classes(x))
)

这是完成所有事情的完整代码段:

# sampling parameters
oversampling_coef = 0.9  # if equal to 0 then oversample_classes() always returns 1
undersampling_coef = 0.5  # if equal to 0 then undersampling_filter() always returns True

def oversample_classes(example):
    """
    Returns the number of copies of given example
    """
    class_prob = example['class_prob']
    class_target_prob = example['class_target_prob']
    prob_ratio = tf.cast(class_target_prob/class_prob, dtype=tf.float32)
    # soften ratio is oversampling_coef==0 we recover original distribution
    prob_ratio = prob_ratio ** oversampling_coef 
    # for classes with probability higher than class_target_prob we
    # want to return 1
    prob_ratio = tf.maximum(prob_ratio, 1) 
    # for low probability classes this number will be very large
    repeat_count = tf.floor(prob_ratio)
    # prob_ratio can be e.g 1.9 which means that there is still 90%
    # of change that we should return 2 instead of 1
    repeat_residual = prob_ratio - repeat_count # a number between 0-1
    residual_acceptance = tf.less_equal(
                        tf.random_uniform([], dtype=tf.float32), repeat_residual
    )

    residual_acceptance = tf.cast(residual_acceptance, tf.int64)
    repeat_count = tf.cast(repeat_count, dtype=tf.int64)

    return repeat_count + residual_acceptance


def undersampling_filter(example):
    """
    Computes if given example is rejected or not.
    """
    class_prob = example['class_prob']
    class_target_prob = example['class_target_prob']
    prob_ratio = tf.cast(class_target_prob/class_prob, dtype=tf.float32)
    prob_ratio = prob_ratio ** undersampling_coef
    prob_ratio = tf.minimum(prob_ratio, 1.0)

    acceptance = tf.less_equal(tf.random_uniform([], dtype=tf.float32), prob_ratio)

    return acceptance


dataset = dataset.flat_map(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(oversample_classes(x))
)

dataset = dataset.filter(undersampling_filter)

dataset = dataset.repeat(-1)
dataset = dataset.shuffle(2048)
dataset = dataset.batch(32)

sess.run(tf.global_variables_initializer())

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

更新#1

这是一个简单的 jupyter notebook它在玩具模型上实现了上述过采样/欠采样。

关于python - Tensorflow 数据集 API 中的过采样功能,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47236465/

相关文章:

python - Pandas df.to_records() 返回一维 numpy 数组

python - 在 Django 中使用 templatetags 时导入错误

python - 如何将 numpy 数组存储为 tfrecord?

python - 在AWS Sagemaker中使用Tensorflow Estimator时,训练作业是否会自动将模型工件保存到/opt/ml/model?

python - 如何将不平衡库与 sklearn pipeline 一起使用?

python - Numpy:根据形状角选择任意形状

python - 为什么 get_weights 返回一个空列表?

linux - linux中的音频流采样率

c++ - 连续 WASAPI 环形缓冲区采样

Python:Selenium "no such element"XPath 或 ID