jax - 如何在jax中生成0到1之间的随机数?

标签 jax

如何在 jax 中生成 0 到 1 之间的随机数? 基本上,我希望从 jax 中的 numpy 复制以下函数。

np.random.random(1000)

最佳答案

jax 中的等价物是

from jax import random
key = random.PRNGKey(758493)  # Random seed is explicit in JAX
random.uniform(key, shape=(1000,))

有关详细信息,请参阅 jax.random module 的文档.

另请注意,JAX 的随机数生成器不维护任何类型的全局状态,因此您需要以不同于您在 NumPy 中习惯的方式来考虑它。有关这方面的更多背景信息,请参阅 JAX Sharp Bits: Random Numbers .

关于jax - 如何在jax中生成0到1之间的随机数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72320999/

相关文章:

tensorflow - 简单来说,JAX、Trax 和 TensorRT 之间有什么区别?

python - Websockets 消息仅在最后发送,而不是在使用 async/await 的实例中发送,在嵌套 for 循环中产生

jax - 作为并行处理模型,xmap 与 pmap 有什么区别?

python - 给定批处理的样本找到 x,y 位置 - python, jax

python - JAX 与 JIT 和自定义差异化

python - 将稀疏矩阵与 jax 结合使用

python - 创建一个 3D 零张量,在 numpy/jax 中的每个切片上随机放置一个 '1'

python - Jax - sigmoid 的 autograd 总是返回 nan