在浏览了一些 Stack 问题和 Keras 文档之后,我设法编写了一些代码,尝试评估神经网络相对于其输入的输出的梯度,目的是近似二元函数的简单练习(f(x,y) = x^2+y^2
) 使用分析微分和自动微分之间的差异作为损失。
结合两个问题( Keras custom loss function: Accessing current input pattern 和 Getting gradient of model output w.r.t weights using Keras )的答案,我想出了这个:
import tensorflow as tf
from keras import backend as K
from keras.models import Model
from keras.layers import Dense, Activation, Input
def custom_loss(input_tensor):
outputTensor = model.output
listOfVariableTensors = model.input
gradients = K.gradients(outputTensor, listOfVariableTensors)
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
evaluated_gradients = sess.run(gradients,feed_dict={model.input:input_tensor})
grad_pred = K.add(evaluated_gradients[0], evaluated_gradients[1])
grad_true = k.add(K.scalar_mul(2, model.input[0][0]), K.scalar_mul(2, model.input[0][1]))
return K.square(K.subtract(grad_pred, grad_true))
input_tensor = Input(shape=(2,))
hidden = Dense(10, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')
这会产生错误:TypeError:提要的值不能是 tf.Tensor 对象。
,因为 feed_dict={model.input:input_tensor}
。我理解这个错误,但我只是不知道如何修复它。
根据我收集的信息,我不能简单地将输入数据传递到损失函数中,它必须是一个张量。我意识到当我调用 input_tensor 时 Keras 会“理解”它。这一切只会让我认为我做事的方式是错误的,试图这样评估梯度。非常感谢一些启发。
最佳答案
我不太明白你为什么想要这个损失函数,但无论如何我都会提供答案。此外,无需评估函数内的梯度(事实上,您将“断开”计算图)。损失函数可以实现如下:
from keras import backend as K
from keras.models import Model
from keras.layers import Dense, Input
def custom_loss(input_tensor, output_tensor):
def loss(y_true, y_pred):
gradients = K.gradients(output_tensor, input_tensor)
grad_pred = K.sum(gradients, axis=-1)
grad_true = K.sum(2*input_tensor, axis=-1)
return K.square(grad_pred - grad_true)
return loss
input_tensor = Input(shape=(2,))
hidden = Dense(10, activation='relu')(input_tensor)
output_tensor = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, output_tensor)
model.compile(loss=custom_loss(input_tensor, output_tensor), optimizer='adam')
关于python - Keras 自定义损失函数内的 TensorFlow session ,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49688134/