python - 在经过一定数量的 epoch 之后,GAN 中的假图像创建变得最糟糕

标签 python pytorch conv-neural-network artificial-intelligence generative-adversarial-network

我正在尝试创建 GAN 模型。
这是我的 discriminator.py

import torch.nn as nn
class D(nn.Module):
    feature_maps = 64
    kernel_size = 4
    stride = 2
    padding = 1
    bias = False
    inplace = True

    def __init__(self):
        super(D, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(4, self.feature_maps, self.kernel_size, self.stride, self.padding, bias=self.bias),
            nn.LeakyReLU(0.2, inplace=self.inplace),
            nn.Conv2d(self.feature_maps, self.feature_maps * 2, self.kernel_size, self.stride, self.padding,
                      bias=self.bias),
            nn.BatchNorm2d(self.feature_maps * 2), nn.LeakyReLU(0.2, inplace=self.inplace),
            nn.Conv2d(self.feature_maps * 2, self.feature_maps * (2 * 2), self.kernel_size, self.stride, self.padding,
                      bias=self.bias),
            nn.BatchNorm2d(self.feature_maps * (2 * 2)), nn.LeakyReLU(0.2, inplace=self.inplace),
            nn.Conv2d(self.feature_maps * (2 * 2), self.feature_maps * (2 * 2 * 2), self.kernel_size, self.stride,
                      self.padding, bias=self.bias),
            nn.BatchNorm2d(self.feature_maps * (2 * 2 * 2)), nn.LeakyReLU(0.2, inplace=self.inplace),
            nn.Conv2d(self.feature_maps * (2 * 2 * 2), 1, self.kernel_size, 1, 0, bias=self.bias),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1)
这是我的 generator.py
import torch.nn as nn
class G(nn.Module):
    feature_maps = 512
    kernel_size = 4
    stride = 2
    padding = 1
    bias = False

    def __init__(self, input_vector):
        super(G, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(input_vector, self.feature_maps, self.kernel_size, 1, 0, bias=self.bias),
            nn.BatchNorm2d(self.feature_maps), nn.ReLU(True),
            nn.ConvTranspose2d(self.feature_maps, int(self.feature_maps // 2), self.kernel_size, self.stride, self.padding,
                               bias=self.bias),
            nn.BatchNorm2d(int(self.feature_maps // 2)), nn.ReLU(True),
            nn.ConvTranspose2d(int(self.feature_maps // 2), int((self.feature_maps // 2) // 2), self.kernel_size, self.stride,
                               self.padding,
                               bias=self.bias),
            nn.BatchNorm2d(int((self.feature_maps // 2) // 2)), nn.ReLU(True),
            nn.ConvTranspose2d((int((self.feature_maps // 2) // 2)), int(((self.feature_maps // 2) // 2) // 2), self.kernel_size,
                               self.stride, self.padding,
                               bias=self.bias),
            nn.BatchNorm2d(int((self.feature_maps // 2) // 2) // 2), nn.ReLU(True),
            nn.ConvTranspose2d(int(((self.feature_maps // 2) // 2) // 2), 4, self.kernel_size, self.stride, self.padding,
                               bias=self.bias),
            nn.Tanh()
        )

    def forward(self, input):
        output = self.main(input)
        return output
这是我的 gans.py
# Importing the libraries
from __future__ import print_function
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from generator import G
from discriminator import D
import os
from PIL import Image

batchSize = 64  # We set the size of the batch.
imageSize = 64  # We set the size of the generated images (64x64).
input_vector = 100
nb_epochs = 500
# Creating the transformations
transform = transforms.Compose([transforms.Resize((imageSize, imageSize)), transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5,
                                                                       0.5)), ])  # We create a list of transformations (scaling, tensor conversion, normalization) to apply to the input images.


def pil_loader_rgba(path: str) -> Image.Image:
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGBA')


# Loading the dataset
dataset = dset.ImageFolder(root='./data', transform=transform, loader=pil_loader_rgba)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True,
                                         num_workers=2)  # We use dataLoader to get the images of the training set batch by batch.


# Defining the weights_init function that takes as input a neural network m and that will initialize all its weights.
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


def is_cuda_available():
    return torch.cuda.is_available()


def is_gpu_available():
    if is_cuda_available():
        if int(torch.cuda.device_count()) > 0:
            return True
        return False
    return False


# Create results directory
def create_dir(name):
    if not os.path.exists(name):
        os.makedirs(name)


# Creating the generator
netG = G(input_vector)
netG.apply(weights_init)

# Creating the discriminator
netD = D()
netD.apply(weights_init)

if is_gpu_available():
    netG.cuda()
    netD.cuda()

# Training the DCGANs

criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

generator_model = 'generator_model'
discriminator_model = 'discriminator_model'


def save_model(epoch, model, optimizer, error, filepath, noise=None):
    if os.path.exists(filepath):
        os.remove(filepath)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': error,
        'noise': noise
    }, filepath)


def load_checkpoint(filepath):
    if os.path.exists(filepath):
        return torch.load(filepath)
    return None


def main():
    print("Device name : " + torch.cuda.get_device_name(0))
    for epoch in range(nb_epochs):

        for i, data in enumerate(dataloader, 0):
            checkpointG = load_checkpoint(generator_model)
            checkpointD = load_checkpoint(discriminator_model)
            if checkpointG:
                netG.load_state_dict(checkpointG['model_state_dict'])
                optimizerG.load_state_dict(checkpointG['optimizer_state_dict'])
            if checkpointD:
                netD.load_state_dict(checkpointD['model_state_dict'])
                optimizerD.load_state_dict(checkpointD['optimizer_state_dict'])

            # 1st Step: Updating the weights of the neural network of the discriminator

            netD.zero_grad()

            # Training the discriminator with a real image of the dataset
            real, _ = data
            if is_gpu_available():
                input = Variable(real.cuda()).cuda()
                target = Variable(torch.ones(input.size()[0]).cuda()).cuda()
            else:
                input = Variable(real)
                target = Variable(torch.ones(input.size()[0]))
            output = netD(input)
            errD_real = criterion(output, target)

            # Training the discriminator with a fake image generated by the generator
            if is_gpu_available():
                noise = Variable(torch.randn(input.size()[0], input_vector, 1, 1)).cuda()
                target = Variable(torch.zeros(input.size()[0])).cuda()
            else:
                noise = Variable(torch.randn(input.size()[0], input_vector, 1, 1))
                target = Variable(torch.zeros(input.size()[0]))
            fake = netG(noise)
            output = netD(fake.detach())
            errD_fake = criterion(output, target)

            # Backpropagating the total error
            errD = errD_real + errD_fake
            errD.backward()
            optimizerD.step()

            # 2nd Step: Updating the weights of the neural network of the generator
            netG.zero_grad()
            if is_gpu_available():
                target = Variable(torch.ones(input.size()[0])).cuda()
            else:
                target = Variable(torch.ones(input.size()[0]))
            output = netD(fake)
            errG = criterion(output, target)
            errG.backward()
            optimizerG.step()

            # 3rd Step: Printing the losses and saving the real images and the generated images of the minibatch every 100 steps

            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (
            epoch, nb_epochs, i, len(dataloader), errD.data, errG.data))
            save_model(epoch, netG, optimizerG, errG, generator_model, noise)
            save_model(epoch, netD, optimizerD, errD, discriminator_model, noise)

            if i % 100 == 0:
                create_dir('results')
                vutils.save_image(real, '%s/real_samples.png' % "./results", normalize=True)
                fake = netG(noise)
                vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize=True)


if __name__ == "__main__":
    main()
所以几个小时后,我决定查看我的结果文件夹。我在第 39 个时代之后看到了奇怪的事情。
生成器开始生成最差的图像。直到第 39 个纪元生成器改进。
请看下面的截图。
enter image description here
为什么发电机突然变坏了?
我正在尝试运行 500 个时代。我认为更多的时代更成功
所以我查看了日志,我在下面看到
[40/500][0/157] Loss_D: 0.0141 Loss_G: 5.7559
[40/500][1/157] Loss_D: 0.0438 Loss_G: 5.5805
[40/500][2/157] Loss_D: 0.0161 Loss_G: 6.4947
[40/500][3/157] Loss_D: 0.0138 Loss_G: 7.1711
[40/500][4/157] Loss_D: 0.0547 Loss_G: 4.6262
[40/500][5/157] Loss_D: 0.0295 Loss_G: 4.7831
[40/500][6/157] Loss_D: 0.0103 Loss_G: 6.3700
[40/500][7/157] Loss_D: 0.0276 Loss_G: 5.9162
[40/500][8/157] Loss_D: 0.0205 Loss_G: 6.3571
[40/500][9/157] Loss_D: 0.0139 Loss_G: 6.4961
[40/500][10/157] Loss_D: 0.0117 Loss_G: 6.4371
[40/500][11/157] Loss_D: 0.0057 Loss_G: 6.6858
[40/500][12/157] Loss_D: 0.0203 Loss_G: 5.4308
[40/500][13/157] Loss_D: 0.0078 Loss_G: 6.5749
[40/500][14/157] Loss_D: 0.0115 Loss_G: 6.3202
[40/500][15/157] Loss_D: 0.0187 Loss_G: 6.2258
[40/500][16/157] Loss_D: 0.0052 Loss_G: 6.5253
[40/500][17/157] Loss_D: 0.0158 Loss_G: 5.5672
[40/500][18/157] Loss_D: 0.0156 Loss_G: 5.5416
[40/500][19/157] Loss_D: 0.0306 Loss_G: 5.4550
[40/500][20/157] Loss_D: 0.0077 Loss_G: 6.1985
[40/500][21/157] Loss_D: 0.0158 Loss_G: 5.3092
[40/500][22/157] Loss_D: 0.0167 Loss_G: 5.8395
[40/500][23/157] Loss_D: 0.0119 Loss_G: 6.0849
[40/500][24/157] Loss_D: 0.0104 Loss_G: 6.5493
[40/500][25/157] Loss_D: 0.0182 Loss_G: 5.6758
[40/500][26/157] Loss_D: 0.0145 Loss_G: 5.8336
[40/500][27/157] Loss_D: 0.0050 Loss_G: 6.8472
[40/500][28/157] Loss_D: 0.0080 Loss_G: 6.4894
[40/500][29/157] Loss_D: 0.0186 Loss_G: 5.5563
[40/500][30/157] Loss_D: 0.0143 Loss_G: 6.4144
[40/500][31/157] Loss_D: 0.0377 Loss_G: 5.4557
[40/500][32/157] Loss_D: 0.0540 Loss_G: 4.6034
[40/500][33/157] Loss_D: 0.0200 Loss_G: 5.6417
[40/500][34/157] Loss_D: 0.0189 Loss_G: 5.7760
[40/500][35/157] Loss_D: 0.0197 Loss_G: 6.1732
[40/500][36/157] Loss_D: 0.0093 Loss_G: 6.4046
[40/500][37/157] Loss_D: 0.0281 Loss_G: 5.5217
[40/500][38/157] Loss_D: 0.0410 Loss_G: 5.9157
[40/500][39/157] Loss_D: 0.0667 Loss_G: 5.2522
[40/500][40/157] Loss_D: 0.0530 Loss_G: 5.6412
[40/500][41/157] Loss_D: 0.0315 Loss_G: 5.9325
[40/500][42/157] Loss_D: 0.0097 Loss_G: 6.7819
[40/500][43/157] Loss_D: 0.0157 Loss_G: 5.8630
[40/500][44/157] Loss_D: 0.0382 Loss_G: 5.1942
[40/500][45/157] Loss_D: 0.0331 Loss_G: 5.1490
[40/500][46/157] Loss_D: 0.0362 Loss_G: 5.7026
[40/500][47/157] Loss_D: 0.0237 Loss_G: 5.7493
[40/500][48/157] Loss_D: 0.0227 Loss_G: 5.7636
[40/500][49/157] Loss_D: 0.0230 Loss_G: 5.6500
[40/500][50/157] Loss_D: 0.0329 Loss_G: 5.4542
[40/500][51/157] Loss_D: 0.0306 Loss_G: 5.6473
[40/500][52/157] Loss_D: 0.0254 Loss_G: 5.8464
[40/500][53/157] Loss_D: 0.0402 Loss_G: 5.8609
[40/500][54/157] Loss_D: 0.0242 Loss_G: 5.9952
[40/500][55/157] Loss_D: 0.0400 Loss_G: 5.8378
[40/500][56/157] Loss_D: 0.0302 Loss_G: 5.8990
[40/500][57/157] Loss_D: 0.0239 Loss_G: 5.8134
[40/500][58/157] Loss_D: 0.0348 Loss_G: 5.8109
[40/500][59/157] Loss_D: 0.0361 Loss_G: 5.9011
[40/500][60/157] Loss_D: 0.0418 Loss_G: 5.8825
[40/500][61/157] Loss_D: 0.0501 Loss_G: 6.2302
[40/500][62/157] Loss_D: 0.0184 Loss_G: 6.2755
[40/500][63/157] Loss_D: 0.0273 Loss_G: 5.9655
[40/500][64/157] Loss_D: 0.0250 Loss_G: 5.7513
[40/500][65/157] Loss_D: 0.0298 Loss_G: 6.0434
[40/500][66/157] Loss_D: 0.0299 Loss_G: 6.4280
[40/500][67/157] Loss_D: 0.0205 Loss_G: 6.3743
[40/500][68/157] Loss_D: 0.0173 Loss_G: 6.2749
[40/500][69/157] Loss_D: 0.0199 Loss_G: 6.0541
[40/500][70/157] Loss_D: 0.0309 Loss_G: 6.5044
[40/500][71/157] Loss_D: 0.0177 Loss_G: 6.6093
[40/500][72/157] Loss_D: 0.0363 Loss_G: 7.2993
[40/500][73/157] Loss_D: 0.0093 Loss_G: 7.6995
[40/500][74/157] Loss_D: 0.0087 Loss_G: 7.3493
[40/500][75/157] Loss_D: 0.0540 Loss_G: 8.2688
[40/500][76/157] Loss_D: 0.0172 Loss_G: 8.3312
[40/500][77/157] Loss_D: 0.0086 Loss_G: 7.6863
[40/500][78/157] Loss_D: 0.0232 Loss_G: 7.4930
[40/500][79/157] Loss_D: 0.0175 Loss_G: 7.8834
[40/500][80/157] Loss_D: 0.0109 Loss_G: 9.5329
[40/500][81/157] Loss_D: 0.0093 Loss_G: 7.3253
[40/500][82/157] Loss_D: 0.0674 Loss_G: 10.6709
[40/500][83/157] Loss_D: 0.0010 Loss_G: 10.8321
[40/500][84/157] Loss_D: 0.0083 Loss_G: 8.5728
[40/500][85/157] Loss_D: 0.0124 Loss_G: 6.9085
[40/500][86/157] Loss_D: 0.0181 Loss_G: 7.0867
[40/500][87/157] Loss_D: 0.0130 Loss_G: 7.3527
[40/500][88/157] Loss_D: 0.0189 Loss_G: 7.2494
[40/500][89/157] Loss_D: 0.0302 Loss_G: 8.7555
[40/500][90/157] Loss_D: 0.0147 Loss_G: 7.7668
[40/500][91/157] Loss_D: 0.0325 Loss_G: 7.7779
[40/500][92/157] Loss_D: 0.0257 Loss_G: 8.3955
[40/500][93/157] Loss_D: 0.0113 Loss_G: 8.3687
[40/500][94/157] Loss_D: 0.0124 Loss_G: 7.6081
[40/500][95/157] Loss_D: 0.0088 Loss_G: 7.6012
[40/500][96/157] Loss_D: 0.0241 Loss_G: 7.6573
[40/500][97/157] Loss_D: 0.0522 Loss_G: 10.8114
[40/500][98/157] Loss_D: 0.0071 Loss_G: 11.0529
[40/500][99/157] Loss_D: 0.0043 Loss_G: 8.0707
[40/500][100/157] Loss_D: 0.0141 Loss_G: 7.2864
[40/500][101/157] Loss_D: 0.0234 Loss_G: 7.3585
[40/500][102/157] Loss_D: 0.0148 Loss_G: 7.4577
[40/500][103/157] Loss_D: 0.0190 Loss_G: 8.1904
[40/500][104/157] Loss_D: 0.0201 Loss_G: 8.1518
[40/500][105/157] Loss_D: 0.0220 Loss_G: 9.1069
[40/500][106/157] Loss_D: 0.0108 Loss_G: 9.0069
[40/500][107/157] Loss_D: 0.0044 Loss_G: 8.0970
[40/500][108/157] Loss_D: 0.0076 Loss_G: 7.2699
[40/500][109/157] Loss_D: 0.0052 Loss_G: 7.4036
[40/500][110/157] Loss_D: 0.0167 Loss_G: 7.2742
[40/500][111/157] Loss_D: 0.0032 Loss_G: 7.9825
[40/500][112/157] Loss_D: 0.3462 Loss_G: 32.6314
[40/500][113/157] Loss_D: 0.1704 Loss_G: 40.6010
[40/500][114/157] Loss_D: 0.0065 Loss_G: 44.4607
[40/500][115/157] Loss_D: 0.0142 Loss_G: 43.9761
[40/500][116/157] Loss_D: 0.0160 Loss_G: 45.0376
[40/500][117/157] Loss_D: 0.0042 Loss_G: 45.9534
[40/500][118/157] Loss_D: 0.0061 Loss_G: 45.2998
[40/500][119/157] Loss_D: 0.0023 Loss_G: 45.4654
[40/500][120/157] Loss_D: 0.0033 Loss_G: 44.6643
[40/500][121/157] Loss_D: 0.0042 Loss_G: 44.6020
[40/500][122/157] Loss_D: 0.0002 Loss_G: 44.4807
[40/500][123/157] Loss_D: 0.0004 Loss_G: 44.0402
[40/500][124/157] Loss_D: 0.0055 Loss_G: 43.9188
[40/500][125/157] Loss_D: 0.0021 Loss_G: 43.1988
[40/500][126/157] Loss_D: 0.0008 Loss_G: 41.6770
[40/500][127/157] Loss_D: 0.0001 Loss_G: 40.8719
[40/500][128/157] Loss_D: 0.0009 Loss_G: 40.3803
[40/500][129/157] Loss_D: 0.0023 Loss_G: 39.0143
[40/500][130/157] Loss_D: 0.0254 Loss_G: 39.0317
[40/500][131/157] Loss_D: 0.0008 Loss_G: 37.9451
[40/500][132/157] Loss_D: 0.0253 Loss_G: 37.1046
[40/500][133/157] Loss_D: 0.0046 Loss_G: 36.2807
[40/500][134/157] Loss_D: 0.0025 Loss_G: 35.5878
[40/500][135/157] Loss_D: 0.0011 Loss_G: 33.6500
[40/500][136/157] Loss_D: 0.0061 Loss_G: 33.5011
[40/500][137/157] Loss_D: 0.0015 Loss_G: 30.0363
[40/500][138/157] Loss_D: 0.0019 Loss_G: 31.0197
[40/500][139/157] Loss_D: 0.0027 Loss_G: 28.4693
[40/500][140/157] Loss_D: 0.0189 Loss_G: 27.3072
[40/500][141/157] Loss_D: 0.0051 Loss_G: 26.6637
[40/500][142/157] Loss_D: 0.0077 Loss_G: 24.8390
[40/500][143/157] Loss_D: 0.0123 Loss_G: 23.8334
[40/500][144/157] Loss_D: 0.0014 Loss_G: 23.3755
[40/500][145/157] Loss_D: 0.0036 Loss_G: 19.6341
[40/500][146/157] Loss_D: 0.0025 Loss_G: 18.1076
[40/500][147/157] Loss_D: 0.0029 Loss_G: 16.9415
[40/500][148/157] Loss_D: 0.0028 Loss_G: 16.4647
[40/500][149/157] Loss_D: 0.0048 Loss_G: 14.6184
[40/500][150/157] Loss_D: 0.0074 Loss_G: 13.2544
[40/500][151/157] Loss_D: 0.0053 Loss_G: 13.0052
[40/500][152/157] Loss_D: 0.0070 Loss_G: 11.8815
[40/500][153/157] Loss_D: 0.0078 Loss_G: 12.1657
[40/500][154/157] Loss_D: 0.0094 Loss_G: 10.4259
[40/500][155/157] Loss_D: 0.0073 Loss_G: 9.9345
[40/500][156/157] Loss_D: 0.0082 Loss_G: 9.7609
[41/500][0/157] Loss_D: 0.0079 Loss_G: 9.2920
[41/500][1/157] Loss_D: 0.0134 Loss_G: 8.5241
[41/500][2/157] Loss_D: 0.0156 Loss_G: 8.6983
[41/500][3/157] Loss_D: 0.0250 Loss_G: 8.1148
[41/500][4/157] Loss_D: 0.0160 Loss_G: 8.3324
[41/500][5/157] Loss_D: 0.0187 Loss_G: 7.6281
[41/500][6/157] Loss_D: 0.0191 Loss_G: 7.4707
[41/500][7/157] Loss_D: 0.0092 Loss_G: 8.3976
[41/500][8/157] Loss_D: 0.0118 Loss_G: 7.9800
[41/500][9/157] Loss_D: 0.0126 Loss_G: 7.3999
[41/500][10/157] Loss_D: 0.0165 Loss_G: 7.0854
[41/500][11/157] Loss_D: 0.0095 Loss_G: 7.6392
[41/500][12/157] Loss_D: 0.0079 Loss_G: 7.3862
[41/500][13/157] Loss_D: 0.0181 Loss_G: 7.3812
[41/500][14/157] Loss_D: 0.0168 Loss_G: 6.9518
[41/500][15/157] Loss_D: 0.0094 Loss_G: 7.8525
[41/500][16/157] Loss_D: 0.0165 Loss_G: 7.3024
[41/500][17/157] Loss_D: 0.0029 Loss_G: 8.4487
[41/500][18/157] Loss_D: 0.0169 Loss_G: 7.0449
[41/500][19/157] Loss_D: 0.0167 Loss_G: 7.1307
[41/500][20/157] Loss_D: 0.0255 Loss_G: 6.7970
[41/500][21/157] Loss_D: 0.0154 Loss_G: 6.9745
[41/500][22/157] Loss_D: 0.0110 Loss_G: 6.9925
如您所见,Generator loss(Loss_G) 发生了巨大变化。
知道为什么会这样吗?
知道如何克服这样的问题吗?

最佳答案

由于两个竞争模型的同时动态训练,GAN 训练本质上是不稳定的。尝试从您的问题中绘制损失值,鉴别器和生成器的损失如下所示:
enter image description here
查看损失和生成的图像,我们可以说训练未能收敛。这种失败是由于没有在鉴别器和生成器之间找到平衡。我们看到鉴别器的损失接近于零,并且生成器的损失上升并且不稳定,导致鉴别器可以很容易地将其识别为伪造的垃圾图像。
鉴别器对来自生成器的真实数据和虚假数据进行分类。判别器损失是指它因将真实实例错误分类为假实例或假实例(由生成器创建)而惩罚自己。
生成器损失基于判别器的分类——如果它成功欺骗了判别器,它就会得到奖励,否则就会受到惩罚。 GAN 作为零和非合作博弈,胜利要么是判别器的,要么是生成器的。如果一个赢了,另一个输了。收敛发生在纳什均衡时,即一个人的行为不会影响另一个人。在此处阅读更多信息 https://jonathan-hui.medium.com/gan-why-it-is-so-hard-to-train-generative-advisory-networks-819a86b3750bhttps://jonathan-hui.medium.com/gan-what-is-wrong-with-the-gan-cost-function-6f594162ce01提供对 GAN 挑战的更深入了解。
由于模式崩溃和梯度递减,也可能发生收敛失败。此外,除了 Nihal 建议的爆炸梯度解决方案,

  • 尝试根据指标在模型中实现提前停止 Inception Score, Modified Inception Score, Frechet Inception Distance, Wasserstein distance (取自这篇论文 https://arxiv.org/pdf/1802.03446.pdf)这些措施有助于识别模型收敛,一旦模型收敛就会自动停止。
  • 还表明,频谱归一化,一种应用于卷积核的特殊归一化,可以极大地帮助训练的稳定性。 https://arxiv.org/pdf/1802.05957.pdf
  • 使鉴别器的训练更加困难可能会有所帮助。向真实图像和来自生成器的图像添加噪声有助于增加鉴别器训练的复杂性。

  • 增加迭代并不总是能改进模型。更多的训练迭代,超出训练稳定性的某个点可能会或可能不会由于高方差损失而导致更高质量的图像。而且由于 GAN 相对较新,面临的挑战的研究方向仍然是开放和有争议的。

    关于python - 在经过一定数量的 epoch 之后,GAN 中的假图像创建变得最糟糕,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68904476/

    相关文章:

    python - Tensorflow 中的 CNN - 损失保持不变

    Python "with"Lambda 函数中的关键字

    python - Python中漂亮的图形和图表

    machine-learning - 为什么 GAN 中 Generator 的训练标签应该始终为 True?

    python - 在相等索引位置有效填充 torch.Tensor

    image - 是否通过除以 255 个训练集和测试集之间的泄漏信息来标准化图像?

    python - 我的 CNN 总是得出 0 或 1,而不是百分比。为什么?

    Python:将代码应用于整个数据框列

    python - 如何使 jupyter pytherejs 绘图更大?

    python - 错误 libtorch_python.so : cannot open shared object file: No such file or directory