python - 渐变剪裁似乎在 None 上窒息

标签 python machine-learning tensorflow

我正在尝试向图表中添加渐变裁剪。我使用了此处推荐的方法:How to effectively apply gradient clipping in tensor flow?

    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    if gradient_clipping:
        gradients = optimizer.compute_gradients(loss)
        clipped_gradients = [(tf.clip_by_value(grad, -1, 1), var) for grad, var in gradients]
        opt = optimizer.apply_gradients(clipped_gradients, global_step=global_step)
    else:
        opt = optimizer.minimize(loss, global_step=global_step)

但是当我打开渐变裁剪时,我得到以下堆栈跟踪:

<ipython-input-19-be0dcc63725e> in <listcomp>(.0)
     61         if gradient_clipping:
     62             gradients = optimizer.compute_gradients(loss)
---> 63             clipped_gradients = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gradients]
     64             opt = optimizer.apply_gradients(clipped_gradients, global_step=global_step)
     65         else:

/home/armence/mlsandbox/venv/lib/python3.4/site-packages/tensorflow/python/ops/clip_ops.py in clip_by_value(t, clip_value_min, clip_value_max, name)
     51   with ops.op_scope([t, clip_value_min, clip_value_max], name,
     52                    "clip_by_value") as name:
---> 53     t = ops.convert_to_tensor(t, name="t")
     54 
     55     # Go through list of tensors, for each value in each tensor clip

/home/armence/mlsandbox/venv/lib/python3.4/site-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, as_ref)
    619     for base_type, conversion_func in funcs_at_priority:
    620       if isinstance(value, base_type):
--> 621         ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
    622         if ret is NotImplemented:
    623           continue

/home/armence/mlsandbox/venv/lib/python3.4/site-packages/tensorflow/python/framework/constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref)
    178                                          as_ref=False):
    179   _ = as_ref
--> 180   return constant(v, dtype=dtype, name=name)
    181 
    182 

/home/armence/mlsandbox/venv/lib/python3.4/site-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name)
    161   tensor_value = attr_value_pb2.AttrValue()
    162   tensor_value.tensor.CopyFrom(
--> 163       tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))
    164   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
    165   const_tensor = g.create_op(

/home/armence/mlsandbox/venv/lib/python3.4/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape)
    344   else:
    345     if values is None:
--> 346       raise ValueError("None values not supported.")
    347     # if dtype is provided, forces numpy array to be the type
    348     # provided if possible.

ValueError: None values not supported.

我该如何解决这个问题?

最佳答案

因此,一个似乎有效的选项是:

    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    if gradient_clipping:
        gradients = optimizer.compute_gradients(loss)

        def ClipIfNotNone(grad):
            if grad is None:
                return grad
            return tf.clip_by_value(grad, -1, 1)
        clipped_gradients = [(ClipIfNotNone(grad), var) for grad, var in gradients]
        opt = optimizer.apply_gradients(clipped_gradients, global_step=global_step)
    else:
        opt = optimizer.minimize(loss, global_step=global_step)

当梯度为零张量且 tf.clip_by_value 不支持 None 值时,compute_gradients 似乎返回 None 而不是零张量。所以不要将 None 传递给它并保留 None 值。

关于python - 渐变剪裁似乎在 None 上窒息,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39295136/

相关文章:

python - tf.argmax() 用于多个索引 Tensorflow

python - 为keras模型添加预处理层并设置张量值

python - 适用于 Ubuntu 和 Windows 的相同 python 库列表

python - Mako过滤器放在哪里

python - Python 字典哈希查找是如何工作的?

machine-learning - 用于文本分类的 nltk naivebayes 分类器

python - 按行中非空元素的计数对 PySpark Dataframe 进行统一分区

machine-learning - tensorflow 中的反卷积(conv2d_transpose)

machine-learning - 将 LibShortText 与 LibSVM 格式的文件一起使用

python - 索引错误: too many indices for array