python-3.x - PyTorch 中的左移张量

标签 python-3.x torch

我有一个张量 a形状(1, N, 1) .我需要沿维度左移张量 1并添加一个新值作为替换。我找到了一种方法来完成这项工作,以下是代码。

a = torch.from_numpy(np.array([1, 2, 3]))
a = a.unsqueeze(0).unsqeeze(2)  # (1, 3, 1), my data resembles this shape, therefore the two unsqueeze
# want to left shift a along dim 1 and insert a new value at the end
# I achieve the required shifts using the following code
b = a.squeeze
c = b.roll(shifts=-1)
c[-1] = 4
c = c.unsqueeze(0).unsqueeze(2)
# c = [[[2], [3], [4]]]
我的问题是,有没有更简单的方法来做到这一点?谢谢。

最佳答案

你实际上不需要挤压和执行你的操作,然后解压你的输入张量 a .相反,您可以直接执行这两个操作,如下所示:

# No need to squeeze
c = torch.roll(a, shifts=-1, dims=1)
c[:,-1,:] = 4
# No need to unsqeeze
# c = [[[2], [3], [4]]]

关于python-3.x - PyTorch 中的左移张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64495404/

相关文章:

Python 3 多处理 - 我应该使用多少个进程?

python - 如何使用 python 在discord.py 上制作排行榜命令?

python-3.x - 使用 read_parquet 从 Parquet 文件中获取带有分类列的 Pandas DataFrame?

python - 从 python 调用 torch7 ( Lua ) 函数?

lua - Torch7-内存不足: you tried to allocate 0GB.购买新的RAM

python - Torch 中的 view() 和 unsqueeze() 有什么区别?

python-3.x - pandas DatetimeIndex 到 matplotlib x-ticks

linux - 当有两个 gpu 时,如何设置 Torch 只使用一个 gpu?

python - conv2d 之后的 PyTorch CNN 线性层形状

生成自定义类集合的 Pythonic 多重继承