python - 如何为 GPU 设备指定或设置变量

标签 python deep-learning gpu multi-gpu jax

我是 JAX 新手,我想使用多个 GPU。 到目前为止,我的 JAX 可以看到两个 GPU(0 和 1)。

import jax
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
print(jax.local_devices())
>>>
# prints: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]

当我创建 NumPy 对象时,它将始终位于 GPU 设备 0 中,我认为这是默认设备。

nmp = jax.numpy.ones(4)
print(nmp.device())
>>>
# Prints: gpu:0

如何将变量 nmp 发送到另一个 GPU gpu:1 中?

最佳答案

使用.device_put()

import jax
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
devices = jax.local_devices()
print(devices) # >>> [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]

nmp = jax.numpy.ones(4)
print(nmp.device()) # >>> gpu:0

nmp = jax.device_put(nmp, jax.devices()[1])
print(nmp.device()) # >>> gpu:1

关于python - 如何为 GPU 设备指定或设置变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/73456614/

相关文章:

python - pytorch .cuda() 无法将张量获取到 cuda

memory-leaks - 从 Process Explorer 解释 GPU 信息

python - 带 pandas 的水平条形图左侧的输出表

python - gspread update_cells 写入错误的位置

deep-learning - keras 输入层(Nnoe,200,3),为什么没有?输入有3维,但得到了形状为(200, 3)的数组

python - 将特定的 TensorFlow 变量恢复到特定层(按名称恢复)

python - 检查两个句子是否包含Python中常见的palgiarist单词

python - 使用递归生成分形正方形

python - 如何下载本地镜像集以与 Keras 一起使用?

python - Keras - 如何在 CPU 上运行加载模型