numpy - 使用 map_fn Slow 进行元素采样

标签 numpy matrix tensorflow

假设我想对一个矩阵进行采样,其中每个条目都是从另一个矩阵中的条目定义的分布中采样的。我展开矩阵并将 map_fn 应用于每个元素。对于相对较小的矩阵 (128 x 128),以下内容给了我几个 PoolAllocator 警告 (GTX TITAN Black),并且不会在任何合理的时间内进行训练。

def sample(x):
   samples = tf.map_fn(lambda z:
                      tf.random_normal([1], mean=z,
                                       stddev=tf.sqrt(z * (1 - z))),
                      tf.reshape(x, [-1]))    # apply to each element

   return tf.cond(is_training, lambda: tf.reshape(samples, shape=tf.shape(x)),
                  lambda: tf.tanh(x))

是否有更好的方法来应用这样的元素运算?

最佳答案

如果您可以使用一次张量运算而不是像 tf.map_fn 这样的元素运算,您的代码将会运行得更快。

这里看起来您想要从每个元素的正态分布中进行采样,其中输入张量中每个值的分布参数都不同。尝试这样的事情:

def sample(x):
  samples = tf.random_normal(shape=[128, 128]) * tf.sqrt(x * (1 - x)) + x

tf.random_normal() 默认生成平均值为 0.0、标准差为 1.0 的正态分布。您可以使用逐点张量运算来修复每个元素的标准差(通过相乘)和平均值(通过相加)。事实上,如果你看看 tf.random_normal() 是如何实现的,你就会发现这正是它内部所做的。

(您可能还可以更好地使用 Python 条件来区分训练和测试时间。)

如果你打算经常做这种事情,你可以在 github 上提交一个功能请求,要求泛化 tf.random_normal 以接受具有更通用形状的张量 meanstddev 。我认为没有理由不支持这一点。

希望有帮助!

关于numpy - 使用 map_fn Slow 进行元素采样,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38135999/

相关文章:

Java以正确的形式输出矩阵

python - 从 TensorFlow 中的检查点恢复后修改变量名称

python - 使用 1D 变换实现 2D 傅里叶逆变换

python - Numpy:如何迭代地将 3D 数组堆叠在行中?

python - 迭代删除 numpy 数组中的行

python - 测试 Numpy 数组是否包含给定的行

algorithm - 任意矩阵乘法的复杂性

matrix - 关于 CRS 稀疏矩阵存储

python - 为什么返回 tf.py_func/tf.numpy_function 结果的 Lambda 函数没有输出形状,我该如何纠正这种行为?

python - 创建全卷积网络