如何在 jax 中生成 0 到 1 之间的随机数?
基本上,我希望从 jax
中的 numpy
复制以下函数。
np.random.random(1000)
最佳答案
jax 中的等价物是
from jax import random
key = random.PRNGKey(758493) # Random seed is explicit in JAX
random.uniform(key, shape=(1000,))
有关详细信息,请参阅 jax.random
module 的文档.
另请注意,JAX 的随机数生成器不维护任何类型的全局状态,因此您需要以不同于您在 NumPy 中习惯的方式来考虑它。有关这方面的更多背景信息,请参阅 JAX Sharp Bits: Random Numbers .
关于jax - 如何在jax中生成0到1之间的随机数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72320999/