论文解读(VAT)《Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning》

发布时间 2023-04-22 21:08:40作者: VX账号X466550

论文信息

论文标题:Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning
论文作者:Takeru Miyato, S. Maeda, Masanori Koyama, S. Ishii
论文来源:2020 ECCV
论文地址:download 
论文代码:download
视屏讲解:click

1 前言

  提出问题:在半监督域适应任务中,作者针对目标域中标签样本和部分无标签,目标域标签样本会偏向源域这一现象,从而形成较大的域内差异。

  我们提出了一种基于虚拟对抗损失的新正则化方法:给定输入的条件标签分布的局部平滑度的新度量。 虚拟对抗损失被定义为每个输入数据点周围的条件标签分布对局部扰动的鲁棒性。 与对抗训练不同,我们的方法定义了没有标签信息的对抗方向,因此适用于半监督学习。 因为我们平滑模型的方向只是“虚拟”对抗,所以我们称我们的方法为虚拟对抗训练(VAT)。 增值税的计算成本相对较低。 对于神经网络,虚拟对抗损失的近似梯度可以通过不超过两对前向和反向传播来计算。 在我们的实验中,我们将 VAT 应用于多个基准数据集上的监督和半监督学习任务。 通过基于熵最小化原理对算法进行简单增强,我们的 VAT 在 SVHN 和 CIFAR-10 上实现了半监督学习任务的最先进性能。

2 介绍

2.1 问题定义

  We begin this section with a set of notations. Let  $x \in R^{I}$  and  $y \in Q$  respectively denote an input vector and an output label, where  $I$  is the input dimension and  $Q$  is the space of all labels. Additionally, we denote the output distribution parameterized by  $\theta$  as  $p(y \mid x, \theta)$ . We use  $\hat{\theta}$  to denote the vector of the model parameters at a specific iteration step of the training process. We use  $\mathcal{D}_{l}=\left\{x_{l}^{(n)}, y_{l}^{(n)} \mid n=1, \ldots, N_{l}\right\}$  to denote a labeled dataset, and  $\mathcal{D}_{u l}=\left\{x_{u l}^{(m)} \mid m=1, \ldots, N_{u l}\right\}$  to denote an unlabeled dataset. We train the model  $p(y \mid x, \theta)$  using  $\mathcal{D}_{l}$  and  $\mathcal{D}_{u l}$ .

2.2 对抗训练

  对抗训练:

    $\begin{array}{l}L_{\mathrm{adv}}\left(x_{l}, \theta\right):=D\left[q\left(y \mid x_{l}\right), p\left(y \mid x_{l}+r_{\mathrm{adv}}, \theta\right)\right] \quad\quad\quad(1)\\\text { where } r_{\mathrm{adv}}:=\underset{r ;\|r\| \leq \epsilon}{\arg \max } D\left[q\left(y \mid x_{l}\right), p\left(y \mid x_{l}+r, \theta\right)\right]\quad\quad\quad(2)\end{array}$

  其中,$D\left[q, p^{\prime}\right]$ 是非负函数,用域测量分布 $p$ 和 分布 $q$ 之间的差异。例子,$D$ 可以是交叉熵 $D\left[p, p^{\prime}\right]=-\sum_{i} p_{i} \log p_{i}^{\prime}$

  公式的含义很明显,就是想在 adversarial direction 上找到一个模小于 $\epsilon$ 的扰动,但是优化这个损失是一件困难的事情,一般情况下我们很难找到这样一个精确的 $r$ ,所以可以采用如下的线性估计的方法,找到最近似的扰动:

    $r_{\mathrm{adv}} \approx \epsilon \frac{g}{\|g\|_{2}}, \text { where } g=\nabla_{x_{l}} D\left[h\left(y ; y_{l}\right), p\left(y \mid x_{l}, \theta\right)\right]$

  当范数为 $L_{\infty}$ 时,对抗性扰动可以近似为

    $r_{\text {adv }} \approx \epsilon \operatorname{sign}(g)$

3 方法

3.1 虚拟对抗训练

  对抗训练是一种成功的方法,适用于许多监督问题。 但是,并非始终提供完整的标签信息。 让 $x_{*}$ 代表 $x_{l}$ 或 $x_{u l}$。虚拟对抗训练目标函数如下:

    $\begin{array}{l}D\left[q\left(y \mid x_{*}\right), p\left(y \mid x_{*}+r_{\mathrm{qadv}}, \theta\right)\right] \\\text { where } r_{\mathrm{qadv}}:=\underset{r ;\|r\| \leq \epsilon}{\arg \max } D\left[q\left(y \mid x_{*}\right), p\left(y \mid x_{*}+r, \theta\right)\right]\end{array}$

  在本研究中,使用当前估计值 $p(y \mid x, \hat{\theta})$ 代替 $q(y \mid x)$。 通过这种折中,得到了 $Eq.2$ 的再现:

    $\begin{array}{l}\operatorname{LDS}\left(x_{*}, \theta\right):=D\left[p\left(y \mid x_{*}, \hat{\theta}\right), p\left(y \mid x_{*}+r_{\mathrm{vadv}}, \theta\right)\right] \\r_{\mathrm{vadv}}:=\underset{r ;\|r\|_{2} \leq \epsilon}{\arg \max } D\left[p\left(y \mid x_{*}, \hat{\theta}\right), p\left(y \mid x_{*}+r\right)\right]\end{array}$

  损失 $\operatorname{LDS}(x, \theta)$ 可以被认为是当前模型在每个输入数据点 $x$ 处的局部平滑度的负度量,它的减少将使模型在每个数据点处变得平滑。在本研究中提出的正则化项是所有输入数据点的 $\operatorname{LDS}\left(x_{*}, \theta\right)$ 的平均值:

    $\mathcal{R}_{\mathrm{vadv}}\left(\mathcal{D}_{l}, \mathcal{D}_{u l}, \theta\right):=\frac{1}{N_{l}+N_{u l}} \sum_{x_{*} \in \mathcal{D}_{l}, \mathcal{D}_{u l}} \operatorname{LDS}\left(x_{*}, \theta\right)$

  完整的目标函数:

    $\ell\left(\mathcal{D}_{l}, \theta\right)+\alpha \mathcal{R}_{\mathrm{vadv}}\left(\mathcal{D}_{l}, \mathcal{D}_{u l}, \theta\right)$

  其中,$\ell\left(\mathcal{D}_{l}, \theta\right)$ 是标记数据集的负对数似然。 VAT 是一种使用正则化器 $\mathcal{R}_{\text {vadv }}$ 的训练方法。

  VAT 的一个显着优点是只有两个标量值超参数:

    • 对抗方向的范数约束 $\epsilon>0$ ;
    • 控制负对数似然之间的相对平衡的正则化系数 $\alpha>0$ 和正则化器 $\mathcal{R}_{\mathrm{vadv}}$;

  实验:

  

  上图直观的显示了 VAT 在半监督任务上的表现的举例,可以看到第二行第二列,在一开始模型迭代伦次较少的情况下,有大量的无标签数据(那些大量的灰色点)会有较高的 LDS(深蓝色),这是因为一开始的模型对相同类别的数据点预测了不同的标签(见同列第一行),VAT 会给予这些 LDS 较高数据点更大的压力,来迫使模型让数据点间的边界平滑。

代码:

class VATLoss(nn.Module):

    def __init__(self, xi=10.0, eps=1.0, ip=1):
        """VAT loss
        :param xi: hyperparameter of VAT (default: 10.0)
        :param eps: hyperparameter of VAT (default: 1.0)
        :param ip: iteration times of computing adv noise (default: 1)
        """
        super(VATLoss, self).__init__()
        self.xi = xi  #10.0
        self.eps = eps   #1.0
        self.ip = ip  #1

    def forward(self, model, x):
        with torch.no_grad():
            pred = F.softmax(model(x), dim=1)  #torch.Size([32, 10])

        # prepare random unit tensor
        d = torch.rand(x.shape).sub(0.5).to(x.device)  #torch.Size([32, 3, 32, 32])
        d = _l2_normalize(d)

        with _disable_tracking_bn_stats(model):
            # calc adversarial direction
            for _ in range(self.ip):
                d.requires_grad_()
                pred_hat = model(x + self.xi * d)
                logp_hat = F.log_softmax(pred_hat, dim=1)
                adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean')
                adv_distance.backward()
                d = _l2_normalize(d.grad)
                model.zero_grad()
    
            # calc LDS
            r_adv = d * self.eps  #r_adv.requires_grad = False
            pred_hat = model(x + r_adv)  # x + r_adv .requires_grad = False
            logp_hat = F.log_softmax(pred_hat, dim=1)
            lds = F.kl_div(logp_hat, pred, reduction='batchmean')

        return lds
VAT loss
def train(args, model, device, data_iterators, optimizer):
    model.train()
    for i in tqdm(range(args.iters)):
        if i % args.log_interval == 0:
            ce_losses = utils.AverageMeter()
            vat_losses = utils.AverageMeter()
            prec1 = utils.AverageMeter()
        
        x_l, y_l = next(data_iterators['labeled'])
        x_ul, _ = next(data_iterators['unlabeled'])
        x_l, y_l = x_l.to(device), y_l.to(device)
        x_ul = x_ul.to(device)

        optimizer.zero_grad()
        vat_loss = VATLoss(xi=args.xi, eps=args.eps, ip=args.ip)  # 10.0    1.0   1
        cross_entropy = nn.CrossEntropyLoss()

        lds = vat_loss(model, x_ul)
        output = model(x_l)
        classification_loss = cross_entropy(output, y_l)
        loss = classification_loss + args.alpha * lds
        loss.backward()
        optimizer.step()
train_epcoh