以imagecrop为例,二分类,输出准确率与召回率,可调阈值。
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import torchvision
import torchvision.transforms as transforms
import os
import _pickle as cPickle
import argparse
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
import math
import sys
from PIL import Image
from torch.utils.data import Dataset
from torchsummary import summary
def dequantize(tq, qt):
return tq / qt
max_value={}
max_value['input']={}
max_value['weight']={}
max_value['bias']={}
max_value['output']={}
def my_round(x):
flag = torch.ones_like(x)*0.5
mask = (torch.ceil(x)-x==flag).float()
res_ceil = torch.ceil(x)
res_round = torch.round(x)
res_ceil = torch.mul(res_ceil, mask)
res_round = torch.mul(res_round, torch.abs(mask-1))
res = res_ceil + res_round
return res
def quantize(tf, signed, part, layer=0,bit=8):
tf_max = torch.max(torch.abs(tf))
if(layer in max_value[part]):
tf_max = max_value[part][layer] if (max_value[part][layer]>tf_max) else tf_max
max_value[part][layer]=tf_max
top_edge=math.pow(2,bit)-1
if signed:
qt = (top_edge-1)/2/tf_max
else:
qt = top_edge/tf_max
qt = torch.log2(qt)
qt = qt.floor()
qt = torch.pow(2,qt)
tq = tf * qt
if signed:
max_value[part][layer] = (top_edge-1)/2/qt
else:
max_value[part][layer] = top_edge/qt
if signed:
tq.clamp_(-(top_edge-1)/2,(top_edge-1)/2)
tq = my_round(tq)
else:
tq.clamp_(0, top_edge)
tq = my_round(tq)
return tq, qt
class Fakequant(Function):
def forward(self, tf, signed,part,layer,bit):
tq, qt= quantize(tf, signed, part, layer, bit)
dtf = dequantize(tq, qt)
return dtf
def backward(self, grad_output):
return grad_output,None,None,None,None
class imagecrop_data(Dataset):
def __init__(self, data_dir):
super().__init__()
self.imgpath_1 = os.path.join(data_dir,'type')
self.imgpath_0 = os.path.join(data_dir,'no_type')
self.transform = transforms.Compose(
[
# transforms.Resize(size = (512,512)),#尺寸规范
# transforms.RandomResizedCrop((512,512)),
# transforms.RandomCrop((512,512), padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
# transforms.RandomRotation(45),
# transforms.ColorJitter(contrast=0.5),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(), #转化为tensor
# transforms.Normalize((0.5), (0.5)),
])# Transforms只适用于PIL 中的Image打开的图像
#image = Image.open('/Dataset/imagecrop/type/1.png')
#image = self.transform(image)
#image = image.reshape(256,32,32)
#print(image.shape)
#sys.exit(0)
def __getitem__(self, index):
if index >= len(os.listdir(self.imgpath_0)):
label = 1
index_1 = index - len(os.listdir(self.imgpath_0))
imgpath = self.imgpath_1
else:
label = 0
index_1 = index
imgpath = self.imgpath_0
name = str(index_1+1)+'.png'
path = os.path.join(imgpath,name)
image = Image.open(path)
# image = np.expand_dims(image,axis=0)
# image = torch.FloatTensor(image).permute(2,0,1)
image = self.transform(image)
image = image.reshape(256,32,32)
# label = np.reshape(label,(1,))
label = torch.as_tensor(label, dtype=torch.int64)
# label = torch.FloatTensor(label)
return image,label
#def __getitem__(self, index):
# if index >= 5000:
# label = 1
# index_1 = index - 5000
# imgpath = self.imgpath_1
# else:
# label = 0
# index_1 = index
# imgpath = self.imgpath_0
# name = str(index_1+1)+'.png'
# path = os.path.join(imgpath,name)
# image = Image.open(path)
# # image = np.expand_dims(image,axis=0)
# # image = torch.FloatTensor(image).permute(2,0,1)
# image = self.transform(image)
# image = image.reshape(256,32,32)
# # label = np.reshape(label,(1,))
# label = torch.as_tensor(label, dtype=torch.int64)
# # label = torch.FloatTensor(label)
# return image,label
def __len__(self):
return len(os.listdir(self.imgpath_0))+len(os.listdir(self.imgpath_1))
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.layer1 = nn.Conv2d(256,256,3,1,1)
self.layer2 = nn.BatchNorm2d(256)
self.layer3 = nn.ReLU(inplace=True)
self.layer4 = nn.Conv2d(256,256,3,1,1)
self.layer5 = nn.BatchNorm2d(256)
self.layer6 = nn.ReLU(inplace=True)
self.layer7 = nn.MaxPool2d((2,2))
self.layer8 = nn.Conv2d(256,256,3,1,1)
self.layer9 = nn.BatchNorm2d(256)
self.layer10= nn.ReLU(inplace=True)
self.layer11= nn.Conv2d(256,256,3,1,1)
self.layer12= nn.BatchNorm2d(256)
self.layer13= nn.ReLU(inplace=True)
self.layer14= nn.MaxPool2d((2,2))
self.layer15= nn.Conv2d(256,256,3,1,1)
self.layer16= nn.BatchNorm2d(256)
self.layer17= nn.ReLU(inplace=True)
self.layer18= nn.Conv2d(256,256,3,1,1)
self.layer19= nn.BatchNorm2d(256)
self.layer20= nn.ReLU(inplace=True)
self.layer21= nn.MaxPool2d((2,2))
self.layer22= nn.Conv2d(256,256,3,1,1)
self.layer23= nn.BatchNorm2d(256)
self.layer24= nn.ReLU(inplace=True)
self.layer25= nn.Conv2d(256,256,3,1,1)
self.layer26= nn.BatchNorm2d(256)
self.layer27= nn.ReLU(inplace=True)
self.layer28= nn.MaxPool2d((2,2))
self.layer29= nn.Flatten()
self.layer30= nn.Linear(1024,256)
self.layer31= nn.BatchNorm1d(256)
self.layer32= nn.ReLU(inplace=True)
self.layer33= nn.Linear(256,2)
self.layer34= nn.BatchNorm1d(2)
self.layer35= nn.ReLU(inplace=True)
self.layer36= nn.Softmax(dim=1)
self.layer_number = 36
for layer in range(1, self.layer_number+1):
module = getattr(self, 'layer'+str(layer))
if isinstance(module, nn.Conv2d):# or isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
#nn.init.xavier_uniform_(module.weight, gain=nn.init.calculate_gain('relu'))
def forward(self,x):
for layer in range(1,self.layer_number+1):
module = getattr(self,"layer"+str(layer))
x = module(x)
out = x
return out
class qat_Cifar10Net(Net):
def __init__(self):
super(qat_Cifar10Net, self).__init__()
for layer in range(1,self.layer_number+1):
module = getattr(self,"layer"+str(layer))
if(module.__class__.__name__ in ['Conv2d','Linear']):
setattr(self,'layer'+str(layer)+'_qw',0)
setattr(self,'layer'+str(layer)+'_qb',0)
setattr(self,'layer'+str(layer)+'_qy',0)
setattr(self,'layer'+str(layer)+'_qx',0)
self.w_bits=6
self.i_bits=6
self.b_bits=6
def adjust_q(self,layer,left_edge,right_edge):
qw=getattr(self,'layer'+str(layer)+'_qw')
qy=getattr(self,'layer'+str(layer)+'_qy')
qx=getattr(self,'layer'+str(layer)+'_qx')
if qx*qw/qy>right_edge:
qw_new = right_edge*qy/qx
max_value['weight'][layer] = max_value['weight'][layer]*qw/qw_new
elif qx*qw/qy<left_edge:
qy_new = qx*qw/left_edge
max_value['output'][layer] = max_value['output'][layer]*qy/qy_new
def forward(self,x):
for layer in range(1,self.layer_number+1):
module = getattr(self,"layer"+str(layer))
if(module.__class__.__name__ in ['Conv2d','Linear']):
_,x_factor=quantize(x, layer==1, 'input', layer,self.i_bits)
x = Fakequant.apply(x, layer==1, 'input', layer,self.i_bits)
setattr(self,'layer'+str(layer)+'_qx',x_factor)
_,w_factor=quantize(module.weight, True,'weight', layer,self.w_bits)
setattr(self,'layer'+str(layer)+'_qw',w_factor)
if(module.__class__.__name__=='Conv2d'):
x_copy = F.conv2d(x, Fakequant.apply(module.weight,True,'weight',layer,self.w_bits), \
Fakequant.apply(module.bias,True,'bias',layer,self.b_bits), stride=module.stride, \
padding=module.padding, dilation=module.dilation)
else:
x_copy = F.linear(x, Fakequant.apply(module.weight, True, 'weight', layer,self.w_bits), \
Fakequant.apply(module.bias, True , 'bias', layer,self.b_bits))
_,y_factor=quantize(x_copy, True, 'output', layer,self.i_bits)
setattr(self,'layer'+str(layer)+'_qy',y_factor)
self.adjust_q(layer,32,128)
_,w_factor=quantize(module.weight, True,'weight', layer,self.w_bits)
setattr(self,'layer'+str(layer)+'_qw',w_factor)
if(module.__class__.__name__=='Conv2d'):
x = F.conv2d(x, Fakequant.apply(module.weight,True,'weight',layer,self.w_bits), \
Fakequant.apply(module.bias,True,'bias',layer,self.b_bits), stride=module.stride, \
padding=module.padding, dilation=module.dilation)
else:
x = F.linear(x, Fakequant.apply(module.weight, True, 'weight', layer,self.w_bits), \
Fakequant.apply(module.bias, True , 'bias', layer,self.b_bits))
_,y_factor=quantize(x, True,'output',layer,self.i_bits)
x = Fakequant.apply(x, True,'output',layer,self.i_bits)
setattr(self,'layer'+str(layer)+'_qy',y_factor)
else:
x = module(x)
return x
def save_qat(self):
for layer in range(1,self.layer_number+1):
module = getattr(self,"layer"+str(layer))
if(module.__class__.__name__ in ['Conv2d','Linear']):
module.weight.data, qw = quantize(module.weight.data, True, 'weight', layer,self.w_bits)
module.bias.data, qb = quantize(module.bias.data, True, 'bias', layer,self.b_bits)
setattr(self,'layer'+str(layer)+'_qw',qw)
setattr(self,'layer'+str(layer)+'_qb',qb)
def qat_forward(self,x):
for layer in range(1,self.layer_number+1):
module = getattr(self,"layer"+str(layer))
if(module.__class__.__name__ in ['Conv2d','Linear']):
x = x * getattr(self,'layer'+str(layer)+'_qx')
x = my_round(x)
x = x / getattr(self,'layer'+str(layer)+'_qx')
if(module.__class__.__name__ =='Conv2d'):
x = F.conv2d(x, module.weight.data/getattr(self,'layer'+str(layer)+'_qw'), \
module.bias.data/getattr(self,'layer'+str(layer)+'_qb')*getattr(self,'layer'+str(layer)+'_qx'), \
stride=module.stride, \
padding=module.padding, dilation=module.dilation)
else:
x = F.linear(x, module.weight.data/getattr(self,'layer'+str(layer)+'_qw'), \
module.bias.data/getattr(self,'layer'+str(layer)+'_qb')*getattr(self,'layer'+str(layer)+'_qx'))
x = x * getattr(self,'layer'+str(layer)+'_qy')
x = my_round(x)
x = x / getattr(self,'layer'+str(layer)+'_qy')
else:
x = module(x)
out = x
return out
def display_qat(self,write_f):
print_log = open(write_f,'w')
for layer in range(1,self.layer_number+1):
module = getattr(self,"layer"+str(layer))
if(module.__class__.__name__ in ['Conv2d','Linear']):
print("=========="+"layer"+str(layer)+"========",file=print_log)
print("qx:"+str(getattr(self,'layer'+str(layer)+'_qx')),file=print_log)
print("qy:"+str(getattr(self,'layer'+str(layer)+'_qy')),file=print_log)
print("qw:"+str(getattr(self,'layer'+str(layer)+'_qw')),file=print_log)
print("qb:"+str(getattr(self,'layer'+str(layer)+'_qb')),file=print_log)
print("Q:"+str(getattr(self,'layer'+str(layer)+'_qw')*getattr(self,'layer'+str(layer)+'_qx')/getattr(self,'layer'+str(layer)+'_qy')),file=print_log)
print_log.close()
def save_each_layer(self):
for layer in range(1,self.layer_number+1):
module = getattr(self,"layer"+str(layer))
if(module.__class__.__name__ in ['Conv2d','Linear']):
torch.save(module.weight.data,'./each_layer_pt/layer'+str(layer)+'_weight.pt')
torch.save(module.bias.data,'./each_layer_pt/layer'+str(layer)+'_bias.pt')
def train_one_epoch(model, optimizer, data_loader, device, epoch, threshold=0.5):
model.train()
loss_function = torch.nn.CrossEntropyLoss()
accu_loss = torch.zeros(1).to(device) # 累计损失
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
cor_num = torch.zeros(1).to(device) # 累计所有正确的样本数
pre_cor_num = torch.zeros(1).to(device) # 累计所有预测正确且实际正确的样本数
optimizer.zero_grad()
sample_num = 0
data_loader = tqdm(data_loader, file=sys.stdout)
for step, data in enumerate(data_loader):
images, labels = data
sample_num += images.shape[0]
pred = model(images.to(device))
#pred_classes = torch.max(pred, dim=1)[1]
pred_classes = pred[:,1]>=threshold
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
cor_num += (labels.to(device)==1).sum()
pre_cor_num += (pred_classes * labels.to(device)).sum()
loss = loss_function(pred, labels.to(device))
loss.backward()
accu_loss += loss.detach()
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.4f}, rec: {:.4f}".format(epoch,
accu_loss.item() / (step + 1),
accu_num.item() / sample_num,
pre_cor_num.item() / cor_num.item())
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)
optimizer.step()
optimizer.zero_grad()
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
@torch.no_grad()
def evaluate(model, data_loader, device, epoch, qat=False, threshold=0.5):
loss_function = torch.nn.CrossEntropyLoss()
model.eval()
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
accu_loss = torch.zeros(1).to(device) # 累计损失
cor_num = torch.zeros(1).to(device) # 累计所有正确的样本数
pre_cor_num = torch.zeros(1).to(device) # 累计所有预测正确且实际正确的样本数
sample_num = 0
data_loader = tqdm(data_loader, file=sys.stdout)
for step, data in enumerate(data_loader):
images, labels = data
sample_num += images.shape[0]
if qat:
images = images.to(device)
pred = model.qat_forward(images)
else:
pred = model(images.to(device))
#pred_classes = torch.max(pred, dim=1)[1]
pred_classes = pred[:,1]>=threshold
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
cor_num += (labels.to(device)==1).sum()
pre_cor_num += (pred_classes * labels.to(device)).sum()
loss = loss_function(pred, labels.to(device))
accu_loss += loss
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.4f}, rec: {:.4f}".format(epoch,
accu_loss.item() / (step + 1),
accu_num.item() / sample_num,
pre_cor_num.item() / cor_num.item())
return accu_loss.item() / (step + 1), accu_num.item() / sample_num, pre_cor_num.item() / cor_num.item()
def tradeoff(model,trainloader,testloader,device):
print_log = open('threshold.txt','w')
for i in np.arange(0.5,0,-0.05):
_,acc,rec = evaluate(model,trainloader,device,'last',False,i)
print("train threshold:{},acc:{:.4f},rec:{:.4f}".format(i,acc,rec),file=print_log)
_,acc,rec = evaluate(model,testloader,device,'last',False,i)
print("test threshold:{},acc:{:.4f},rec:{:.4f}".format(i,acc,rec),file=print_log)
print_log.close()
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
batch_size = args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
#model = qat_Cifar10Net().to(device)
model = qat_Cifar10Net().to(device)
#model = torch.nn.DataParallel(model)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 20, gamma = 0.1)
#summary(model,(256,32,32),device='cuda')
#sys.exit(0)
model.load_state_dict(torch.load(args.weights))
all_dataset = imagecrop_data('/Dataset/imagecrop')
train_size = int(len(all_dataset) * 0.8)
test_size = len(all_dataset) - train_size
trainset, testset = torch.utils.data.random_split(all_dataset, [train_size, test_size])
trainloader = torch.utils.data.DataLoader(trainset,
batch_size=args.batch_size,
shuffle=True,
#pin_memory=True,
num_workers=nw)
testloader = torch.utils.data.DataLoader(testset,
batch_size=args.batch_size,
shuffle=False,
#pin_memory=True,
num_workers=nw)
print('训练数据个数:%d,测试数据个数%d'%(len(trainset),len(testset)))
#tradeoff(model,trainloader,testloader,device)
#sys.exit(0)
for epoch in range(args.epochs):
train_loss = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=trainloader,
device=device,
epoch=epoch)
scheduler.step()
val_loss = evaluate(model=model,
data_loader=testloader,
device=device,
epoch=epoch)
torch.save(model.state_dict(), "./qat_epoch_pkl/model-{}-{}.pth".format(epoch,args.lr))
model.save_qat()
evaluate(model, testloader, device,"final",True)
torch.save(model, "final_model.pth")
model.display_qat("q.txt")
#model.save_each_layer()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='./trained_noqat_dic_all.pkl',
help='initial weights path')
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--lrf', type=float, default=0.01)
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
opt = parser.parse_args()
main(opt)