lua - torch/nn - 按元素连接张量数组

标签 lua join torch

这个问题的主题是将神经网络的张量与 Lua 的 torch/nn 和 torch/nngraph 库连接起来。几周前我开始使用 Lua 编码,所以我的经验非常少。在下面的文本中,我将 lua 表称为数组。

上下文

我正在使用循环神经网络进行语音识别。 在网络中的某个时刻,有 Nm 个张量数组。

a = {a1, a2, ..., aM},
b = {b1, b2, ..., bM}, 
... N times

其中 aibi 是张量,{} 表示数组。

需要做的是按元素连接所有这些数组,以便 output 是一个 M 张量数组,其中 output[i] 是将第二个维度上的 N 个数组中的每个第 i 个张量连接起来的结果。

output = {z1, z2, ..., zM}

示例

|| 用于表示张量

x = {|1 1|, |2 2|}
     |1 1|  |2 2|
     Tensors of size 2x2

y = {|3 3 3|, |4 4 4|}
     |3 3 3|  |4 4 4|
     Tensors of size 2x3
        |
        | Join{x,y}
        \/
z = {|1 1 3 3 3|, |2 2 4 4 4|}
     |1 1 3 3 3|  |2 2 4 4 4|
     Tensors of size 2x5

因此,大小为 2x2 的 x 的第一个张量与大小为 2x3 的 y 的第一个张量在第二个维度上连接,每个数组的第二个张量也是如此产生 z 2x5 张量数组。

问题

现在这是一个基本的串联,但我似乎无法在 torch/nn 库中找到允许我这样做的模块。我当然可以编写自己的模块,但如果已经存在的模块可以做到这一点,那么我宁愿这样做。

我知道连接表的唯一现有模块是(显然)JoinTable。它需要一系列张量并将它们连接在一起。我想按元素连接张量数组。

此外,当我们向网络提供输入时,N 数组中的张量数量会发生变化,因此上述上下文中的 m 并不是恒定的。

想法

我认为为了使用 JoinTable 模块,我可以将数组转换为张量,然后在转换后的 N 张量上使用 JoinTable 。但话又说回来,我需要一个模块来执行此类转换,并需要另一个模块来转换回数组,以便将其提供给网络的下一层。

最后的手段

编写一个新模块,迭代所有给定的数组并按元素连接。当然这是可行的,但这篇文章的全部目的是找到一种方法来避免编写臭模块。我觉得很奇怪这样的模块还不存在。


结论

我最终决定按照我在最后的手段中写的那样去做。我编写了一个新模块,它迭代所有给定的数组并按元素连接。

不过,@fmguler 给出的答案是相同的,而无需编写新模块。

最佳答案

你可以像这样使用 nn.SelectTable 和 nn.JoinTable 来做到这一点;

require 'nn'

x = {torch.Tensor{{1,1},{1,1}}, torch.Tensor{{2,2},{2,2}}}
y = {torch.Tensor{{3,3,3},{3,3,3}}, torch.Tensor{{4,4,4},{4,4,4}}}

res = {}
res[1] = nn.JoinTable(2):forward({nn.SelectTable(1):forward(x),nn.SelectTable(1):forward(y)})
res[2] = nn.JoinTable(2):forward({nn.SelectTable(2):forward(x),nn.SelectTable(2):forward(y)})

print(res[1])
print(res[2])

如果您希望在模块中完成此操作,请将其包装在 nnGraph 中;

require 'nngraph'

x = {torch.Tensor{{1,1},{1,1}}, torch.Tensor{{2,2},{2,2}}}
y = {torch.Tensor{{3,3,3},{3,3,3}}, torch.Tensor{{4,4,4},{4,4,4}}}

xi = nn.Identity()()
yi = nn.Identity()()
res = {}
--you can loop over columns here>>
res[1] = nn.JoinTable(2)({nn.SelectTable(1)(xi),nn.SelectTable(1)(yi)})
res[2] = nn.JoinTable(2)({nn.SelectTable(2)(xi),nn.SelectTable(2)(yi)})
module = nn.gModule({xi,yi},res)

--test like this
result = module:forward({x,y})
print(result)
print(result[1])
print(result[2])

--gives the result
th> print(result)
{
  1 : DoubleTensor - size: 2x5
  2 : DoubleTensor - size: 2x5
}

th> print(result[1])
 1  1  3  3  3
 1  1  3  3  3
[torch.DoubleTensor of size 2x5]

th> print(result[2])
 2  2  4  4  4
 2  2  4  4  4
[torch.DoubleTensor of size 2x5]

关于lua - torch/nn - 按元素连接张量数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37747810/

相关文章:

android - 如何忽略 corona sdk 文本换行中的 tashkeel 字符计数?

lua - 以交互模式打开 Lua 文件

MySQL根据另一列的值联接一列中显示 '0'的两个表

MySQL - 尽管未在连接中引用,但未找到列

pytorch - Torch.cuda.is_available() 不断切换到 False

python - 如何将 PyTorch 张量的每一行中的重复值清零?

lua - 如何使用lua设置bash环境变量

lua - 如何检查一个值是否大于一个数字并小于Lua中的另一个数字?

php - 使用 Eloquent 在 mysql 中左连接后分别获取相同的命名列

torch - 在windows系统中安装torch 1.0.1.post2