GAN(生成对抗网络,Generative Adversarial Network)

发布时间 2023-10-08 14:00:49作者: 黑逍逍

生成对抗网络(GAN)是一种深度学习模型架构,由生成器(Generator)和判别器(Discriminator)两个神经网络组成。这两个网络之间进行博弈式训练。

  1. 生成器(Generator):生成器是一个神经网络模型,它接收一个随机噪声向量作为输入,并试图生成与训练数据相似的新数据样本。生成器的目标是生成尽可能逼真的假数据。

    生成器模型采用了一系列反卷积(ConvTranspose2d)层和批标准化(BatchNorm2d)层,最终输出一个图像。它的输入是一个潜在向量 nz,通过一系列反卷积操作将这个向量变换为一个与数据集相匹配的图像

  1. 判别器(Discriminator):判别器也是一个神经网络模型,它的任务是接收输入数据(可以是真实数据或生成器生成的假数据)并尝试将其区分为真实数据或假数据。判别器的目标是尽可能准确地区分真伪数据。

    判别器模型采用了一系列卷积(Conv2d)层和批标准化(BatchNorm2d)层,最终输出一个判别分数,用于表示输入图像是否是真实图像。判别器的输入是一个图像,它将通过一系列卷积操作输出一个标量值,用于表示输入图像的真实程度

 

训练过程:

  • 进行多轮的训练迭代,每轮包括以下步骤:

    a. 更新判别器(Discriminator):

    • 将判别器的梯度清零。
    • 从训练数据集中加载真实图像样本,并将它们传递给判别器,计算判别器对真实图像的输出。
    • 计算判别器的判别损失,这个损失衡量了判别器对真实图像的判别能力。
    • 从生成器中生成假的图像样本,将它们传递给判别器(注意:这里不更新生成器的参数)。
    • 计算判别器的判别损失,这个损失衡量了判别器对假的图像的判别能力。
    • 计算总的判别损失,通常是真实图像损失和假的图像损失的和。
    • 反向传播梯度并更新判别器的参数,以提高其判别能力。

    b. 更新生成器(Generator):

    • 将生成器的梯度清零。
    • 生成一批潜在向量(随机噪声),将它们传递给生成器,生成假的图像样本。
    • 将生成的假的图像传递给判别器,计算判别器对假的图像的输出。
    • 计算生成器的损失,这个损失旨在鼓励生成器生成更具迷惑性的图像,以欺骗判别器。
    • 反向传播梯度并更新生成器的参数,以生成更逼真的假图像。
  • 每轮迭代后,可以打印出损失值和一些统计信息,以监视训练的进展。

输入:

输出: