背景:
我正在开发一个程序,该程序首先沿“列”维度以不同的距离移动张量的不同 channel ,然后沿“ channel ”维度执行求和以将不同维度合并为一个。具体来说,给定大小为 (B,C,H,W) 和步长为 S 的张量 x,其中 B、C、H、W 分别表示批量大小、 channel 数、高度和宽度,即第 i 个 channel x 平移距离(i-1)*S,然后将C 个 channel 求和为1。
这是一个一维玩具示例。 假设我有一个 3 channel 张量 x 为
x = torch.tensor(
[[1,1,1],
[2,2,2],
[3,3,3]]
)
现在我将步长设置为1,然后对张量执行平移
x_shifted = torch.tensor(
[[1,1,1,0,0],
[0,2,2,2,0],
[0,0,3,3,3]]
)
这里,第一个 channel 移动了距离 0,第二个 channel 移动了距离 1,第三个 channel 移动了距离 2。 最后,将所有三个 channel 相加并合并为一个 channel
y = torch.tensor(
[[1,3,6,5,3]]
)
问题:
我已经实现了原始流程。二维图像张量如下代码:
import torch
import torch.nn.functional as F
from time import time
#############################################
# Parameters
#############################################
B = 16
C = 28
H = 256
W = 256
S = 2
T = 1000
device = torch.device('cuda')
seed = 2023
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
#############################################
# Method 1
#############################################
alpha = torch.zeros(B, 1, 1, W+(C-1)*S, device=device)
for i in range(C):
alpha[..., (i*S):(i*S+W)] += 1
def A(x, mask):
z = x * mask
y = torch.zeros(B, 1, H, W+(C-1)*S, device=x.device)
for i in range(C):
y[..., (i*S):(i*S+W)] += z[:, (i):(i+1)]
return y
def A_pinv(y, mask):
z = y / alpha.to(y.device)
x = torch.cat([z[..., (i*S):(i*S+W)] for i in range(C)], dim=1) / mask
return x
#############################################
# Method 2
#############################################
kernel = torch.zeros(1, C, 1, (C-1)*S+1, device=device)
for i in range(C):
kernel[:, C-i-1, :, i*S] = 1
def A_fast(x, mask):
return F.conv2d(x * mask, kernel.to(x.device), padding=(0, (C-1)*S))
def A_pinv_fast(y, mask):
return F.conv_transpose2d(y / alpha.to(y.device), kernel, padding=(0, (C-1)*S)) / mask
#############################################
# Test 1
#############################################
start_time = time()
MAE = 0
for i in range(T):
x = torch.rand(B, C, H, W, device=device)
mask = torch.rand(1, 1, H, W, device=device)
mask[mask == 0] = 1e-12
y = A(x, mask)
x_init = A_pinv(y, mask)
y_init = A(x_init, mask)
MAE += (y_init - y).abs().mean().item()
MAE /= T
end_time = time()
print('---')
print('Test 1')
print('Running Time:', end_time - start_time)
print('MAE:', MAE)
#############################################
# Test 2
#############################################
start_time = time()
MAE = 0
for i in range(T):
x = torch.rand(B, C, H, W, device=device)
mask = torch.rand(1, 1, H, W, device=device)
mask[mask == 0] = 1e-12
y = A_fast(x, mask)
x_init = A_pinv_fast(y, mask)
y_init = A_fast(x_init, mask)
MAE += (y_init - y).abs().mean().item()
MAE /= T
end_time = time()
print('---')
print('Test 2')
print('Running Time:', end_time - start_time)
print('MAE:', MAE)
这里,方法 1
使用 for
循环实现该过程,而我相信方法 2
通过使用 2D 等效地实现该过程卷积运算。
更具体地说,函数A
和A_pinv
分别实现了转发压缩过程及其“伪逆”。 方法 2
中的“快速”版本预计通过并行实现会更快。
但是,当我运行代码时,我发现方法1
仍然比方法2
快很多,并且速度领先。我想问的是:
我们能否有效加速方法1
?更具体地说,我想知道我们是否可以并行化 for
循环,以使“Shift+Summation”过程更快?
最佳答案
大内核卷积不一定高效。
torch.scatter_add_
可以直接对移位后的元素求和。
我没有写伪逆(我认为这是为了检查正确性?我将这个新方法与你的方法1/方法2进行了比较)。
out_W = W + (C-1)*S
i_list = torch.arange(C, dtype=torch.long, device=device)
y_list = torch.arange(H, dtype=torch.long, device=device)
x_list = torch.arange(W, dtype=torch.long, device=device)
indices = x_list + i_list.view(C, 1, 1)*S + y_list.view(1, H, 1)*(out_W)
indices = indices.view(1, C*H*W).expand(B, C*H*W)
"""
functionally equivalent to:
for i in range(C):
for y in range(H):
for x in range(W):
indices[i*H*W+y*W+x] = x + i*S + y*(out_W)
"""
def A_faster(x, mask):
y = torch.zeros(B, H*out_W, device=x.device)
y.scatter_add_(1, indices, (x*mask).view(B, C*H*W))
return y.view(B, 1, H, out_W)
令人惊讶的是,您的方法 1 即使对于较大的 C
也能很好地发挥作用(或者分散不能很好地扩展)。
对于C=28
:
---
Test 1
Running Time: 1.4626126289367676
---
Test 2
Running Time: 2.808514356613159
---
Test 3
Running Time: 1.3663663864135742
---
|Test1 - Test2|: tensor(9.2172e-07, device='cuda:0')
---
|Test1 - Test3|: tensor(7.5425e-09, device='cuda:0')
---
|Test2 - Test3|: tensor(9.2173e-07, device='cuda:0')
对于 C=512
(方法 2 因速度太慢而被跳过):
---
Test 1
Running Time: 27.37247085571289
---
Test 3
Running Time: 24.335933446884155
---
|Test1 - Test3|: tensor(3.9411e-08, device='cuda:0')
完整测试代码:
import torch
import torch.nn.functional as F
from time import time
#############################################
# Parameters
#############################################
B = 16
C = 28
H = 256
W = 256
S = 2
T = 1000
device = torch.device('cuda')
seed = 2023
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
#############################################
# Method 1
#############################################
alpha = torch.zeros(B, 1, 1, W+(C-1)*S, device=device)
for i in range(C):
alpha[..., (i*S):(i*S+W)] += 1
def A(x, mask):
z = x * mask
y = torch.zeros(B, 1, H, W+(C-1)*S, device=x.device)
for i in range(C):
y[..., (i*S):(i*S+W)] += z[:, (i):(i+1)]
return y
def A_pinv(y, mask):
z = y / alpha.to(y.device)
x = torch.cat([z[..., (i*S):(i*S+W)] for i in range(C)], dim=1) / mask
return x
#############################################
# Method 2
#############################################
kernel = torch.zeros(1, C, 1, (C-1)*S+1, device=device)
for i in range(C):
kernel[:, C-i-1, :, i*S] = 1
def A_fast(x, mask):
return F.conv2d(x * mask, kernel.to(x.device), padding=(0, (C-1)*S))
def A_pinv_fast(y, mask):
return F.conv_transpose2d(y / alpha.to(y.device), kernel, padding=(0, (C-1)*S)) / mask
#############################################
# Method 3
#############################################
out_W = W + (C-1)*S
i_list = torch.arange(C, dtype=torch.long, device=device)
y_list = torch.arange(H, dtype=torch.long, device=device)
x_list = torch.arange(W, dtype=torch.long, device=device)
indices = x_list + i_list.view(C, 1, 1)*S + y_list.view(1, H, 1)*(out_W)
indices = indices.view(1, C*H*W).expand(B, C*H*W)
"""
functionally equivalent to:
for i in range(C):
for y in range(H):
for x in range(W):
indices[i*H*W+y*W+x] = x + i*S + y*(out_W)
"""
def A_faster(x, mask):
y = torch.zeros(B, H*out_W, device=x.device)
y.scatter_add_(1, indices, (x*mask).view(B, C*H*W))
return y.view(B, 1, H, out_W)
#############################################
# Test 1
#############################################
torch.cuda.synchronize()
start_time = time()
for i in range(T):
x = torch.rand(B, C, H, W, device=device)
mask = torch.rand(1, 1, H, W, device=device)
mask[mask == 0] = 1e-12
y = A(x, mask)
torch.cuda.synchronize()
end_time = time()
print('---')
print('Test 1')
print('Running Time:', end_time - start_time)
#############################################
# Test 2
#############################################
torch.cuda.synchronize()
start_time = time()
for i in range(T):
x = torch.rand(B, C, H, W, device=device)
mask = torch.rand(1, 1, H, W, device=device)
mask[mask == 0] = 1e-12
y = A_fast(x, mask)
torch.cuda.synchronize()
end_time = time()
print('---')
print('Test 2')
print('Running Time:', end_time - start_time)
#############################################
# Test 3
#############################################
torch.cuda.synchronize()
start_time = time()
for i in range(T):
x = torch.rand(B, C, H, W, device=device)
mask = torch.rand(1, 1, H, W, device=device)
mask[mask == 0] = 1e-12
y = A_faster(x, mask)
torch.cuda.synchronize()
end_time = time()
print('---')
print('Test 3')
print('Running Time:', end_time - start_time)
error = 0
for _ in range(T):
error += (A(x, mask) - A_fast(x, mask)).abs().mean()
error /= T
print('---')
print('|Test1 - Test2|: ', error)
error = 0
for _ in range(T):
error += (A(x, mask) - A_faster(x, mask)).abs().mean()
error /= T
print('---')
print('|Test1 - Test3|: ', error)
error = 0
for _ in range(T):
error += (A_fast(x, mask) - A_faster(x, mask)).abs().mean()
error /= T
print('---')
print('|Test2 - Test3|: ', error)
关于python - 并行化并加速张量加法循环,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/76963279/