python - 如何加载带有自定义损失的模型?

标签 python tensorflow keras deep-learning tensorflow2.0

如何加载具有子类 tf.keras.losses.Loss 的自定义损失的模型?

我通过子类化 tf.keras.losses.Loss 来定义 ContrastiveLoss,如下所示:

import tensorflow as tf
from tensorflow.keras.losses import Loss

class ContrastiveLoss(Loss):
    def __init__(self, alpha, square=True, **kwargs):
        super(ContrastiveLoss, self).__init__(**kwargs)
        self.alpha = alpha
        self.square = square

    def get_dists(self, x, y, square):
        dists = tf.subtract(x, y)
        dists = tf.reduce_sum(tf.square(dists), axis=-1)

        if not square:
            zero_mask = tf.cast(tf.equal(dists, 0.0), tf.float32)
            dists = dists + zero_mask * 1e-16
            dists = tf.sqrt(dists)
            nonzero_mask = 1.0 - zero_mask
            dists = dists * nonzero_mask

        return dists

    def call(self, y_true, y_pred):
        # y_true & y_pred shape == (N, #embed), for N mini-batch
        # y_true[:, 0] == (N)
        if len(y_true.shape) == 2: y_true= y_true[:, 0]

        positive_mask = tf.cast(tf.equal( tf.expand_dims(y_true, 0), tf.expand_dims(y_true, 1) ), tf.float32)
        negative_mask = tf.subtract(1.0, positive_mask)   

        all_dists = self.get_dists(tf.expand_dims(y_pred, 1), tf.expand_dims(y_pred, 0), self.square)
        positive_loss = tf.multiply( positive_mask, all_dists )
        negative_loss = tf.multiply( negative_mask, tf.maximum(tf.subtract(self.alpha, all_dists), 0.) )
        contrastive_loss = tf.add( positive_loss, negative_loss )

        valid_doublet_mask = tf.cast( tf.greater(contrastive_loss, 1e-16), tf.float32)
        num_valid_doublet = tf.reduce_sum(valid_doublet_mask)
        contrastive_loss  = tf.reduce_sum( contrastive_loss ) / (num_valid_doublet + 1e-16)       

        return contrastive_loss 

    def get_config(self):
        config = super(ContrastiveLoss, self).get_config()
        config.update({'alpha' : self.alpha, 
                       'square' : self.square})
        return config  

我可以用它训练和保存模型。

但是,当我按如下方式加载模型时,我会收到错误消息。

load_model(model_path, custom_objects={'ContrastiveLoss' : ContrastiveLoss})
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-5-af42cd2404e1> in <module>()
----> 1 load_model(model_path, custom_objects={'ContrastiveLoss' : ContrastiveLoss})

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/saving/save.py in load_model(filepath, custom_objects, compile)
    148   if isinstance(filepath, six.string_types):
    149     loader_impl.parse_saved_model(filepath)
--> 150     return saved_model_load.load(filepath, compile)
    151 
    152   raise IOError(

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/saving/saved_model/load.py in load(path, compile)
     97     if training_config is not None:
     98       model.compile(**saving_utils.compile_args_from_training_config(
---> 99           training_config))
    100   # pylint: disable=protected-access
    101 

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/saving/saving_utils.py in compile_args_from_training_config(training_config, custom_objects)
    232   loss_config = training_config['loss']  # Deserialize loss class.
    233   if isinstance(loss_config, dict) and 'class_name' in loss_config:
--> 234     loss_config = losses.get(loss_config)
    235   loss = nest.map_structure(
    236       lambda obj: custom_objects.get(obj, obj), loss_config)

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/losses.py in get(identifier)
   1184     return deserialize(identifier)
   1185   if isinstance(identifier, dict):
-> 1186     return deserialize(identifier)
   1187   elif callable(identifier):
   1188     return identifier

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/losses.py in deserialize(name, custom_objects)
   1173       module_objects=globals(),
   1174       custom_objects=custom_objects,
-> 1175       printable_module_name='loss function')
   1176 
   1177 

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    290     config = identifier
    291     (cls, cls_config) = class_and_config_for_serialized_keras_object(
--> 292         config, module_objects, custom_objects, printable_module_name)
    293 
    294     if hasattr(cls, 'from_config'):

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
    248     cls = module_objects.get(class_name)
    249     if cls is None:
--> 250       raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
    251 
    252   cls_config = config['config']

ValueError: Unknown loss function: ContrastiveLoss

奇怪的是,如果我使用自定义损失“函数”,在 load_model(.) 期间没有错误

但是在这种情况下,使用 Loss 的“子类”,就会发生错误。

Complete code

最佳答案

如果 javad 建议什么

Have you tried using an object instead of the class name, meaning load_model(model_path, custom_objects={'ContrastiveLoss' : ContrastiveLoss(...)}) where ... has all the parameters for your loss like alpha,...?

不起作用,您只想进行推理,然后尝试使用:

tf.keras.models.load_model("<model_path>", compile=False)

希望对你有帮助。

关于python - 如何加载带有自定义损失的模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59966003/

相关文章:

tensorflow - 如何解释 CNN 中的 model.summary() 输出?

python - 如何在编译时从 C 扩展模块检查 python API 的版本?

python - 更改类变量 : good practice or not? (Python)

python - 我如何知道哪个预测针对哪个数据?那么,如何评估预测呢?

python - 谷歌.protobuf.message.DecodeError : Error parsing message

python - 图形执行中不允许使用 `tf.Tensor` 作为 Python `bool`。使用 Eager execution 或使用 @tf.function 装饰此函数

python - 不同形状箭头上的 CNN 箭头图像分类器

tensorflow - 在 Tensorflow 2 的 fit 方法中使用 Dataset 和 ndarray 有什么区别?

python ggplot不能用两行进行分面换行

python - 将制定良好的标准作为 Python 中的第一类对象传递给内部 np.where() 调用?