import os,argparse import numpy as np from PIL import Image from models import * from metrics import psnr,ssim import torch import torch.nn as nn import torchvision.transforms as tfs import torchvision.utils as vutils import matplotlib.pyplot as plt from torchvision.utils import make_grid from torch.utils.data import DataLoader from data_utils import test_Dataset abs=os.getcwd()+'/' def tensorShow(tensors,titles=['haze']): fig=plt.figure() for tensor,tit,i in zip(tensors,titles,range(len(tensors))): img = make_grid(tensor) npimg = img.numpy() ax = fig.add_subplot(221+i) ax.imshow(np.transpose(npimg, (1, 2, 0))) ax.set_title(tit) plt.show() parser=argparse.ArgumentParser() parser.add_argument('--task',type=str,default='its',help='its or ots') parser.add_argument('--xing ',type=str,default='test_imgs',help='Test imgs folder') opt=parser.parse_args() dataset=opt.task gps=3 blocks=19 test_loader=DataLoader(dataset=test_Dataset('/media/mmsys/6f1091c9-4ed8-4a10-a03d-2acef144d2e1/SXY/Data/LOL/LOL-v1/test/'),batch_size=1,shuffle=True) output_dir='/media/mmsys/6f1091c9-4ed8-4a10-a03d-2acef144d2e1/SXY/Data/LOL/LOL-v1/output_FFA/' if not os.path.exists(output_dir): os.mkdir(output_dir) model_dir='trained_models/v1_FFA/model_bestPSNR.pth' device='cuda:1' #if torch.cuda.is_available() else 'cpu' ckp=torch.load(model_dir,map_location=device) net=FFA(gps=gps,blocks=blocks) # net=nn.DataParallel(net) net.load_state_dict(ckp['model']) net.eval() ssims=[] psnrs=[] with torch.no_grad(): for iter_idx, data in enumerate(test_loader): # print(f'\r {im}',end='',flush=True) # haze = Image.open(img_dir+im) input,target,img_name = data # input= tfs.Compose([ # tfs.Normalize(mean=[0.64, 0.6, 0.58],std=[0.14,0.15, 0.152]) # ])(input)[None,::] # haze_no=tfs.ToTensor()(haze)[None,::] with torch.no_grad(): pred = net(input) ssim1 = ssim(pred, target).item() psnr1 = psnr(pred, target) print(f'image:{img_name[0]} |ssim:{ssim1:.4f}| psnr:{psnr1:.4f}') ssims.append(ssim1) psnrs.append(psnr1) # ts=torch.squeeze(pred.clamp(0,1).cpu()) # tensorShow([haze_no,pred.clamp(0,1).cpu()],['haze','pred']) vutils.save_image(pred, output_dir+img_name[0], normalize=True) print(f'\nssim:{np.mean(ssims):.4f}| psnr:{np.mean(psnrs):.4f}')