tensorflow - 如何在 Tensorflow 中的 Dataset 对象中显示类分布

标签 tensorflow dataset multilabel-classification

我正在使用我自己的图像处理多类分类任务。

filenames = [] # a list of filenames
labels = [] # a list of labels corresponding to the filenames
full_ds = tf.data.Dataset.from_tensor_slices((filenames, labels))

这个完整的数据集将被打乱并分成训练数据集、有效数据集和测试数据集
full_ds_size = len(filenames)
full_ds = full_ds.shuffle(buffer_size=full_ds_size*2, seed=128) # seed is used for reproducibility

train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)

train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)  
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)

现在我正在努力理解每个类在 train_ds、valid_ds 和 test_ds 中的分布情况。一个丑陋的解决方案是迭代数据集中的所有元素并计算每个类的出现次数。有没有更好的方法来解决它?

我丑陋的解决方案:
def get_class_distribution(dataset):
    class_distribution = {}
    for element in dataset.as_numpy_iterator():
        label = element[1]

        if label in class_distribution.keys():
            class_distribution[label] += 1
        else:
            class_distribution[label] = 0

    # sort dict by key
    class_distribution = collections.OrderedDict(sorted(class_distribution.items())) 
    return class_distribution


train_ds_class_dist = get_class_distribution(train_ds)
valid_ds_class_dist = get_class_distribution(valid_ds)
test_ds_class_dist = get_class_distribution(test_ds)

print(train_ds_class_dist)
print(valid_ds_class_dist)
print(test_ds_class_dist)

最佳答案

下面的答案假设:

  • 有五个类(class)。
  • 标签是从 0 到 4 的整数。

  • 它可以根据您的需要进行修改。

    定义一个计数器函数:

    def count_class(counts, batch, num_classes=5):
        labels = batch['label']
        for i in range(num_classes):
            cc = tf.cast(labels == i, tf.int32)
            counts[i] += tf.reduce_sum(cc)
        return counts
    

    使用 reduce手术:

    initial_state = dict((i, 0) for i in range(5))
    counts = train_ds.reduce(initial_state=initial_state,
                             reduce_func=count_class)
    
    print([(k, v.numpy()) for k, v in counts.items()])
    

    关于tensorflow - 如何在 Tensorflow 中的 Dataset 对象中显示类分布,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60876805/

    相关文章:

    C#,循环遍历数据集并显示数据集列中的每条记录

    c# - DataGrid ASP.net C# 中的中型 Blob

    python - Scikit-learn 多目标

    keras - LSTM 文本分类准确率低 Keras

    Python Scikit 学习 : Multilabel Classification ValueError: could not convert string to float:

    tensorflow - 如何控制keras镜像策略中状态指标的缩减策略

    python - 导入错误 : No module named tensorflow in Spyder

    dataset - 推荐系统需要多少数据?

    matplotlib - 尝试分割图像颜色时出错 : numpy. ndarray' 对象没有属性 'mask'

    python - 在 Tensorflow NN 模型中将权重初始化为单位矩阵