python - TensorFlow:使用 boolean_mask 有效计算 sqrt

标签 python tensorflow

为了提高效率,我只想计算低于阈值的张量的 sqrt。

例如,在 numpy 中,我有

import numpy as np
x = np.random.random(size=(10e6))
%timeit np.sqrt(x)
-> 10 ms ± 17.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

如果我戴口罩

x_m = x[x < 1e-3]
%timeit np.sqrt(x_m)
-> 8.94 µs ± 20.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

计算速度更快,正如预期的那样,因为 numpy 似乎只计算元素 x < 1e-3 的 sqrt。

但是,在 Tensorflow 中,我无法完成这项工作:

import tensorflow as tf
tf.InteractiveSession()
x_tf = tf.constant(x)
%timeit tf.sqrt(x_tf).eval()
-> 314 ms ± 1.82 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

如果我现在尝试使用 boolean_mask

mask = tf.boolean_mask(x_tf, x_tf < 1e-3)
%timeit tf.sqrt(mask).eval()
-> 341 ms ± 1.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

没有像 numpy 版本那样的加速。看起来 Tensorflow 中的 sqrt 仍然是针对原始张量 x_tf 的所有值计算的。

有没有办法只对屏蔽值运行运算(如 sqrt)?或者,从屏蔽张量中提取较短的张量?

最佳答案

您的措施存在两个问题:

  • 您没有计算 NumPy 中 bool 掩码的比较。
  • 您将在 TensorFlow 中的每次计时试验中创建新的图形节点。

这些应该是更具代表性的时间:

import numpy as np
import tensorflow as tf

np.random.seed(0)
x = np.random.random(size=int(10e6))
%timeit np.sqrt(x)
# 20.4 ms ± 581 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit np.sqrt(x[x < 1e-3])
# 9.96 ms ± 91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

with tf.Graph().as_default(), tf.Session():
    x_tf = tf.constant(x)
    x_tf_sqrt = tf.sqrt(x_tf)
    %timeit x_tf_sqrt.eval()
    # 16.8 ms ± 685 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    mask = tf.boolean_mask(x_tf, x_tf < 1e-3)
    mask_sqrt = tf.sqrt(mask)
    %timeit mask_sqrt.eval()
    # 103 µs ± 43.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

关于python - TensorFlow:使用 boolean_mask 有效计算 sqrt,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55935286/

相关文章:

python - 在 Keras/Tensorflow 自定义损失函数中使用额外的 *trainable* 变量

python - 根据感兴趣的日期范围作为参数输入,限制在 Pig Latin 中加载日志文件

python - 在 2D numpy 数组中找到给定角度的最近项目

python - 设置 Cx_Oracle

python - 修补 Python 包,作为 Pip 的依赖项安装

python - Pandas GroupBy 和日期范围内的平均值

tensorflow - 在 Google Cloud ML Engine 上训练时出现 "Import error:No module named Cython.Build"

tensorflow - 索引错误: index 5 is out of bounds for axis 1 with size 5

batch-file - 如何使用管道解码我的 tiff 图片以输入 tensorflow?

validation - 使用估算器时,将验证监视器替换为 tf.train.SessionRunHook