说这是我的样本张量
sample = torch.tensor(
[[2, 7, 3, 1, 1],
[9, 5, 8, 2, 5],
[0, 4, 0, 1, 4],
[5, 4, 9, 0, 0]]
)
我想要一个新的张量,它将由示例张量中的 2 行串联组成。
所以我有一个张量,其中包含我想要连接成新张量的单行的行号对
cat_indices = torch.tensor([[0, 1], [1, 2], [0, 2], [2, 3]])
我目前使用的方法是这个
torch.cat((sample[cat_indices[:,0]], sample[cat_indices[:,1]]), dim=1)
这给出了期望的结果
tensor([[2, 7, 3, 1, 1, 9, 5, 8, 2, 5],
[9, 5, 8, 2, 5, 0, 4, 0, 1, 4],
[2, 7, 3, 1, 1, 0, 4, 0, 1, 4],
[0, 4, 0, 1, 4, 5, 4, 9, 0, 0]])
这是内存和计算效率最高的方法吗?我不确定,因为我对 cat_indices
进行了两次调用,然后进行了串联操作。
我觉得应该有一种方法可以通过某种 View 来做到这一点。也许是高级索引。我尝试过 sample[cat_indices[:,0], cat_indices[:,1]]
或 sample[cat_indices[0], cat_indices[1]]
但我无法使 View 正确显示。
最佳答案
你的速度应该很快。另一种选择是
sample[cat_indices].reshape(cat_indices.shape[0],-1)
您必须在您的机器上对性能进行基准测试,看看哪个更好。
关于Pytorch:通过提取张量行进行一系列串联的计算和内存效率最高的方法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70854880/