deep-learning - 在训练过程中如何修改人工神经元之间的连接?

标签 deep-learning neural-network artificial-intelligence

背景

在特斯拉自治日,Andrej Karpathy 发表了演讲。在这次演讲中,他说,在开始时,人工神经网络中人工神经元之间的连接是随机初始化的。因此,网络进行随机预测。

后来他说信息通过网络向前流动,然后通过反向传播算法,信息通过网络向后流动。当信息向后流动时,神经元之间的连接被修改。反过来,网络做出越来越好的预测。

此外,我还从网上的一些讲座中了解到,神经元之间的连接(权重)存储在矩阵中。

问题

所以我的问题是连接这两个事实的数学公式或方程式是什么?

也就是说,这些矩阵是如何通过反向传播算法修改的?

最佳答案

这是个好问题。

假设您有一个带有单个隐藏层的简单神经网络

a1 = g(x W1)
a2 = g(a1 W2)
y = g(a2)

在哪里a1是第一层的激活或输出,a2是隐藏层的激活,y是输出。 W1,W2是权重矩阵。 g()是一些非线性“激活函数”,用于使神经网络编码更复杂的关系。另请注意,为了引入偏差元素(线性回归中的 b),按照惯例,我们将 a 的第一个元素设为等于1 : ak[0]=1 .

现在计算 y就是我们所说的“前向流”或“前向传播”。请注意,为了计算 y你只需要输入数据 x .这也称为推理。

但是,为了让网络做出有效的预测,我们必须更新权重矩阵。有多种不同的网络更新方式,但梯度下降是一种常用的方法。

为了使用梯度下降,我们需要某种衡量网络表现如何的指标,或者更准确地说,它的表现有多差。我们定义了一个称为误差函数的函数,它取决于网络的输出以及真实答案。

def error(y, y'):
   # approximately MSE. 
   return sum((y-y').^2)

现在为了更新权重,我们想计算出每个权重对错误分数的贡献有多大,并以一种使贡献更小的方式更新它。这就是梯度下降的用武之地

W -= alpha * diff(error(y,y'), W)

在哪里alpha是一些小的正实数值,diff(f,x)f 的导数关于 x ,或者在这种情况下,相对于权重矩阵的误差。

但是我们如何计算导数呢?这就是反向传播的用武之地。Tensorflow 和 Pytorch 等软件包不是分析地寻找导数,而是使用一种称为自动微分的东西来自动化该过程。

这个想法是递归地应用链式法则,直到您将复杂的误差函数分解为您已手动计算其导数的原子函数之和。由于这是从最后一层开始并通过网络向后流动,所以它被称为反向传播,尽管我认为它有点误导。

回想一下链式规则是

(f(g(x)))' = f'(g(x)) g'(x)

一旦你破坏了复合函数 f(g(x))向下,如果您不知道如何计算 f 的导数,您可以再次应用链式法则或 g .


总而言之,你问了

  1. 所以我的问题是连接这两个事实的数学公式或方程式是什么?

信息在推理过程中“向前流动”(前向传播),因为您从第一层的输入开始,到最后一层的输出结束。

当从最后一层开始分解导数时,信息“向后流动”,通过网络向后移动到第一层。

  1. 也就是说,这些矩阵是如何通过反向传播算法修改的?

使用称为梯度下降的算法修改矩阵,该算法尝试迭代地使权重矩阵相对于某些误差函数更好一些。为了做到这一点,它需要误差函数的梯度,为此我们使用自动微分,我们递归地分解从最后一层开始向后移动到第一层的导数,因此称为反向传播。

关于deep-learning - 在训练过程中如何修改人工神经元之间的连接?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59316632/

相关文章:

python - 如何在 Keras 中仅获取序列模型的最后一个输出?

python-3.x - 无效参数错误 : logits and labels must have the same first dimension seq2seq Tensorflow

java - Deeplearning4j 神经网络配置

python - 你能帮我在 pyBrain 中线性激活我的简单分类器神经网络吗?

amazon-web-services - 尽管输入了准确的话语,AWS Lex 仍匹配错误的意图

python - MultiheadAttention 中的 attn_output_weights

machine-learning - 准确率有所提高,但在许多时期内保持不变

tensorflow - 带有 train_on_batch 的优化器?

python - 为什么我的 Deep Q Network 没有掌握一个简单的 Gridworld (Tensorflow)? (如何评估 Deep-Q-Net)

python - tensorflow 2 :NotImplementedError: numpy() is only available when eager execution is enabled