python 量化模版

发布时间 2024-01-04 22:18:55作者: 心比天高xzh

以imagecrop为例,二分类,输出准确率与召回率,可调阈值。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import torchvision
import torchvision.transforms as transforms
import os
import _pickle as cPickle
import argparse
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
import math
import sys
from PIL import Image
from torch.utils.data import Dataset
from torchsummary import summary 
def dequantize(tq, qt):
    return tq / qt
max_value={}
max_value['input']={}
max_value['weight']={}
max_value['bias']={}
max_value['output']={}
def my_round(x):
    flag = torch.ones_like(x)*0.5
    mask = (torch.ceil(x)-x==flag).float()
    res_ceil = torch.ceil(x)
    res_round = torch.round(x)
    res_ceil = torch.mul(res_ceil, mask)
    res_round = torch.mul(res_round, torch.abs(mask-1))
    res = res_ceil + res_round
    return res
def quantize(tf, signed, part, layer=0,bit=8):
    tf_max = torch.max(torch.abs(tf))
    if(layer in max_value[part]):
        tf_max = max_value[part][layer] if (max_value[part][layer]>tf_max)  else tf_max
    max_value[part][layer]=tf_max
    top_edge=math.pow(2,bit)-1
    if signed:
        qt = (top_edge-1)/2/tf_max
    else:
        qt = top_edge/tf_max
    qt = torch.log2(qt)
    qt = qt.floor()
    qt = torch.pow(2,qt)
    tq = tf * qt
    if signed:
        max_value[part][layer] = (top_edge-1)/2/qt
    else:
        max_value[part][layer] = top_edge/qt
    if signed:
        tq.clamp_(-(top_edge-1)/2,(top_edge-1)/2)
        tq = my_round(tq)
    else:
        tq.clamp_(0, top_edge)
        tq = my_round(tq)
    return tq, qt
class Fakequant(Function):
    def forward(self, tf, signed,part,layer,bit):
        tq, qt= quantize(tf, signed, part, layer, bit)
        dtf = dequantize(tq, qt)
        return dtf
    def backward(self, grad_output):
        return grad_output,None,None,None,None
class imagecrop_data(Dataset):
    def __init__(self, data_dir):
        super().__init__()
        self.imgpath_1 = os.path.join(data_dir,'type')
        self.imgpath_0 = os.path.join(data_dir,'no_type')
        
        self.transform = transforms.Compose(
            [
                # transforms.Resize(size = (512,512)),#尺寸规范
                # transforms.RandomResizedCrop((512,512)),
                # transforms.RandomCrop((512,512), padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                # transforms.RandomRotation(45),
                # transforms.ColorJitter(contrast=0.5),
                transforms.Grayscale(num_output_channels=1),
                transforms.ToTensor(),   #转化为tensor
                # transforms.Normalize((0.5), (0.5)),
                
            ])# Transforms只适用于PIL 中的Image打开的图像
        #image = Image.open('/Dataset/imagecrop/type/1.png')
        #image = self.transform(image) 
        #image = image.reshape(256,32,32)
        #print(image.shape)
        #sys.exit(0)
    def __getitem__(self, index):
        if index >= len(os.listdir(self.imgpath_0)):
            label = 1
            index_1 = index - len(os.listdir(self.imgpath_0))
            imgpath = self.imgpath_1
        else:
            label = 0
            index_1 = index
            imgpath = self.imgpath_0
        name = str(index_1+1)+'.png'
        path = os.path.join(imgpath,name)
        image = Image.open(path)
        # image = np.expand_dims(image,axis=0)
        # image = torch.FloatTensor(image).permute(2,0,1)
        image = self.transform(image)
        image = image.reshape(256,32,32)
        # label =  np.reshape(label,(1,))
        label = torch.as_tensor(label, dtype=torch.int64)
        # label = torch.FloatTensor(label)
        return image,label
    #def __getitem__(self, index):
    #    if index >= 5000:
    #        label = 1
    #        index_1 = index - 5000
    #        imgpath = self.imgpath_1
    #    else:
    #        label = 0
    #        index_1 = index
    #        imgpath = self.imgpath_0
    #    name = str(index_1+1)+'.png'
    #    path = os.path.join(imgpath,name)
    #    image = Image.open(path)
    #    # image = np.expand_dims(image,axis=0)
    #    # image = torch.FloatTensor(image).permute(2,0,1)
    #    image = self.transform(image)
    #    image = image.reshape(256,32,32)
    #    # label =  np.reshape(label,(1,))
    #    label = torch.as_tensor(label, dtype=torch.int64)
    #    # label = torch.FloatTensor(label)
    #    return image,label

    def __len__(self):
        return len(os.listdir(self.imgpath_0))+len(os.listdir(self.imgpath_1))
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.layer1 = nn.Conv2d(256,256,3,1,1)
        self.layer2 = nn.BatchNorm2d(256)
        self.layer3 = nn.ReLU(inplace=True)
        self.layer4 = nn.Conv2d(256,256,3,1,1)
        self.layer5 = nn.BatchNorm2d(256)
        self.layer6 = nn.ReLU(inplace=True)
        self.layer7 = nn.MaxPool2d((2,2))
        self.layer8 = nn.Conv2d(256,256,3,1,1)
        self.layer9 = nn.BatchNorm2d(256)
        self.layer10= nn.ReLU(inplace=True)
        self.layer11= nn.Conv2d(256,256,3,1,1)
        self.layer12= nn.BatchNorm2d(256)
        self.layer13= nn.ReLU(inplace=True)
        self.layer14= nn.MaxPool2d((2,2))
        self.layer15= nn.Conv2d(256,256,3,1,1)
        self.layer16= nn.BatchNorm2d(256)
        self.layer17= nn.ReLU(inplace=True)
        self.layer18= nn.Conv2d(256,256,3,1,1)
        self.layer19= nn.BatchNorm2d(256)
        self.layer20= nn.ReLU(inplace=True)
        self.layer21= nn.MaxPool2d((2,2))
        self.layer22= nn.Conv2d(256,256,3,1,1)
        self.layer23= nn.BatchNorm2d(256)
        self.layer24= nn.ReLU(inplace=True)
        self.layer25= nn.Conv2d(256,256,3,1,1)
        self.layer26= nn.BatchNorm2d(256)
        self.layer27= nn.ReLU(inplace=True)
        self.layer28= nn.MaxPool2d((2,2))
        self.layer29= nn.Flatten()
        self.layer30= nn.Linear(1024,256)
        self.layer31= nn.BatchNorm1d(256)
        self.layer32= nn.ReLU(inplace=True)
        self.layer33= nn.Linear(256,2)
        self.layer34= nn.BatchNorm1d(2)
        self.layer35= nn.ReLU(inplace=True)
        self.layer36= nn.Softmax(dim=1)
        self.layer_number = 36
        for layer in range(1, self.layer_number+1):
            module = getattr(self, 'layer'+str(layer))
            if isinstance(module, nn.Conv2d):# or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
                #nn.init.xavier_uniform_(module.weight, gain=nn.init.calculate_gain('relu'))
    def forward(self,x):
        for layer in range(1,self.layer_number+1):
            module = getattr(self,"layer"+str(layer))
            x = module(x)
        out = x 
        return out
class qat_Cifar10Net(Net):
    def __init__(self):
        super(qat_Cifar10Net, self).__init__()
        for layer in range(1,self.layer_number+1):
            module = getattr(self,"layer"+str(layer))
            if(module.__class__.__name__ in ['Conv2d','Linear']):
                setattr(self,'layer'+str(layer)+'_qw',0)
                setattr(self,'layer'+str(layer)+'_qb',0) 
                setattr(self,'layer'+str(layer)+'_qy',0)
                setattr(self,'layer'+str(layer)+'_qx',0)
        self.w_bits=6
        self.i_bits=6
        self.b_bits=6
    def adjust_q(self,layer,left_edge,right_edge):
        qw=getattr(self,'layer'+str(layer)+'_qw')
        qy=getattr(self,'layer'+str(layer)+'_qy')
        qx=getattr(self,'layer'+str(layer)+'_qx')
        if qx*qw/qy>right_edge:
            qw_new = right_edge*qy/qx
            max_value['weight'][layer] = max_value['weight'][layer]*qw/qw_new
        elif qx*qw/qy<left_edge:
            qy_new = qx*qw/left_edge
            max_value['output'][layer] = max_value['output'][layer]*qy/qy_new
            
    def forward(self,x):
        for layer in range(1,self.layer_number+1):          
            module = getattr(self,"layer"+str(layer))
            if(module.__class__.__name__ in ['Conv2d','Linear']):
                _,x_factor=quantize(x, layer==1, 'input', layer,self.i_bits)
                x = Fakequant.apply(x, layer==1, 'input', layer,self.i_bits)
                setattr(self,'layer'+str(layer)+'_qx',x_factor)
                _,w_factor=quantize(module.weight, True,'weight', layer,self.w_bits)
                setattr(self,'layer'+str(layer)+'_qw',w_factor) 
                if(module.__class__.__name__=='Conv2d'):
                    x_copy = F.conv2d(x, Fakequant.apply(module.weight,True,'weight',layer,self.w_bits), \
                               Fakequant.apply(module.bias,True,'bias',layer,self.b_bits), stride=module.stride, \
                               padding=module.padding, dilation=module.dilation)
                else:
                    x_copy = F.linear(x, Fakequant.apply(module.weight, True, 'weight', layer,self.w_bits), \
                             Fakequant.apply(module.bias, True , 'bias', layer,self.b_bits))
                _,y_factor=quantize(x_copy, True, 'output', layer,self.i_bits)
                setattr(self,'layer'+str(layer)+'_qy',y_factor)
                self.adjust_q(layer,32,128)
                _,w_factor=quantize(module.weight, True,'weight', layer,self.w_bits)
                setattr(self,'layer'+str(layer)+'_qw',w_factor) 
                if(module.__class__.__name__=='Conv2d'):
                    x = F.conv2d(x, Fakequant.apply(module.weight,True,'weight',layer,self.w_bits), \
                               Fakequant.apply(module.bias,True,'bias',layer,self.b_bits), stride=module.stride, \
                               padding=module.padding, dilation=module.dilation)
                else:
                    x = F.linear(x, Fakequant.apply(module.weight, True, 'weight', layer,self.w_bits), \
                             Fakequant.apply(module.bias, True , 'bias', layer,self.b_bits))
                _,y_factor=quantize(x, True,'output',layer,self.i_bits)
                x = Fakequant.apply(x, True,'output',layer,self.i_bits)
                setattr(self,'layer'+str(layer)+'_qy',y_factor)
            else:
                x = module(x)
            
        return x
    def save_qat(self):
        for layer in range(1,self.layer_number+1):
            module = getattr(self,"layer"+str(layer))
            if(module.__class__.__name__ in ['Conv2d','Linear']):
                module.weight.data, qw = quantize(module.weight.data, True, 'weight', layer,self.w_bits)
                module.bias.data, qb = quantize(module.bias.data, True, 'bias', layer,self.b_bits)
                setattr(self,'layer'+str(layer)+'_qw',qw)
                setattr(self,'layer'+str(layer)+'_qb',qb)
    def qat_forward(self,x):
        for layer in range(1,self.layer_number+1):
            module = getattr(self,"layer"+str(layer))
            if(module.__class__.__name__ in ['Conv2d','Linear']):
                x = x * getattr(self,'layer'+str(layer)+'_qx')
                x = my_round(x)
                x = x / getattr(self,'layer'+str(layer)+'_qx')
                if(module.__class__.__name__ =='Conv2d'):
                    x = F.conv2d(x, module.weight.data/getattr(self,'layer'+str(layer)+'_qw'), \
                                module.bias.data/getattr(self,'layer'+str(layer)+'_qb')*getattr(self,'layer'+str(layer)+'_qx'), \
                                stride=module.stride, \
                                padding=module.padding, dilation=module.dilation)
                else:
                    x = F.linear(x, module.weight.data/getattr(self,'layer'+str(layer)+'_qw'), \
                               module.bias.data/getattr(self,'layer'+str(layer)+'_qb')*getattr(self,'layer'+str(layer)+'_qx'))
                x = x * getattr(self,'layer'+str(layer)+'_qy')
                x = my_round(x)
                x = x / getattr(self,'layer'+str(layer)+'_qy')
            else:
                x = module(x)
        out = x
        return out
    def display_qat(self,write_f):
        print_log = open(write_f,'w')
        for layer in range(1,self.layer_number+1):
            module = getattr(self,"layer"+str(layer))
            if(module.__class__.__name__ in ['Conv2d','Linear']):
                print("=========="+"layer"+str(layer)+"========",file=print_log)
                print("qx:"+str(getattr(self,'layer'+str(layer)+'_qx')),file=print_log)
                print("qy:"+str(getattr(self,'layer'+str(layer)+'_qy')),file=print_log)
                print("qw:"+str(getattr(self,'layer'+str(layer)+'_qw')),file=print_log)
                print("qb:"+str(getattr(self,'layer'+str(layer)+'_qb')),file=print_log)
                print("Q:"+str(getattr(self,'layer'+str(layer)+'_qw')*getattr(self,'layer'+str(layer)+'_qx')/getattr(self,'layer'+str(layer)+'_qy')),file=print_log)
        print_log.close()
    def save_each_layer(self):
        for layer in range(1,self.layer_number+1):
            module = getattr(self,"layer"+str(layer))
            if(module.__class__.__name__ in ['Conv2d','Linear']):
                torch.save(module.weight.data,'./each_layer_pt/layer'+str(layer)+'_weight.pt')
                torch.save(module.bias.data,'./each_layer_pt/layer'+str(layer)+'_bias.pt')
def train_one_epoch(model, optimizer, data_loader, device, epoch, threshold=0.5):
    model.train()
    loss_function = torch.nn.CrossEntropyLoss()
    accu_loss = torch.zeros(1).to(device)  # 累计损失
    accu_num = torch.zeros(1).to(device)   # 累计预测正确的样本数
    cor_num = torch.zeros(1).to(device)   # 累计所有正确的样本数
    pre_cor_num = torch.zeros(1).to(device)   # 累计所有预测正确且实际正确的样本数
    optimizer.zero_grad()
    sample_num = 0
    data_loader = tqdm(data_loader, file=sys.stdout)
    for step, data in enumerate(data_loader):
        images, labels = data
        sample_num += images.shape[0]
        pred = model(images.to(device))
        #pred_classes = torch.max(pred, dim=1)[1]
        pred_classes = pred[:,1]>=threshold
        accu_num += torch.eq(pred_classes, labels.to(device)).sum()
        cor_num += (labels.to(device)==1).sum()
        pre_cor_num += (pred_classes * labels.to(device)).sum()
        loss = loss_function(pred, labels.to(device))
        loss.backward()
        accu_loss += loss.detach()
        data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.4f}, rec: {:.4f}".format(epoch,
                                                                               accu_loss.item() / (step + 1),
                                                                               accu_num.item() / sample_num,
                                                                               pre_cor_num.item() / cor_num.item())
        if not torch.isfinite(loss):
            print('WARNING: non-finite loss, ending training ', loss)
            sys.exit(1)
        optimizer.step()
        optimizer.zero_grad()
    return accu_loss.item() / (step + 1), accu_num.item() / sample_num
@torch.no_grad()
def evaluate(model, data_loader, device, epoch, qat=False, threshold=0.5):
    loss_function = torch.nn.CrossEntropyLoss()
    model.eval()
    accu_num = torch.zeros(1).to(device)   # 累计预测正确的样本数
    accu_loss = torch.zeros(1).to(device)  # 累计损失
    cor_num = torch.zeros(1).to(device)   # 累计所有正确的样本数
    pre_cor_num = torch.zeros(1).to(device)   # 累计所有预测正确且实际正确的样本数
    sample_num = 0
    data_loader = tqdm(data_loader, file=sys.stdout)
    for step, data in enumerate(data_loader):
        images, labels = data
        sample_num += images.shape[0]
        if qat:
            images = images.to(device)
            pred = model.qat_forward(images)
        else: 
            pred = model(images.to(device))
        #pred_classes = torch.max(pred, dim=1)[1]
        pred_classes = pred[:,1]>=threshold
        accu_num += torch.eq(pred_classes, labels.to(device)).sum()
        cor_num += (labels.to(device)==1).sum()
        pre_cor_num += (pred_classes * labels.to(device)).sum()
        loss = loss_function(pred, labels.to(device))
        accu_loss += loss
        data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.4f}, rec: {:.4f}".format(epoch,
                                                                               accu_loss.item() / (step + 1),
                                                                               accu_num.item() / sample_num,
                                                                               pre_cor_num.item() / cor_num.item())

    return accu_loss.item() / (step + 1), accu_num.item() / sample_num, pre_cor_num.item() / cor_num.item()
def tradeoff(model,trainloader,testloader,device):
    print_log = open('threshold.txt','w')
    for i in np.arange(0.5,0,-0.05):    
        _,acc,rec = evaluate(model,trainloader,device,'last',False,i)
        print("train   threshold:{},acc:{:.4f},rec:{:.4f}".format(i,acc,rec),file=print_log)
        _,acc,rec = evaluate(model,testloader,device,'last',False,i)
        print("test   threshold:{},acc:{:.4f},rec:{:.4f}".format(i,acc,rec),file=print_log)
    print_log.close()
def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    #model = qat_Cifar10Net().to(device)
    model = qat_Cifar10Net().to(device)
    #model = torch.nn.DataParallel(model)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 20, gamma = 0.1)
    #summary(model,(256,32,32),device='cuda')
    #sys.exit(0)
    model.load_state_dict(torch.load(args.weights)) 
    all_dataset = imagecrop_data('/Dataset/imagecrop')
    train_size = int(len(all_dataset) * 0.8)
    test_size = len(all_dataset) - train_size
    trainset, testset = torch.utils.data.random_split(all_dataset, [train_size, test_size])
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              #pin_memory=True,
                                              num_workers=nw)
    testloader = torch.utils.data.DataLoader(testset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              #pin_memory=True,
                                              num_workers=nw)
    print('训练数据个数:%d,测试数据个数%d'%(len(trainset),len(testset)))
    #tradeoff(model,trainloader,testloader,device)
    #sys.exit(0)
    for epoch in range(args.epochs):
        train_loss = train_one_epoch(model=model,
                                     optimizer=optimizer,
                                     data_loader=trainloader,
                                     device=device,
                                     epoch=epoch)
        scheduler.step()
        val_loss = evaluate(model=model,
                                data_loader=testloader,
                                device=device,
                                epoch=epoch)
        torch.save(model.state_dict(), "./qat_epoch_pkl/model-{}-{}.pth".format(epoch,args.lr))
    model.save_qat()
    evaluate(model, testloader, device,"final",True)
    torch.save(model, "final_model.pth")
    model.display_qat("q.txt")
    #model.save_each_layer()
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default='./trained_noqat_dic_all.pkl',
                    help='initial weights path')
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--lrf', type=float, default=0.01)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
    opt = parser.parse_args()
    main(opt)