python - 从 TensorFlow 中的截断正态分布中采样

标签 python tensorflow scipy

我正在将一堆计算从 Numpy/Scipy 移植到 TensorFlow,并且我需要从 truncated normal distribution 生成样本,相当于 scipy.stats.truncnorm.rvs()。

我认为生成这些样本的两种标准方法是拒绝采样或将截断的均匀分布样本提供给逆正态累积分布函数,但前者似乎很难在静态中实现计算图(在生成样本之前我们无法知道要运行多少个拒绝循环),并且我认为标准 TensorFlow 库中没有逆正态累积分布函数。

我意识到 TensorFlow 中有一个名为 truncated_normal() 的函数,但它只是在两个标准差处进行剪辑;您无法指定下限和上限。

有什么建议吗?

最佳答案

更简单的方法可能是使用 tf.py_func 将您习惯的函数嵌入到 TensorFlow 图中。如果您需要它,您可以像平常一样使用 numpy.random 模块设置种子:

import numpy as np
import tensorflow as tf
from scipy.stats import truncnorm

np.random.seed(seed=123)

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
size = tf.placeholder(tf.int32)

f = tf.py_func(lambda x, y, s: truncnorm.rvs(x, y, size=s),
               inp=[a, b, size],
                Tout=tf.float64)


with tf.Session() as sess:
    print(sess.run(f, {a:0, b:1, size:10}))

将打印:

[0.63638154 0.24732635 0.19533476 0.49072188 0.66066675 0.37031253
 0.9732229  0.62423404 0.42385328 0.34206036]

关于python - 从 TensorFlow 中的截断正态分布中采样,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48843484/

相关文章:

Python lambda 绑定(bind)到本地值

python - 尝试...除了处理 matplotlib 版本

tensorflow - 在 Keras 中使用扩张卷积

python - 最小化指数平滑中的 alpha

python - 在矩阵中插入一项

python - 在python中添加多个数组

java - 如何从 Java 中的示例对象创建张量?

python - 使用 Dataset api tensorflow 即时生成

python - 如何在 scipy 中以 cgs 单位获取物理常数?

python - 使用 scipy 构建 wav 文件并将其写入磁盘