python - Tensorflow - 带有 sample_weight 的自定义损失函数

标签 python tensorflow keras loss-function

我正在尝试运行接受 sample_weights 的自定义函数。我正在关注此文档 https://www.tensorflow.org/api_docs/python/tf/keras/losses/Loss .

但是,当我尝试使用以下成本函数时:

class deltaE(Loss):
  def __call__(self, y_true, y_pred, sample_weight):
    errors = tf_get_deltaE2000(y_true * tf_Xtrain_labels_max, y_pred * tf_Xtrain_labels_max)
    errors *= sample_weight
    return tf.math.reduce_mean(errors, axis=-1)

loss_deltaE = deltaE()

我在 Model.fit 方法上遇到了这个错误。

TypeError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:571 train_function  *
        outputs = self.distribute_strategy.run(
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:951 run  **
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2290 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2649 _call_for_each_replica
        return fn(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:543 train_step  **
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/compile_utils.py:411 update_state
        metric_obj.update_state(y_t, y_p)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/metrics_utils.py:90 decorated
        update_op = update_state_fn(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/metrics.py:603 update_state
        matches = self._fn(y_true, y_pred, **self._fn_kwargs)

    TypeError: __call__() missing 1 required positional argument: 'sample_weight'

我使用的生成器可以根据需要生成长度为 3 的元组。我已经检查过了。一切正常。

成本函数也能正常工作。当我使用下面的代码时,模型训练没有问题。

def loss_deltaE(y_true, y_pred):
  errors = tf_get_deltaE2000(y_true * tf_Xtrain_labels_max, y_pred * tf_Xtrain_labels_max)
  return tf.math.reduce_mean(errors, axis=-1)

如果有人有任何线索。我会很感激。提前致谢!

最佳答案

这是一种将附加参数传递给自定义损失函数的解决方法。诀窍在于使用虚假输入,这些输入有助于以正确的方式构建和使用损失

我在回归问题中提供了一个虚拟示例

def mse(y_true, y_pred, sample_weight):

    error = y_true-y_pred

    return K.mean(K.sqrt(error)*sample_weight)


X = np.random.uniform(0,1, (1000,10))
y = np.random.uniform(0,1, 1000)
W = np.random.uniform(1,2, 1000)

inp = Input((10))
true = Input((1))
sample_weight = Input((1))
x = Dense(32, activation='relu')(inp)
out = Dense(1)(x)

m = Model([inp,true, sample_weight], out)
m.add_loss( mse( true, out, sample_weight ) )
m.compile(loss=None, optimizer='adam')
history = m.fit([X, y, W], y, epochs=10)

# final fitted model to compute predictions
final_m = Model(inp, out)

关于python - Tensorflow - 带有 sample_weight 的自定义损失函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62270283/

相关文章:

python - 如何在 dask 中实现 `groupby` 滚动平均值?

python - key 错误 : 'filename' (Pandas)

python - 倒排索引,我可以在其中保存单词的元组及其来源的 id

python - 如何在范围内为 i 向量化 np.dot(vector_a, vector_b[ :, i]) ?

python - 增加对 Python 中 Unicode 的理解 (2.7)

linux - 如何在后台运行 Keras、tensorflow 程序并保留所有日志?

tensorflow - 加载的 MobileNet 模型给出了错误的预测

python - 尝试使用未初始化的变量 -tensorboard

tensorflow - 如何检查keras tensorflow后端是GPU版本还是CPU版本?

python - Keras:如何在自定义损失中获取张量尺寸?