python - 语义分割训练时 Keras 损失为 NaN

标签 python keras deep-learning computer-vision semantic-segmentation


所有 mask 图像都是单 channel 。这是我的代码:

image_size = 512
batch = 4
labels = 14
data_directory = "/content/headsegmentation_final/"
sample_train_images = len(os.listdir(data_directory + 'Training/Images/')) - 1
sample_validation_images = len(os.listdir(data_directory + 'Validation/Images/')) - 1
test_images = len(os.listdir('/content/headsegmentation_final/Test/')) - 1

t_images = sorted(glob(os.path.join(data_directory, "Training/Images/*")))[:sample_train_images]
t_masks = sorted(glob(os.path.join(data_directory, "Training/Category_ids/*")))[:sample_train_images]
v_images = sorted(glob(os.path.join(data_directory, "Validation/Images/*")))[:sample_validation_images]
v_masks = sorted(glob(os.path.join(data_directory, "Validation/Category_ids/*")))[:sample_validation_images]
ts_images = sorted(glob(os.path.join(data_directory, "Test/*")))[:test_images]

def image_augmentation(img, random_range):
    img = tf.image.random_flip_left_right(img)
    img = tfa.image.rotate(img, random_range)

    return img

def image_process(path, mask=False):
    img =

    upper = 90 * (math.pi/180.0) # degrees -> radian
    lower = 0 * (math.pi/180.0)
    ran_range = random.uniform(lower, upper)

    if mask == True:
        img = tf.image.decode_png(img, channels=1)
        img.set_shape([None, None, 1])
        img = tf.image.resize(images=img, size=[image_size, image_size])
        #img = image_augmentation(img, ran_range)

        img = tf.image.decode_jpeg(img, channels=3)
        img.set_shape([None, None, 3])
        img = tf.image.resize(images=img, size=[image_size, image_size])
        img = img / 127.5 - 1
        #img = image_augmentation(img, ran_range)

    return img

def data_loader(image_list, mask_list):
    img = image_process(image_list)
    mask = image_process(mask_list, mask=True)
    return img, mask

def data_generator(image_list, mask_list):

    cihp_dataset =, mask_list))
    cihp_dataset =,
    cihp_dataset = cihp_dataset.batch(batch, drop_remainder=True)

    return cihp_dataset

train_dataset = data_generator(t_images, t_masks)
val_dataset = data_generator(v_images, v_masks)

def block(block_input, filters = 256, kernel = 3, dilation = 1, padding = "same", use_bias = False,):
    x = layers.Conv2D(filters, kernel_size = kernel, dilation_rate = dilation, padding = "same", use_bias = use_bias, kernel_initializer = keras.initializers.HeNormal(),)(block_input)
    x = layers.BatchNormalization()(x)

    return tf.nn.relu(x)

def DSP_pooling(dsp_pooling_input):
    dims = dsp_pooling_input.shape
    x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dsp_pooling_input)
    x = block(x, kernel = 1, use_bias = True)
    pool_output = layers.UpSampling2D(size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",)(x)

    block_output_1 = block(dsp_pooling_input, kernel=1, dilation=1)
    block_out_6 = block(dsp_pooling_input, kernel=3, dilation=6)
    block_out_12 = block(dsp_pooling_input, kernel=3, dilation=12)
    block_out_18 = block(dsp_pooling_input, kernel=3, dilation=18)

    x = layers.Concatenate(axis=-1)([pool_output, block_output_1, block_out_6, block_out_12, block_out_18])

    output = block(x, kernel=1)
    return output

def DeepLabV3_ResNet50(size, classes):
    input = keras.Input(shape=(size, size, 3))

    resnet50 = keras.applications.ResNet50(weights="imagenet", include_top=False, input_tensor = input)
    x = resnet50.get_layer("conv4_block6_2_relu").output
    x = DSP_pooling(x)

    a = layers.UpSampling2D(size=(size // 4 // x.shape[1], size // 4 // x.shape[2]),interpolation="bilinear",)(x)
    b = resnet50.get_layer("conv2_block3_2_relu").output
    b = block(b, filters = 48, kernel = 1)

    x = layers.Concatenate(axis=-1)([a, b])
    x = block(x)
    x = block(x)
    x = layers.UpSampling2D(size=(size // x.shape[1], size // x.shape[2]),interpolation="bilinear",)(x)

    output = layers.Conv2D(classes, kernel_size=(1, 1), padding="same")(x)

    return keras.Model(inputs = input, outputs = output)

model = DeepLabV3_ResNet50(size = image_size, classes = labels)

def scheduler(epoch, lr):
  if epoch < 10:
    return lr
    return lr * tf.math.exp(-0.1)

loss = keras.losses.SparseCategoricalCrossentropy(from_logits = True)

model.compile(optimizer=keras.optimizers.Adam(), loss=loss, metrics=["accuracy"])
round(, 5)

callback = tf.keras.callbacks.LearningRateScheduler(scheduler)

history =, validation_data = val_dataset, epochs = 25, callbacks = [callback], verbose=1)
round(, 5)


Epoch 1/25
1404/1404 [==============================] - 342s 232ms/step - loss: nan - accuracy: 0.5888 - val_loss: nan - val_accuracy: 0.4956 - lr: 0.0010
Epoch 2/25
1404/1404 [==============================] - 323s 230ms/step - loss: nan - accuracy: 0.5892 - val_loss: nan - val_accuracy: 0.4956 - lr: 0.0010
Epoch 3/25
1404/1404 [==============================] - 323s 230ms/step - loss: nan - accuracy: 0.5892 - val_loss: nan - val_accuracy: 0.4956 - lr: 0.0010


我在使用 DeepLabV3+ 时也遇到了同样的问题。首先,您可能想查看这个网站因为他们有与您类似的代码并使用相同的 CIHP 数据集。


该问题可能是由于掩码中的实际标签超出您分配的类或标签的问题造成的。例如,您在这里指定了 14 作为类/标签的数量,但在掩码中,实际上应该有超过 14 个标签,因此您会得到 NaN 损失。我的情况就是这样。您应该将模型中使用的标签/类的数量调整为掩模数据集中现有的数量。以下是您可以如何做到的:

from skimage import io
import numpy as np
# Check labels for all masks
def check_mask_labels(masks):
    # Create an empty set
    unique_labels_len = set()
    # Iterate over all mask dataset
    for mask in masks:
        # Read mask
        test_mask = io.imread(mask)
        # Find unique labels in the mask
        unique_labels = np.unique(test_mask)
        # Find the total number of unique labels
        len_unique_labels = len(unique_labels)
        # Add to the set

    # Find the maximum label length
    max_label_len = max(unique_labels_len)
    # Convert to list and sort
    unique_labels_len = list(unique_labels_len)
    # Print results
    print(f" Number of labels across all masks: {unique_labels_len} \n Maximum number of masks: {max_label_len}")

    return max_label_len

NUM_CLASSES = check_mask_labels(masks)


所有掩码的标签数量:[1, 30, 34, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 , 56, 57, 58, 59, 60, 61, 62, 63, 64]



