tensorflow - tensorflow 中的 stop_gradient

标签 tensorflow tensorflow-gradient

我想知道 tf.stop_gradient 是否停止给定操作的梯度计算,或者停止其输入 tf.variable 的更新?我有以下问题 - 在 MNIST 中的前向路径计算期间,我想对权重执行一组操作(假设 W 到 W*),然后对输入进行 matmul。但是,我想从后向路径中排除这些操作。我只想在反向传播训练期间计算 dE/dW。我编写的代码阻止 W 更新。你能帮我理解为什么吗?如果这些是变量,我知道我应该将它们的可训练属性设置为 false,但这些是对权重的操作。如果 stop_gradient 不能用于此目的,那么如何构建两个图,一个用于前向路径,另一个用于反向传播?

def build_layer(inputs, fmap, nscope,layer_size1,layer_size2, faulty_training):  
  with tf.name_scope(nscope): 
    if (faulty_training):
      ## trainable weight
      weights_i = tf.Variable(tf.truncated_normal([layer_size1, layer_size2],stddev=1.0 / math.sqrt(float(layer_size1))),name='weights_i')
      ## Operations on weight whose gradient should not be computed during backpropagation
      weights_fx_t = tf.multiply(268435456.0,weights_i)
      weight_fx_t = tf.stop_gradient(weights_fx_t)
      weights_fx = tf.cast(weights_fx_t,tf.int32)
      weight_fx = tf.stop_gradient(weights_fx)
      weights_fx_fault = tf.bitwise.bitwise_xor(weights_fx,fmap)
      weight_fx_fault = tf.stop_gradient(weights_fx_fault)
      weights_fl = tf.cast(weights_fx_fault, tf.float32)
      weight_fl = tf.stop_gradient(weights_fl)
      weights = tf.stop_gradient(tf.multiply((1.0/268435456.0),weights_fl))
      ##### end transformation
    else:
      weights = tf.Variable(tf.truncated_normal([layer_size1, layer_size2],stddev=1.0 / math.sqrt(float(layer_size1))),name='weights')


    biases = tf.Variable(tf.zeros([layer_size2]), name='biases')
    hidden = tf.nn.relu(tf.matmul(inputs, weights) + biases)
    return weights,hidden

我正在使用 tensorflow 梯度下降优化器进行训练。

optimizer = tf.train.GradientDescentOptimizer(learning_rate) 
global_step = tf.Variable(0, name='global_step', trainable=False) 
train_op = optimizer.minimize(loss, global_step=global_step)

最佳答案

停止梯度将阻止反向传播继续经过图中的该节点。除了经过梯度停止处的weights_fx_t 的路径之外,您的代码没有任何从weights_i 到损失的路径。这就是导致weights_i 在训练期间不更新的原因。您不需要在每个步骤之后放置 stop_gradient 。仅使用一次就会停止那里的反向传播。

如果 stop_gradient 没有执行您想要的操作,那么您可以通过执行 tf.gradients 来获取渐变,并且您可以使用 编写自己的更新操作tf.分配。这将允许您根据需要更改渐变。

关于tensorflow - tensorflow 中的 stop_gradient,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50221783/

相关文章:

tensorflow - 为什么我必须在 tensorflow 中为线性回归打乱输入数据

python - 使用 tensorflow ,如何平均多个批处理的参数梯度值并使用该平均值进行更新?

python - 我可以在不应用输入的情况下获得张量相对于输入的梯度吗?

python - 离线安装时找不到tensorboard

python - 导入错误 : cannot import name 'transpose_shape'

tensorflow - 权重矩阵提供什么信息?

python - 使用 Tensorflow Eager Execution 的 OpenAI 梯度检查点

python - 如何用张量板监控keras中的梯度消失和爆炸?

tensorflow - 像在 Pytorch 中一样在 Tensorflow 中屏蔽零填充嵌入(并返回零梯度)