Cycle GAN:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

发布时间 2023-04-25 21:20:38作者: sinatJ

paper:https://arxiv.org/pdf/1703.10593.pdf [2017]

code 参考:

1 整体架构

整体架构主要由两个 生成器(G 和 F)、两个判别器(Dy 和 Dx)组成。
这里借用了语言翻译领域中循环一致性的思想,即将一个句子从中文翻译到英文,然后再将其翻译回中文,应该得到与原始的中文相同的句子。
所以这篇 paper 的做法是将 x 经由 G 生成得到 \(\hat{y}\),再经由 F 生成得到 \(\hat{x}\),这样看的话,这里其实就有了三个数据域:原始图像 x 所属的数据域 X,G 生成的中间图像 \(\hat{y}\) 所属的数据域 \(\hat{Y}\),以及 F 生成的重建图像 \(\hat{x}\) 的数据域 \(\hat{X}\)
其中域 X 和域\(\hat{X}\)在训练过程中是期望它们的分布尽量接近的(所以其实这里也可以理解为总共有 2 个域),而我们所需要的风格转换后的图像其实就是中间图像 \(\hat{y}\)
整个过程如上图 (a) 所示,两个生成器 G 和 F 循环生成,然后两个判别器 Dx 和 Dy 在各自的域上进行判断。

2 核心创新

提出了循环一致性损失,训练不再需要成对的图像数据,成为后续 GANs 相关论文的重要参考。

3 损失函数

1)传统 GAN 的生成&判别损失

由于有两个生成器(Generator)G 和 F、两个判别器(Discriminator)Dx 和 Dy,所以生成&判别损失其实是有两个部分:

以第一个公式为例,鉴别器 Dy 越大,表明预测结果越接近真实图片(也即生成的结果越接近真实)。
其中鉴别器的损失计算采用 L1 loss。

2)G 和 D 的优化目标

对于 Generator 来说,其目的就是为了使得生成的图像越接近真实图像越好,所以其优化目标是使得 Discriminator 的判别概率越大越好,也即 \(max_{D_Y}\)

而对于 Discriminator 来说,其目的是为了尽量鉴别出由 Generator 生成的非真实图片(对于生成图片,给出低概率),所以其目标是使得对 Generator 生成的图片赋低分,也即 \(min_{G}\)

由于有两个 Generator 和 两个 Discriminator,所以优化目标也有两组:

3)循环一致性损失

循环一致性损失主要作用就是控制在使用非对称样本时,生成结果别跑偏了,所以需要控制 F 的重建结果和原始图片的一致性,其定义如下:

其实就是两个 Generator 组成的两阶段生成的过程中,第二阶段的重建结果与第一阶段输出图片的一致性之和。
循环一致性损失采用的是 MES loss。

4)整体损失函数

就是两组 生成&对抗 损失加上一个 带权重的 循环一致性损失。

4 代码解读&实现

4.1 前置知识(可选)

在直接阅读代码之前,为了保证阅读代码的流畅性,有必要将一些可能引起疑惑的操作函数进行说明,主要包括:

  • nn.ConvTranspose2d()
  • nn.InstanceNorm2d()
  • nn.detach()
  • albumentations 图像增强库
  • ReflectionPad2d

这些函数如果已经知道的可以直接跳过。

1)nn.ConvTranspose2d()

也叫转置卷积、反卷积,和卷积对应,其目的是将低 size 的 feature map 转为 高 size 的 feature map,是图像重建过程中恢复图像原来尺寸常用的操作。

尺寸变换公式:

其实大多数时候,我们在执行卷积/反卷积的时候,只是期望能将 feature size 缩放/放大 为原来的一倍,所以这里可以简化的去记变换规则,即:

当我们希望得到 输入特征图大小/输出特征图大小 = stride 的话,需要 padding = (kernel_size - stride + output_padding )/2,进一步的则 output_padding 应该取值为 stride - 1。

常用的一组参数为:kernel_size=3, stride=2, padding=1, output_padding=1,这样正好使得 feature map 被反卷积上采样为 2 倍原尺寸大小。

更多的解释可以参考:

2)nn.InstanceNorm2d()

又叫实例归一化,其是对每个样本沿着通道方向独立对各个通道进行计算,而批量归一化则是对所有样本沿着batch 的方向对各个通道分别进行计算。

举个例子:当输入特征图形状为 (2,3,256,512),表示有两个 256×512 的特征图,特征图通道数为 3,假设为 RGB 三个通道。

那么实例归一化会依次对样本 1,样本 2 分别计算 R、G、B 三个通道的均值、方差,每次计算其实是对 256×512 个元素值进行计算。

而批量归一化则是对整个批次的样本,对各个通道分别求出均值和方差,每次计算其实是对 2×256×512 个元素值进行计算。

论文图示:

至于为啥风格转换任务中要使用 IN,摘录知乎回答如下:

BN 适用于判别模型中,比如图片分类模型。因为 BN 注重对每个batch进行归一化,从而保证数据分布的一致性,而判别模型的结果正是取决于数据整体分布。但是BN对batchsize的大小比较敏感,由于每次计算均值和方差是在一个batch上,所以如果batchsize太小,则计算的均值、方差不足以代表整个数据分布;

IN适用于生成模型中,比如图片风格迁移。因为图片生成的结果主要依赖于某个图像实例,所以对整个batch归一化不适合图像风格化中,在风格迁移中使用 Instance Normalization 不仅可以加速模型收敛,并且可以保持每个图像实例之间的独立。

参考:

3)nn.detach()

举个例子来说明一下detach有什么用。 如果 A 网络的输出被喂给 B 网络作为输入,如果我们希望在梯度反传(loss.backward())的时候只更新 B 中参数的值,而不更新 A 中的参数值,这时候就可以使用 detach(),代码示例:

...
fake_a = gen_A(domain_b_img)
D_A_fake_prob = disc_A(fake_a.detach())
loss = mse(D_A_fake_prob, label_a_prob)

...
loss.backward()
...

这样在进行 backward 时,disc_A 网络会更新参数,但 gen_A 网络不会。

详见:https://zhuanlan.zhihu.com/p/410199046

4)albumentations

albumentations 是一个第三方的图像增强操作库,其主要特点就是快,也封装了很多常规的图像增强方式(例如翻转、随机裁剪等)。

其使用方式也很简单:

import albumentations as A
from albumentations.pytorch import ToTensorV2

transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={'image0': 'image'}
)

其中需要解释的是 additional_targets 参数,其实就是将多个成对的 image 对象绑定到一起,例如将两个附加图片都绑定到原始图片上(使他们成对):{'image0': 'image', 'image1': 'image'},这样 image0 和 image1 就会执行和 image 相同的 transform 操作。

详见:

5)ReflectionPad2d()

这个填充函数不同全零填充(padding),而是采用输入边界的反射来填充输入张量,说人话就是用图像矩阵中其他位置的像素值来填充(扩充)边界,从而增大图像尺寸。

  1. 填充一层时:m = nn.ReflectionPad2d(1)
    填充顺序是:左、右、上、下
  1. 填充多层时
    示例:nn.ReflectionPad2d((1, 1, 2, 0)) 中,这几个数字表示左右上下分别要填充的层数

之所以使用反射填充,一个主要的原因是如果我们直接使用全零填充,会导致图像产生黑边,影响视觉模型的训练效果,因为黑边其实是个很明显的结构特征。

在视觉代码中通常可以通过如下的写法替代原始 Conv 中默认的全零填充方式:

nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode='reflect'),

详细可参考:https://zhuanlan.zhihu.com/p/351958361

6)clamp()

常用 out.clamp(0, 1) 将 out 中各个数的取值范围压缩到 0-1 之间。

7)Automatic mixed precision

Automatic mixed precision(amp),自动混合精度,可以在神经网络推理过程中,针对不同的层,采用不同的数据精度进行计算,从而实现节省显存和加快速度的目的。

自动预示着 Tensor 的 dtype 类型会自动变化,也就是框架按需自动调整 tensor 的 dtype。

更多详见:https://blog.csdn.net/Z2572862506/article/details/128800233

4.2 Code

本来想直接从官方代码入手的,但是官方的封装的有点复杂,对新手不太友好,无法直观的关注 cGAN 核心的逻辑,所以参考了多份代码,下面的内容其实就是参照着其中我感觉比较好的一份实现来写得,原代码作者视频:
https://www.bilibili.com/video/BV1kb4y197PE/?spm_id_from=333.337.search-card.all.click&vd_source=bda72e785d42f592b8a2dc6c2aad2409

4.3 Generator module

Generator 由多层卷积与残差模块堆叠而成,顺序依次为:初始转置卷积 + 2层下采样 + 9 个残差模块 + 2层上采样 + 最后一层卷积。

图像经过 Generator 后,输出与原图像保持相同尺寸。

import torch
import torch.nn as nn


class ConvBlock(nn.Module):

    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        """
        :param in_channels:
        :param out_channels:
        :param down: 是否下采样
        :param use_act: 是否使用激活函数
        :param kwargs:
        """
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode='reflect', **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            # inplace = False 时,不会修改输入对象的值,而是返回一个新创建的对象,所以打印出对象存储地址不同
            # inplace = True 时,会修改输入对象的值,所以打印出对象存储地址相同
            # inplace = True ,会改变输入数据的值,节省反复申请与释放内存的空间与时间,只是将原来的地址传递,效率更好
            nn.ReLU(inplace=True) if use_act else nn.Identity()     # nn.Identity() 这里其实就是个占位,当不使用激活函数时,表明什么都不做
        )

    def forward(self, x):
        return self.conv(x)


class ResidualBlock(nn.Module):

    def __init__(self, channels):
        """
        这里的 channel、ks=3,pad=1,保证了输入数据和输出数据的维度不会发生改变,只是单纯的做 residual
        :param channels:
        """
        super().__init__()

        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)


class Generator(nn.Module):

    def __init__(self, img_channels, num_features=64, num_residuals=9):
        """
        Generator 经过 下采样、残差连接、上采样,其输出尺寸和输入尺寸是一致的(当然也只有这样,才能使用 l1 计算循环一致性损失)
        :param img_channels: 输入图像通道数,默认为 3
        :param num_features: Generator 编码时的图像尺寸基数,后面会基于该基数转换尺寸
        :param num_residuals: 堆叠的残差模块个数
        """

        super().__init__()

        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode='reflect'),
            nn.ReLU(inplace=True)
        )
        # 两层卷积下采样
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )
        # num_residuals 层堆叠的残差模块
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode='reflect')

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.residual_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))


def test():
    img_channels = 3
    img_size = 256
    x = torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels, 9)
    print(gen)
    print('-' * 90)
    y = gen(x)
    print(y.shape)      # torch.Size([2, 3, 256, 256])


if __name__ == '__main__':
    test()

4.3 Discriminator module

Discriminator 由多层卷积组成,原论文中是直接将图片输入转为标量的概率输出,但是这里实现并没有转为标量,而是直接作为 vector 输出,之后在计算 loss 时采用 ones_like 与真实值对应上(感觉也行吧)。

import torch
import torch.nn as nn
import torch.nn.functional as F


class Block(nn.Module):

    def __init__(self, in_channles, out_channels, stride):
        super().__init__()

        self.conv = nn.Sequential(
            # 转置卷积
            nn.Conv2d(in_channles, out_channels, kernel_size=4, stride=stride, padding=1, padding_mode='reflect'),
            # 实例归一化
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):

    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()

        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
            nn.LeakyReLU(0.2)
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                Block(in_channels, feature, stride=1 if feature == features[-1] else 2)
            )
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, out_channels=1, kernel_size=4, padding=1, padding_mode='reflect'))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))


def test():
    x = torch.randn((1, 3, 256, 256))
    model = Discriminator(in_channels=3)
    preds = model(x)
    print(model)
    print('-' * 90)
    print(preds.shape)      # torch.Size([1, 1, 30, 30])


if __name__ == '__main__':
    test()

4.4 Dataset Loader module

在介绍数据集加载逻辑之前,先把 config.py 和 utils.py 中的一些功能函数和参数说明贴一下:

utils.py:

import numpy as np
import os
import random
import torch


def save_checkpoint(model, optimizer, filename='my_checkpoint.pth.tar'):
    print('=> Saving checkpoint')
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('=> Loading checkpoint')
    checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    # if wen don't do this then it will just have learning rate of old checkpoint and it will lead to many hours of debugging
    for param_group in optimizer.param_group:
        param_group['lr'] = lr


def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

config.py:

import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TRAIN_DIR = 'data/train'
VAL_DIR = 'data/val'
BATCH_SIZE = 1
LEARNINGG_RATE = 2e-4
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 2
NUM_EPOCHS = 200
LOAD_MODEL = False      # 初始训练时指定为 False
SAVE_MODEL = True
CHECKPOINT_GEN_A = 'checkpoints/gen_A.pth.tar'
CHECKPOINT_GEN_B = 'checkpoints/gen_B.pth.tar'
CHECKPOINT_DISCRIMINATOR_A = 'checkpoints/d_A.pth.tar'
CHECKPOINT_DISCRIMINATOR_B = 'checkpoints/d_B.pth.tar'

transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        # A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={'image0': 'image'}
)

数据加载逻辑 dataset.py

import torch
from PIL import Image
import os
from torch.utils.data import Dataset
import numpy as np


class ABDataset(Dataset):

    def __init__(self, root_path_domain_a, root_path_domain_b, transform=None):
        self.root_path_domain_a = root_path_domain_a
        self.root_path_domain_b = root_path_domain_b
        self.transform = transform

        self.domain_a_images = os.listdir(root_path_domain_a)
        self.domain_b_images = os.listdir(root_path_domain_b)
        self.length_dataset = max(len(self.domain_a_images), len(self.domain_b_images))   # 1000, 1500
        self.domain_a_len = len(self.domain_a_images)
        self.domain_b_len = len(self.domain_b_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        """
        示例:
            domain_a 对应 horse
            domain_b 对应 zebra
        :param index:
        :return:
        """
        domain_a_img = self.domain_a_images[index % self.domain_a_len]
        domain_b_img = self.domain_b_images[index % self.domain_b_len]

        domain_a_path = os.path.join(self.root_path_domain_a, domain_a_img)
        domain_b_path = os.path.join(self.root_path_domain_b, domain_b_img)

        domain_a_img = np.array(Image.open(domain_a_path).convert('RGB'))
        domain_b_img = np.array(Image.open(domain_b_path).convert('RGB'))

        if self.transform:
            augmentations = self.transform(image=domain_a_img, image0=domain_b_img)
            domain_a_img = augmentations['image']
            domain_b_img = augmentations['image0']

        return domain_a_img, domain_b_img

4.5 Trian

前置模块写完了,下面定义训练逻辑。在写之前,先梳理一下流程。

  1. 获取两个域的图片 domain_a_img、domain_b_img
  2. 优化两个 Discriminator,目标是 \(min_{fake\_img}\)\(max_{real\_img}\)
    a. 利用 gen_A、gen_B 生成 fake_a、fake_b
    b. 过 disc_A、disc_B 获得判别概率
    c. 计算 disc_A、disc_B 的 loss
  3. 优化两个 Generator,目标是 \(max_{fake\_img}\)\(max_{rec\_img}\)
    a. 利用 gen_A、gen_B 生成 rec_b、rec_a
    b. 过 disc_A、disc_B 获得判别概率
    c. 计算 gen_A、gen_B 的 loss

代码逻辑如下 train.py

import torch
from dataset import ABDataset
import sys
from utils import save_checkpoint, load_checkpoint
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
from discriminator_model import Discriminator
from generator_model import Generator
import config


def train_fn(disc_A, disc_B, gen_B, gen_A, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
    loop = tqdm(loader, leave=True)

    for idx, (domain_a_img, domain_b_img) in enumerate(loop):
        domain_a_img = domain_a_img.to(config.DEVICE)
        domain_b_img = domain_b_img.to(config.DEVICE)

        # Train Discriminator A and B
        with torch.cuda.amp.autocast():
            fake_a = gen_A(domain_b_img)
            D_A_real_prob = disc_A(domain_a_img)
            D_A_fake_prob = disc_A(fake_a.detach())  # 注意这里使用 detach(),使得更新 disc_A 的时候不更新 gen_A
            D_A_real_loss = mse(D_A_real_prob, torch.ones_like(D_A_real_prob))
            D_A_fake_loss = mse(D_A_fake_prob, torch.zeros_like(D_A_fake_prob))
            D_A_loss = D_A_real_loss + D_A_fake_loss

            fake_b = gen_B(domain_a_img)
            D_B_real_prob = disc_B(domain_b_img)
            D_B_fake_prob = disc_B(fake_b.detach())
            D_B_real_loss = mse(D_B_real_prob, torch.ones_like(D_B_real_prob))
            D_B_fake_loss = mse(D_B_fake_prob, torch.zeros_like(D_B_fake_prob))
            D_B_loss = D_B_real_loss + D_B_fake_loss

            # put it togethor
            D_loss = (D_A_loss + D_B_loss) / 2

        # 注意这里使用了 amp 的话,与往常通用的写法有一点不一样了
        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generator A and B
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            # 下面两句在 Discriminator 中已经执行过了,这里有必要再执行一次吗?
            D_A_fake_prob = disc_A(fake_a)
            D_B_fake_prob = disc_B(fake_b)
            # 对于 Gs,其目标是使得 fake img 在 disc 看来,概率越接近真实越好
            loss_G_A = mse(D_A_fake_prob, torch.ones_like(D_A_fake_prob))
            loss_G_B = mse(D_B_fake_prob, torch.ones_like(D_B_fake_prob))

            # cycle consistency loss
            cycle_b = gen_B(fake_a)
            cycle_a = gen_A(fake_b)
            cycle_b_loss = l1(domain_b_img, cycle_b)
            cycle_a_loss = l1(domain_a_img, cycle_a)

            # identity loss
            # 这个原论文中并有提到这个损失,config 配置文件中配置的权重为 0,所以实际上也并没有使用
            identity_b = gen_B(domain_b_img)
            identity_a = gen_A(domain_a_img)
            identity_b_loss = l1(domain_b_img, identity_b)
            identity_a_loss = l1(domain_a_img, identity_a)

            # add all togethor
            G_loss = (
                loss_G_A
                + loss_G_B
                + cycle_a_loss * config.LAMBDA_CYCLE
                + cycle_b_loss * config.LAMBDA_CYCLE
                + identity_a_loss * config.LAMBDA_IDENTITY
                + identity_b_loss * config.LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 100 == 0:
            # 在读入图片时进行过 norm,所以这里保存时需要 denorm
            save_image(domain_a_img * 0.5 + 0.5, f"saved_images/a_{idx}.png")
            save_image(domain_b_img * 0.5 + 0.5, f"saved_images/b_{idx}.png")
            save_image(fake_a * 0.5 + 0.5, f"saved_images/fake_a_{idx}.png")    # fake_a 由 gen_B 生成
            save_image(fake_b * 0.5 + 0.5, f"saved_images/fake_b_{idx}.png")    # fake_b 由 gen_A 生成


def main():
    disc_A = Discriminator(in_channels=3).to(config.DEVICE)
    disc_B = Discriminator(in_channels=3).to(config.DEVICE)

    gen_A = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
    gen_B = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)

    # 将两个 Discriminator 放在一起优化
    opt_disc = optim.Adam(
        list(disc_A.parameters()) + list(disc_B.parameters()),
        lr=config.LEARNINGG_RATE,
        betas=(0.5, 0.999)
    )

    # 将两个 Generator 放在一起优化
    opt_gen = optim.Adam(
        list(gen_A.parameters()) + list(gen_B.parameters()),
        lr=config.LEARNINGG_RATE,
        betas=(0.5, 0.999),
    )

    l1 = nn.L1Loss()
    mse = nn.MSELoss()

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN_A, gen_A, opt_gen, config.LEARNINGG_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_GEN_B, gen_B, opt_gen, config.LEARNINGG_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_DISCRIMINATOR_A, disc_A, opt_disc, config.LEARNINGG_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_DISCRIMINATOR_B, disc_B, opt_disc, config.LEARNINGG_RATE,
        )

    dataset = ABDataset(
        root_path_domain_a=config.TRAIN_DIR+'/horses', root_path_domain_b=config.TRAIN_DIR+'/zebras', transform=config.transforms
    )

    loader = DataLoader(
        dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=True
    )

	# amp
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(config.NUM_EPOCHS):
        train_fn(disc_A, disc_B, gen_B, gen_A, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler)

        if config.SAVE_MODEL:
            save_checkpoint(gen_A, opt_gen, filename=config.CHECKPOINT_DISCRIMINATOR_A)
            save_checkpoint(gen_B, opt_gen, filename=config.CHECKPOINT_DISCRIMINATOR_B)
            save_checkpoint(disc_A, opt_disc, filename=config.CHECKPOINT_DISCRIMINATOR_A)
            save_checkpoint(disc_B, opt_disc, filename=config.CHECKPOINT_DISCRIMINATOR_B)

if __name__ == '__main__':
    main()

4.6 复现结果

完整的训练是 200 个 epoch,时间太长了,这里展示下 10 个 epoch 时的结果,尽管效果还不好,但是网络确实学习着去将 horse 和 zebra 互相转换。