python - 如何为二维张量的每一行计算掩码均值?

标签 python tensorflow mean

我有一个像这样的二维张量:

[[1. 0. 0. 2. 1. 0. 1.]
 [0. 0. 0. 1. 0. 0. 0.]
 [2. 0. 2. 1. 1. 3. 0.]]

我想计算每一行中每个非零元素的平均值,因此结果将是:

[1.25 1.   1.8 ]

我如何使用 TensorFlow 做到这一点?

最佳答案

计算每行屏蔽均值的一种方法是使用 tf.math.unsorted_segment_mean .本质上,您可以每行有一个段 ID,然后用一个额外的替换屏蔽元素的段 ID。

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    x = tf.constant([[1., 0., 0., 2., 1., 0., 1.],
                     [0., 0., 0., 1., 0., 0., 0.],
                     [2., 0., 2., 1., 1., 3., 0.]], tf.float32)
    s = tf.shape(x)
    num_rows = s[0]
    num_cols = s[1]
    # Mask for selected elements
    mask = tf.not_equal(x, 0)
    # Make per-row ids
    row_id = tf.tile(tf.expand_dims(tf.range(num_rows), 1), [1, num_cols])
    # Id is replaced for masked elements
    seg_id = tf.where(mask, row_id, num_rows * tf.ones_like(row_id))
    # Take segmented mean discarding last value (mean of masked elements)
    out = tf.math.unsorted_segment_mean(tf.reshape(x, [-1]), tf.reshape(seg_id, [-1]),
                                        num_rows + 1)[:-1]
    print(sess.run(out))
    # [1.25 1.   1.8 ]

但是,由于在这种情况下掩码恰好用于非零元素,您也可以只使用 tf.math.count_nonzero :

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    x = tf.constant([[1., 0., 0., 2., 1., 0., 1.],
                     [0., 0., 0., 1., 0., 0., 0.],
                     [2., 0., 2., 1., 1., 3., 0.]], tf.float32)
    x_sum = tf.reduce_sum(x, axis=1)
    x_count = tf.cast(tf.count_nonzero(x, axis=1), x.dtype)
    # Using maximum avoids problems when all elements are zero
    out = x_sum / tf.maximum(x_count, 1)
    print(sess.run(out))
    # [1.25 1.   1.8 ]

关于python - 如何为二维张量的每一行计算掩码均值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56460378/

相关文章:

python - 如何使用 Python 以编程方式在 Kafka Schema Registry 中注册 Avro Schema

tensorflow - undefined symbol : _ZTIN10tensorflow8OpKernelE

tensorflow - 在 Tensorflow 上使用 Keras 进行图像分类 : how to find which images are misclassified during training?

python - 取列表列表的平均值,忽略零值

Python 设计指南 :

python - 在 Pandas 数据框中保留大小为 1 >= k 的 block

python - 如何仅选择 Tensorflow 数据集的一部分并更改维度

python - Pandas 滚动适用于允许nan

r - 如何计算 R 中增加时间窗口的方法

python - 为主要功能测试设置命令行参数