这个问题的主题是将神经网络的张量与 Lua 的 torch/nn 和 torch/nngraph 库连接起来。几周前我开始使用 Lua 编码,所以我的经验非常少。在下面的文本中,我将 lua 表称为数组。
上下文
我正在使用循环神经网络进行语音识别。
在网络中的某个时刻,有 N
个 m
个张量数组。
a = {a1, a2, ..., aM},
b = {b1, b2, ..., bM},
... N times
其中 ai
和 bi
是张量,{}
表示数组。
需要做的是按元素连接所有这些数组,以便 output
是一个 M
张量数组,其中 output[i]
是将第二个维度上的 N 个数组中的每个第 i 个张量连接起来的结果。
output = {z1, z2, ..., zM}
示例
||
用于表示张量
x = {|1 1|, |2 2|}
|1 1| |2 2|
Tensors of size 2x2
y = {|3 3 3|, |4 4 4|}
|3 3 3| |4 4 4|
Tensors of size 2x3
|
| Join{x,y}
\/
z = {|1 1 3 3 3|, |2 2 4 4 4|}
|1 1 3 3 3| |2 2 4 4 4|
Tensors of size 2x5
因此,大小为 2x2 的 x
的第一个张量与大小为 2x3 的 y
的第一个张量在第二个维度上连接,每个数组的第二个张量也是如此产生 z
2x5 张量数组。
问题
现在这是一个基本的串联,但我似乎无法在 torch/nn 库中找到允许我这样做的模块。我当然可以编写自己的模块,但如果已经存在的模块可以做到这一点,那么我宁愿这样做。
我知道连接表的唯一现有模块是(显然)JoinTable。它需要一系列张量并将它们连接在一起。我想按元素连接张量数组。
此外,当我们向网络提供输入时,N
数组中的张量数量会发生变化,因此上述上下文中的 m
并不是恒定的。
想法
我认为为了使用 JoinTable 模块,我可以将数组转换为张量,然后在转换后的 N
张量上使用 JoinTable
。但话又说回来,我需要一个模块来执行此类转换,并需要另一个模块来转换回数组,以便将其提供给网络的下一层。
最后的手段
编写一个新模块,迭代所有给定的数组并按元素连接。当然这是可行的,但这篇文章的全部目的是找到一种方法来避免编写臭模块。我觉得很奇怪这样的模块还不存在。
结论
我最终决定按照我在最后的手段中写的那样去做。我编写了一个新模块,它迭代所有给定的数组并按元素连接。
不过,@fmguler 给出的答案是相同的,而无需编写新模块。
最佳答案
你可以像这样使用 nn.SelectTable 和 nn.JoinTable 来做到这一点;
require 'nn'
x = {torch.Tensor{{1,1},{1,1}}, torch.Tensor{{2,2},{2,2}}}
y = {torch.Tensor{{3,3,3},{3,3,3}}, torch.Tensor{{4,4,4},{4,4,4}}}
res = {}
res[1] = nn.JoinTable(2):forward({nn.SelectTable(1):forward(x),nn.SelectTable(1):forward(y)})
res[2] = nn.JoinTable(2):forward({nn.SelectTable(2):forward(x),nn.SelectTable(2):forward(y)})
print(res[1])
print(res[2])
如果您希望在模块中完成此操作,请将其包装在 nnGraph 中;
require 'nngraph'
x = {torch.Tensor{{1,1},{1,1}}, torch.Tensor{{2,2},{2,2}}}
y = {torch.Tensor{{3,3,3},{3,3,3}}, torch.Tensor{{4,4,4},{4,4,4}}}
xi = nn.Identity()()
yi = nn.Identity()()
res = {}
--you can loop over columns here>>
res[1] = nn.JoinTable(2)({nn.SelectTable(1)(xi),nn.SelectTable(1)(yi)})
res[2] = nn.JoinTable(2)({nn.SelectTable(2)(xi),nn.SelectTable(2)(yi)})
module = nn.gModule({xi,yi},res)
--test like this
result = module:forward({x,y})
print(result)
print(result[1])
print(result[2])
--gives the result
th> print(result)
{
1 : DoubleTensor - size: 2x5
2 : DoubleTensor - size: 2x5
}
th> print(result[1])
1 1 3 3 3
1 1 3 3 3
[torch.DoubleTensor of size 2x5]
th> print(result[2])
2 2 4 4 4
2 2 4 4 4
[torch.DoubleTensor of size 2x5]
关于lua - torch/nn - 按元素连接张量数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37747810/