我正在尝试矢量化以下内容:
n = torch.zeros_like(x)
for i in range(x.shape[0]):
for j in range(x.shape[1]):
for k in range(x.shape[2]):
n[i, j, k] = p[i, x[i, j, k], j, k]
我尝试做类似的事情
n = p[:, x, ...]
但我刚刚收到一个错误,提示我内存不足,这没什么帮助。我认为问题在于,它不是在正确的索引处获取 x 的值,而是尝试对整个 x 进行索引,但我不确定如果这是问题所在,我将如何解决这个问题。
最佳答案
这看起来像是 broadcasted fancy indices 的完美用例。 np.ogrid
是一个很有值(value)的工具,或者您可以手动调整范围:
i, j, k = np.ogrid[:x.shape[0], :x.shape[1], :x.shape[2]]
n = p[i, x, j, k]
这个黑魔法之所以有效,是因为 ogrid
的索引返回三个数组,它们广播成与 x
相同的形状。因此,从 p
中最终提取的内容将具有该形状。之后索引就变得微不足道了。另一种写法是:
i = np.arange(x.shape[0]).reshape(-1, 1, 1)
j = np.arange(x.shape[1]).reshape(1, -1, 1)
k = np.arange(x.shape[2]).reshape(1, 1, -1)
n = p[i, x, j, k]
关于python - 如何在 Python 中向量化这些嵌套循环?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58601366/