我正在尝试获取基于不同组的张量列表,
例如,
x = tensor([ 0.3018, -0.0079, 1.4995, -1.4422, 1.6007])
indices = torch.tensor([0,0,1,1,2])
res = func(x,indices)
我希望我的结果是
res= [[0.3018, -0.0079], [1.4995, -1.4422], [1.6007]]
我想知道如何才能达到这个结果,我检查了gather
和index_select
,
但我无法得到上面那样的结果。
谢谢!
最佳答案
怎么样
res = [x[indices == i_] for i_ in indices.unique()]
关于pytorch - 从掩码索引获取张量列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66997166/