假设我想对一个矩阵进行采样,其中每个条目都是从另一个矩阵中的条目定义的分布中采样的。我展开矩阵并将 map_fn 应用于每个元素。对于相对较小的矩阵 (128 x 128),以下内容给了我几个 PoolAllocator 警告 (GTX TITAN Black),并且不会在任何合理的时间内进行训练。
def sample(x):
samples = tf.map_fn(lambda z:
tf.random_normal([1], mean=z,
stddev=tf.sqrt(z * (1 - z))),
tf.reshape(x, [-1])) # apply to each element
return tf.cond(is_training, lambda: tf.reshape(samples, shape=tf.shape(x)),
lambda: tf.tanh(x))
是否有更好的方法来应用这样的元素运算?
最佳答案
如果您可以使用一次张量运算而不是像 tf.map_fn 这样的元素运算,您的代码将会运行得更快。
这里看起来您想要从每个元素的正态分布中进行采样,其中输入张量中每个值的分布参数都不同。尝试这样的事情:
def sample(x):
samples = tf.random_normal(shape=[128, 128]) * tf.sqrt(x * (1 - x)) + x
tf.random_normal() 默认生成平均值为 0.0、标准差为 1.0 的正态分布。您可以使用逐点张量运算来修复每个元素的标准差(通过相乘)和平均值(通过相加)。事实上,如果你看看 tf.random_normal() 是如何实现的,你就会发现这正是它内部所做的。
(您可能还可以更好地使用 Python 条件来区分训练和测试时间。)
如果你打算经常做这种事情,你可以在 github 上提交一个功能请求,要求泛化 tf.random_normal 以接受具有更通用形状的张量 mean
和 stddev
。我认为没有理由不支持这一点。
希望有帮助!
关于numpy - 使用 map_fn Slow 进行元素采样,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38135999/