machine-learning - 调用向后时 nn.CDivTable 抛出错误是否有正当理由?

标签 machine-learning lua neural-network torch

我最近开始使用 Torch 框架和 Lua 脚本语言来研究神经网络。我已经掌握了线性网络的基础知识,所以我尝试了一些更复杂但足够简单的东西:

这个想法是,我有 3 个输入,我必须选择前两个,将它们相除,然后将结果转发到线性模块。所以,我制作了这个小脚本:

require "nn";
require "optim";

local N = 3;

local input = torch.Tensor{
    {1, 2, 3},
    {9, 20, 20},
    {9, 300, 1},
};

local output = torch.Tensor(N);
for i=1, N do
    output[i] = 1;
end

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());

local criterion = nn.BCECriterion();
local params, gradParams = ratioPerceptron:getParameters();
local optimState = {learningRate = 0.01};

local maxIteration = 100000;
for i=1, maxIteration do
    local function f(params)
        gradParams:zero();

        local outputs = ratioPerceptron:forward(input);
        local loss = criterion:forward(outputs, output);
        local dloss_doutputs = criterion:backward(outputs, output);
        ratioPerceptron:backward(input, dloss_doutputs);

        return loss, gradParams;
    end

    optim.sgd(f, params, optimState);
end

当在训练期间调用向后并出现错误时,此操作会失败:

CDivTable.lua:21: both torch.LongStorage and (null) have no addition operator

但是,如果我从顺序模块中删除 CDivTable,并将 nn.Reshape 和 nn.Linear 更改为二维输入(因为我们删除了 CDivTable,它划分两维输入以产生一维输出),如下所示:

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.Reshape(N, 2));
ratioPerceptron:add(nn.Linear(2, 1));
ratioPerceptron:add(nn.Sigmoid());

训练完成,没有错误...是否有其他方法来划分两个选定的输入并将结果转发到线性模块?

最佳答案

模块CDivTable将一个表作为输入,并将第一个表的元素除以第二个表的元素。在这里,您可以将单个输入作为网络的输入,而不是两个输入的表。我相信这就是为什么你会出现 null 错误的原因。 Torch 无法理解您的输入(由两个向量组成)应被视为两个向量的表。它只能看到大小为 2x3 的张量!因此,您必须告诉 Torch 根据输入创建一个表。因此,您可以使用模块 SplitTable(dim) 将输入沿维度 dim 拆分为表。

在窄模块后面插入此行 ratioPerceptron:add(nn.SplitTable(1)):

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.SplitTable(1))
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());

此外,当您遇到此类错误时,我建议您通过放置 print 语句来查看网络计算的内容:插入一行 print(ratioPerceptron:forward(input)) 在添加会产生错误的模块的行之前。

关于machine-learning - 调用向后时 nn.CDivTable 抛出错误是否有正当理由?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43680417/

相关文章:

machine-learning - K-NN 算法如何在 Rapidminer 中以相同距离工作?

python - Scikit-Learn 在 RFECV() 中手动指定 .max_features - 有多少特征得到排名?

lua - 获取表中的最大值

machine-learning - 神经网络中的激活函数

python - 无效参数错误 : input_1_1:0 is both fed and fetched

python - 如何使用 scikit 的 Surprise 进行预测?

python - sklearn.mixture.DPGMM : Unexpected results

parsing - 如何处理 lua 中的未知初始化函数?

c++ - 了解静态链接嵌入式lua环境中lua扩展dll的构建/加载

python - Keras:如何在自定义损失中获取张量尺寸?