04vgg剪枝

发布时间 2023-07-05 15:37:40作者: DemonSlayer

数据集介绍

Cifar10数据集是经典的图像分类数据。首先下载数据并制作成DatasetDataLoader

  1. DatasetDataset是一个抽象类,它定义了数据的存储和访问方法。它的主要任务是加载和预处理数据。用户可以从此类派生出自定义的数据集类,以处理特定类型的数据(如图像、文本等)。
  2. DataLoaderDataLoader是一个可以对Dataset进行包装的类,它提供了数据的批处理、打乱和并行加载等功能。这对于训练大规模深度学习模型非常有用,因为这样可以使得模型在训练过程中更高效地获取数据。
import os
import torch
import shutil
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
def save_checkpoint(state, is_best, filepath):
    torch.save(state, os.path.join(filepath, 'checkpoint.pth'))
    if is_best:
        shutil.copyfile(os.path.join(filepath, 'checkpoint.pth'), os.path.join(filepath, 'model_best.pth'))
def get_training_dataloader(batch_size=16, **kwargs):
    train_dataset = datasets.CIFAR10(root='./data.cifar10', train=True, download=True)
    #除以255是为了将均值标准化到[0,1]的范围
    mean = train_dataset.data.mean(axis=(0,1,2)) / 255
    std  = train_dataset.data.std(axis=(0,1,2))  / 255
    transform_train = transforms.Compose([
        #填充4个像素,通常在随机裁剪前做
        transforms.Pad(4),
        #随机裁剪成32*32的大小
        transforms.RandomCrop(32),
        #随机水平翻转,50%概率
        transforms.RandomHorizontalFlip(),
        #转化为torch.Tensor的操作,并且会自动将数据的范围从[0, 255]归一化到[0.0, 1.0]。
        transforms.ToTensor(),
        #标准化操作,使用给定的均值(mean)和标准差(std)来对图像数据进行标准化
        transforms.Normalize(mean, std)
    ])
    train_dataset.transform = transform_train
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
    return train_loader
def get_test_dataloader(batch_size=16, **kwargs):
    test_dataset = datasets.CIFAR10(root='./data.cifar10', train=False, download=True)
    mean = test_dataset.data.mean(axis=(0,1,2)) / 255
    std  = test_dataset.data.std(axis=(0,1,2))  / 255
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_dataset.transform = transform_test
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, **kwargs)
    return test_loader

构建vgg网络

import torch.nn as nn
defaultcfg = {
    11 : [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512          ],
    13 : [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512          ],
    16 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512     ],
    19 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
}
class VGG(nn.Module):
    def __init__(self, num_classes=10, depth=11, cfg=None):
        super().__init__()
        if cfg is None:
            cfg = defaultcfg[depth]

        self.feature = self.make_layers(cfg)
        self.classifier = nn.Linear(cfg[-1], num_classes)
    def make_layers(self, cfg):
        layers = []
        in_channels = 3
        for l in cfg:
            if l == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, l, kernel_size=3, padding=1, bias=False)
                layers += [conv2d, nn.BatchNorm2d(l), nn.ReLU(inplace=True)]

                in_channels = l
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.feature(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)
        return y
if __name__ == '__main__':
    net = VGG()

BN层复现

Batch normalize的公式,减去均值除以标准差

\[\hat{x}^{(k)} = \frac{x^{(k)} - E[x^{(k)}]}{\sqrt{Var[x^{(k)}]}} \]

Batch normalize的作用

  • 加快收敛、提升精度:对输入进行归一化,从而使得优化更加容易
  • 减少过拟合:可以减少方差的偏移
  • 可以使得神经网络使用更高的学习率:BN 使得神经网络更加稳定,从而可以使用更大的学习率,加速训练过程
  • 甚至可以减少 Dropout 的使用:因为 BN 可以减少过拟合,所以有了 BN,可以减少其他正则化技术的使用

在深度学习中,"momentum"是一种常用的优化策略,用于加速模型的训练过程。它的主要思想是引入一种物理上的“动量”概念,让模型在更新权重时不仅考虑当前梯度,也考虑过去的梯度,从而避免陷入局部最优解和减少训练过程中的震荡。


上图是一个bn层的实现,最后一步是对标准化后的结果进行 "恢复",这两个参数交给神经网络去学习,后续在对BN层做剪枝的时候,是用l1正则化来对gamma进行稀疏

假设有如下loss函数,事实上无论是怎么样的loss我们就用L(w)来表示关于权重的损失函数

\[L(w) = \frac{1}{N} *\sum\limits^{N}_{i=1}(y_i - w^Tx_i)^2 \]

L1 正则化(Lasso 回归)

  • 加上 L1 正则项(Lasso 回归):

    \[C||w||_1 \]

  • 损失函数:

    \[L_{L1}(w)= L(w) + \lambda|w| \]


可以看到L1正则化会使得参数很快变为0,那么久方便我们去剪枝

L2 正则化(岭回归)

  • 加上 L2 正则项(岭回归):

    \[C||w||^2_2 \]

  • 损失函数:

    \[L_{L2}(w) = L(w)+ \lambda w^2 \]


L2正则化在接近0的时候越来越慢。

训练VGG

代码如下:

m.weight.grad.data.add_(args.s*torch.sign(m.weight.data))实际上是在进行梯度下降的步骤中直接对权重应用L1正则化。由于梯度下降的步骤是根据损失函数的导数(也就是梯度)来更新权重,L1正则化的导数是权重的符号(因为绝对值函数在正值上的导数为1,在负值上的导数为-1)。这就是为什么代码中使用torch.sign(m.weight.data)

实际上,这一步并没有直接添加绝对值,而是添加了绝对值函数的导数,这样在进行梯度下降更新时,就已经包含了L1正则化的效果。当权重为正时,它的梯度(也就是导数)就会减小(因为添加了-1),使得权重在更新时向0移动;当权重为负时,它的梯度就会增大(因为添加了1),同样使得权重在更新时向0移动。这就是L1正则化鼓励权重稀疏的原理。

m.weight.grad.data.add_() 它实现了对变量的原地加法操作。在这里,m.weight 是模型的权重,grad 是这些权重的梯度(即,损失函数关于这些权重的导数),data 是这些梯度的具体数值,add_() 是一个原地(in-place)操作,它直接在原有的数据上进行加法操作,而不是创建一个新的数据副本。

import os
import torch
import argparse
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from models.vgg import VGG
from utils import get_training_dataloader, get_test_dataloader, save_checkpoint
def parse_opt():
    parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR training')
    parser.add_argument('--dataset', type=str, default='cifar100', help='training dataset (default: cifar100)')
    parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true', help='train with channel sparsity regularization')
    parser.add_argument('--s', type=float, default=0.0001, help='scale sparse rate (default: 0.0001)')
    parser.add_argument('--refine', default='', type=str, metavar='PATH', help='path to the pruned model to be fine tuned')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', help='input batch size for testing (default: 256)')
    parser.add_argument('--epochs', type=int, default=160, metavar='N', help='number of epochs to train (default: 160)')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.1)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)')
    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
    parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=100, metavar='N', help='how many batches to wait before logging training status')
    parser.add_argument('--save', default='./logs', type=str, metavar='PATH', help='path to save prune model (default: current directory)')
    parser.add_argument('--arch', default='vgg', type=str,  help='architecture to use')
    parser.add_argument('--depth', default=19, type=int, help='depth of the neural network')

    args = parser.parse_args()
    return args
def updateBN():
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.weight.grad.data.add_(args.s*torch.sign(m.weight.data))


def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        if args.sr:
            updateBN()
        #optimizer.step() 执行的是一步梯度下降(或者其他优化算法)的更新。它使用存储在参数的 .grad 属性中的梯度信息来更新参数的值。调用 optimizer.step() 之后,所有的梯度会被清零,因此在下一轮迭代之前,你需要再次计算新的梯度
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), 
                len(train_loader.dataset),
                100. * batch_idx / len(train_loader), 
                loss.item()))


def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            #keepdim=True表示保持原张量的维度,并在第一维度上求最大值
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
        test_loss, correct, 
        len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))

if __name__ == '__main__':
    args = parse_opt()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    if not os.path.exists(args.save):
        os.makedirs(args.save)
    #这里kwargs的存在让代码更加灵活,如果kwargs是{'num_workers': 1, 'pin_memory': True},那么这个函数调用实际上等同于:
    #train_loader = get_training_dataloader(batch_size=args.batch_size, num_workers=1, pin_memory=True)
	#test_loader  = get_test_dataloader(batch_size=args.test_batch_size, num_workers=1, pin_memory=True)
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'cifar10':
        train_loader = get_training_dataloader(batch_size=args.batch_size, **kwargs)
        test_loader  = get_test_dataloader(batch_size=args.test_batch_size, **kwargs)
    if args.refine:
        checkpoint = torch.load(args.refine)
        model = VGG(depth=args.depth, cfg=checkpoint['cfg'])
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model = VGG(depth=args.depth)
    if args.cuda:
        model.cuda()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            # Load the checkpoint file
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
                .format(args.resume, checkpoint['epoch'], best_prec1))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            
    best_prec1 = 0.
    for epoch in range(args.start_epoch, args.epochs):
        #如果当前周期是总周期数的50%或75%,则将学习率减小10倍。这是一种常见的学习率调度策略,可以在训练过程中动态调整学习率。
        if epoch in [args.epochs*0.5, args.epochs*0.75]:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
        train(epoch)
        prec1 = test()
        is_best = prec1 > best_prec1

        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
        }, is_best, filepath=args.save)
    print("Best accuracy: "+str(best_prec1))

剪枝

复现iccv 2017年论文,对BN层gamma做L1正则化的操作

考虑一个问题,深度学习模型包含非常多的层和参数,在这里面有没有一些没有价值的特征和相关连接呢?又该如何去判断一些特征和连接是否有价值呢?答案:在BN层的缩放因子(gamma)上做L1正则化

一些优点如下

  • 不需要对现有 CNN 架构进行任何更改
  • 使用 L1 正则化将 BN 缩放因子的值推向零
    • 使我们能够识别不重要的通道(或神经元),因为每个缩放因子对应于特定的卷积通道(或全连接层的神经元)
    • 这有利于在接下来的步骤中进行通道级剪枝
  • 附加的正则化项很少会损害性能。不仅如此,在某些情况下,它会导致更高的泛化精度
  • 剪枝不重要的通道有时可能会暂时降低性能,但这个效应可以通过接下来的修剪网络的微调来弥补
  • 剪枝后,由此得到的较窄的网络在模型大小、运行时内存和计算操作方面比初始的宽网络更加紧凑。上述过程可以重复几次,得到一个多通道网络瘦身方案,从而实现更加紧凑的网络

    第一步就是普通的训练,获得原模型的训练权重,可以拿来作为一个对比的基准,也可以拿来进行稀疏训练

第二步是进行稀疏训练,在BN层上加一个简单的梯度

第三步根据一些规则去掉一些参数和神经元

第四步是微调fine-tune

对模型进行剪枝主要针对有参数的层:conv2d、BN2d、Linear,而pool2d只用来做下采样,没有可学习的参数,不用处理,代码如下

import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from models.vgg import VGG
from utils import get_test_dataloader

def parse_opt():
    parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
    parser.add_argument('--dataset', type=str, default='cifar100', help='training dataset (default: cifar10)')
    parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', help='input batch size for testing (default: 256)')
    parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
    parser.add_argument('--depth', type=int, default=19, help='depth of the vgg')
    parser.add_argument('--percent', type=float, default=0.5, help='scale sparse rate (default: 0.5)')
    parser.add_argument('--model', default='', type=str, metavar='PATH', help='path to the model (default: none)')
    parser.add_argument('--save', default='logs/', type=str, metavar='PATH', help='path to save pruned model (default: none)')
    args = parser.parse_args()
    return args

def test(model):
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    
    if args.dataset == 'cifar10':
        test_loader = get_test_dataloader(batch_size=args.test_batch_size, **kwargs)
    else:
        raise ValueError("No valid dataset is given.")
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            output = model(data)
            pred = output.data.max(1, keepdim=True)[1] 
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    accuracy = 100. * correct / len(test_loader.dataset)
    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), accuracy))
    return accuracy / 100.



if __name__ == '__main__':
    args = parse_opt()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    model = VGG(depth=args.depth)
    if args.cuda:
        model.cuda()
    if args.model:
        if os.path.isfile(args.model):
            print("=> loading checkpoint '{}'".format(args.model))
            checkpoint = torch.load(args.model)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
                .format(args.model, checkpoint['epoch'], best_prec1))
        else:
            print("=> no checkpoint found at '{}'".format(args.model))
            
    print(model)
    total = 0
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0]
    bn = torch.zeros(total)
    index = 0
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            size = m.weight.data.shape[0]
            bn[index:(index+size)] = m.weight.data.abs().clone()
            index += size
    #返回排序后的列表y,i的每个元素表示对应位置的元素在原始bn张量中的位置
    #例如,如果我们有一个张量bn = torch.tensor([1.5, 0.5, 2.0, 1.0]),那么torch.sort(bn)将返回y = torch.tensor([0.5, 1.0, 1.5, 2.0])和i = torch.tensor([1, 3, 0, 2])。这表示原始张量中的最小元素是0.5,位于第1位置,其次是1.0,位于第3位置,以此类推。
    y, i = torch.sort(bn)
    #剪的个数
    thre_index = int(total * args.percent)
    #阈值
    thre = y[thre_index]
    
    pruned = 0
    cfg = []
    cfg_mask = []
    #每个模块m和它的索引k
    for k, m in enumerate(model.modules()):
        if isinstance(m, nn.BatchNorm2d):
            weight_copy = m.weight.data.abs().clone()
            #例如,如果我们有一个张量a = torch.tensor([1, 2, 3, 4, 5]),我们可以使用.gt(3)来检查a中的每个元素是否大于3
            mask = weight_copy.gt(thre).float().cuda()
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            cfg.append(int(torch.sum(mask)))
            cfg_mask.append(mask.clone())
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
                format(k, mask.shape[0], int(torch.sum(mask))))
        elif isinstance(m, nn.MaxPool2d):
            cfg.append('M')
    pruned_ratio = pruned/total
    print('Pre-processing Successful!')
    acc = test(model)
    
    print(cfg)
    newmodel = VGG(cfg=cfg)
    if args.cuda:
        newmodel.cuda()
    num_parameters = sum([param.nelement() for param in newmodel.parameters()])
    savepath = os.path.join(args.save, "prune.txt")
    with open(savepath, "w") as fp:
        fp.write("Configuration: \n"+str(cfg)+"\n")
        fp.write("Number of parameters: "+str(num_parameters)+"\n")
        fp.write("Test accuracy: "+str(acc))

    layer_id_in_cfg = 0
    start_mask = torch.ones(3)
    end_mask = cfg_mask[layer_id_in_cfg]
    
    for [m0, m1] in zip(model.modules(), newmodel.modules()):
        
        if isinstance(m0, nn.BatchNorm2d):
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            if idx1.size == 1:
                idx1 = np.resize(idx1,(1,))
            m1.weight.data = m0.weight.data[idx1.tolist()].clone()
            m1.bias.data = m0.bias.data[idx1.tolist()].clone()
            m1.running_mean = m0.running_mean[idx1.tolist()].clone()
            m1.running_var = m0.running_var[idx1.tolist()].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):
                end_mask = cfg_mask[layer_id_in_cfg]
        elif isinstance(m0, nn.Conv2d):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1,))
            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
            w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()
            
        elif isinstance(m0, nn.Linear):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            m1.weight.data = m0.weight.data[:, idx0].clone()
            m1.bias.data   = m0.bias.data.clone()

    torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth'))

    print(newmodel)
    model = newmodel
    test(model)