python - Optimizer.apply_gradients 不会更新 TF 2.0 中的权重

标签 python tensorflow tensorflow2.0

我在 TensorFlow 2.0 中有一个 UNet 实现。该代码运行良好,没有任何错误,但 custom_loss 的值保持不变。

首先,保存模型的 UNet 类,

class UNet( object ):

    def __init__ ( self ):

        def get_weight(shape, name):
            return tf.Variable( tf.random.normal( shape ), name=name, trainable=True , dtype=tf.float32 )

        shapes = [
            [3, 3, 3, 16],
            [3, 3, 16, 16],

            [3, 3, 16, 32],
            [3, 3, 32, 32],

            [3, 3, 32, 64],
            [3, 3, 64, 64],

            [3, 3, 64, 128],
            [3, 3, 128, 128],

            [3, 3, 128, 256],
            [3, 3, 256, 256],

            [3, 3, 128, 384],
            [3, 3, 128, 128],

            [3, 3, 64, 192],
            [3, 3, 64, 64],

            [3, 3, 32, 96],
            [3, 3, 32, 32],

            [3, 3, 16, 48],
            [3, 3, 16, 16],

            [1, 1, 16, 1],
        ]

        self.weights = []
        for i in range( len( shapes ) ) :
            self.weights.append( get_weight( shapes[i] , 'weight{}'.format( i ) ) )
        self.padding = 'SAME'

    def conv2d_down( self , inputs, filters, stride_size):
        out = tf.nn.conv2d(inputs, filters, strides=[1, stride_size, stride_size, 1], padding=self.padding)
        return tf.nn.leaky_relu(out, alpha=0.2)

    def maxpool_down( self, inputs, pool_size, stride_size):
        return tf.nn.max_pool2d(inputs, ksize=[1, pool_size, pool_size, 1], padding='VALID',
                                strides=[1, stride_size, stride_size, 1])

    def conv2d_up( self, inputs, filters, stride_size, output_shape):
        out = tf.nn.conv2d_transpose(inputs, filters, output_shape=output_shape,
                                     strides=[1, stride_size, stride_size, 1], padding=self.padding)
        return tf.nn.leaky_relu(out, alpha=0.2)

    def maxpool_up( self , inputs, size):
        in_dimen = tf.shape(inputs)[1]
        out_dimen = tf.cast(tf.round(in_dimen * size), dtype=tf.int32)
        return tf.image.resize(inputs, [out_dimen, out_dimen], method='nearest')

    def __call__( self , x ):
        c1 = self.conv2d_down( x , self.weights[0], stride_size=1)
        c1 = self.conv2d_down(c1, self.weights[1], stride_size=1)
        p1 = self.maxpool_down(c1, pool_size=2, stride_size=2)

        c2 = self.conv2d_down(p1, self.weights[2], stride_size=1)
        c2 = self.conv2d_down(c2, self.weights[3], stride_size=1)
        p2 = self.maxpool_down(c2, pool_size=2, stride_size=2)

        c3 = self.conv2d_down(p2, self.weights[4], stride_size=1)
        c3 = self.conv2d_down(c3, self.weights[5], stride_size=1)
        p3 = self.maxpool_down(c3, pool_size=2, stride_size=2)

        c4 = self.conv2d_down(p3, self.weights[6], stride_size=1)
        c4 = self.conv2d_down(c4, self.weights[7], stride_size=1)
        p4 = self.maxpool_down(c4, pool_size=2, stride_size=2)

        c5 = self.conv2d_down(p4, self.weights[8], stride_size=1)
        c5 = self.conv2d_down(c5, self.weights[9], stride_size=1)

        p5 = self.maxpool_up(c5, 2)
        concat_1 = tf.concat([p5, c4], axis=-1)
        c6 = self.conv2d_up(concat_1, self.weights[10], stride_size=1, output_shape=[1, 16, 16, 128])
        c6 = self.conv2d_up(c6, self.weights[11], stride_size=1, output_shape=[1, 16, 16, 128])

        p6 = self.maxpool_up(c6, 2)
        concat_2 = tf.concat([p6, c3], axis=-1)
        c7 = self.conv2d_up(concat_2, self.weights[12], stride_size=1, output_shape=[1, 32, 32, 64])
        c7 = self.conv2d_up(c7, self.weights[13], stride_size=1, output_shape=[1, 32, 32, 64])

        p7 = self.maxpool_up(c7, 2)
        concat_3 = tf.concat([p7, c2], axis=-1)
        c8 = self.conv2d_up(concat_3, self.weights[14], stride_size=1, output_shape=[1, 64, 64, 32])
        c8 = self.conv2d_up(c8, self.weights[15], stride_size=1, output_shape=[1, 64, 64, 32])

        p8 = self.maxpool_up(c8, 2)
        concat_4 = tf.concat([p8, c1], axis=-1)
        c9 = self.conv2d_up(concat_4, self.weights[16], stride_size=1, output_shape=[1, 128, 128, 16])
        c9 = self.conv2d_up(c9, self.weights[17], stride_size=1, output_shape=[1, 128, 128, 16])

        output = tf.nn.conv2d(c9, self.weights[18], strides=[1, 1, 1, 1], padding=padding)
        outputs = tf.nn.sigmoid(output)
        return outputs

我已经定义了模型的损失函数和优化器。然后在循环中使用 train 方法来更新权重。

loss = tf.losses.BinaryCrossentropy()
optimizer = tf.optimizers.Adam( learning_rate=0.0001 )

def train( model, inputs , outputs ):
    with tf.GradientTape() as tape:
        current_loss = loss( model( inputs ), outputs )
    grads = tape.gradient( current_loss , model.weights )
    optimizer.apply_gradients( zip( grads , weights ) )
    print( current_loss )

sample_image = tf.Variable( tf.random.normal( [ 1 , 128 , 128 , 3 ]) )
target_image = tf.Variable( tf.random.normal( [ 1 , 128 , 128 , 1 ] ) )

model = UNet()
for i in range( 1000 ):
    train( model , sample_image , target_image )

输出是,

tf.Tensor(3.9185588, shape=(), dtype=float32)
tf.Tensor(3.9185588, shape=(), dtype=float32)
tf.Tensor(3.9185588, shape=(), dtype=float32)
tf.Tensor(3.9185588, shape=(), dtype=float32)
tf.Tensor(3.9185588, shape=(), dtype=float32)

等等。 optimizer.apply_gradients 没有更新权重。

How can correctly update the model.weights parameter? I suspect that the problem in some function which is not differentiable ( I am not sure though ).

最佳答案

所以,我自己解决了这个问题。问题是我使用 tf.random.normal 来初始化权重。当我使用 tf.initializers.glorot_uniform 时,损失开始减少。

initializer = tf.initializers.glorot_uniform()

def get_weight(shape, name):
    return tf.Variable( initializer( shape ), name=name, trainable=True , dtype=tf.float32 )

关于python - Optimizer.apply_gradients 不会更新 TF 2.0 中的权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58568566/

相关文章:

python - python中的模型训练和Golang中的模型运行,模型导入过程中的问题

python - ModuleNotFoundError : No module named 'tensorflow_federated.python.research'

python - azure - 列出 python 中的容器

Python - 将按钮动态添加到 PyQt 中的布局

tensorflow - 如何在 tensorflow 中为非分类对象创建一个类?

python - 无法在python、windows 10 64位中导入tensorflow

python - TensorFlow 优化器中的 _get_hyper 和 _set_hyper 是什么?

python - 如何在opencv中保持图像窗口的纵横比?

python - 斯皮科的API?刮Spokeo

python - tensorflow session.run() 方法如何知道占位符变量的名称?