Pytorch:通过提取张量行进行一系列串联的计算和内存效率最高的方法?

标签 pytorch

说这是我的样本张量

sample = torch.tensor(
    [[2, 7, 3, 1, 1],
        [9, 5, 8, 2, 5],
        [0, 4, 0, 1, 4],
        [5, 4, 9, 0, 0]]
)

我想要一个新的张量,它将由示例张量中的 2 行串联组成。

所以我有一个张量,其中包含我想要连接成新张量的单行的行号对

cat_indices = torch.tensor([[0, 1], [1, 2], [0, 2], [2, 3]])

我目前使用的方法是这个

torch.cat((sample[cat_indices[:,0]], sample[cat_indices[:,1]]), dim=1)

这给出了期望的结果

tensor([[2, 7, 3, 1, 1, 9, 5, 8, 2, 5],
        [9, 5, 8, 2, 5, 0, 4, 0, 1, 4],
        [2, 7, 3, 1, 1, 0, 4, 0, 1, 4],
        [0, 4, 0, 1, 4, 5, 4, 9, 0, 0]])

这是内存和计算效率最高的方法吗?我不确定,因为我对 cat_indices 进行了两次调用,然后进行了串联操作。

我觉得应该有一种方法可以通过某种 View 来做到这一点。也许是高级索引。我尝试过 sample[cat_indices[:,0], cat_indices[:,1]]sample[cat_indices[0], cat_indices[1]] 但我无法使 View 正确显示。

最佳答案

你的速度应该很快。另一种选择是

sample[cat_indices].reshape(cat_indices.shape[0],-1)

您必须在您的机器上对性能进行基准测试,看看哪个更好。

关于Pytorch:通过提取张量行进行一系列串联的计算和内存效率最高的方法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70854880/

相关文章:

python - 使用 Pytorch Lightning 时如何将指标(例如验证损失)记录到 TensorBoard?

python - 如何在 PyTorch 模型训练后清除 GPU 内存而不重新启动内核

python - 即使安装了 CUDA 和带有 cuda 的 pytorch,Pytorch cuda 也不可用。怎么修?

pytorch - 如何使用 PyTorch DataLoader 进行强化学习?

c++ - PyTorch C++ API中的 `randperm`不应该返回默认类型为int的张量吗?

python - 值错误: expected 2D or 3D input (got 1D input) PyTorch

gpu - 在两个不同的 GPU 上运行相同的深度学习代码时会出现非常奇怪的行为

python - 如何提取两个张量之间不等价条目的索引?

python - PyTorch 相当于 tf.dynamic_partition

python - Pytorch 句子数据加载器