我是 JAX 新手,我想使用多个 GPU。 到目前为止,我的 JAX 可以看到两个 GPU(0 和 1)。
import jax
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
print(jax.local_devices())
>>>
# prints: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]
当我创建 NumPy 对象时,它将始终位于 GPU 设备 0 中,我认为这是默认设备。
nmp = jax.numpy.ones(4)
print(nmp.device())
>>>
# Prints: gpu:0
如何将变量 nmp
发送到另一个 GPU gpu:1
中?
最佳答案
import jax
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
devices = jax.local_devices()
print(devices) # >>> [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]
nmp = jax.numpy.ones(4)
print(nmp.device()) # >>> gpu:0
nmp = jax.device_put(nmp, jax.devices()[1])
print(nmp.device()) # >>> gpu:1
关于python - 如何为 GPU 设备指定或设置变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/73456614/