111

发布时间 2023-05-28 10:01:27作者: helloWorldhelloWorld
import os,argparse
import numpy as np
from PIL import Image
from models import *
from metrics import psnr,ssim
import torch
import torch.nn as nn
import torchvision.transforms as tfs 
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from data_utils import test_Dataset
abs=os.getcwd()+'/'
def tensorShow(tensors,titles=['haze']):
        fig=plt.figure()
        for tensor,tit,i in zip(tensors,titles,range(len(tensors))):
            img = make_grid(tensor)
            npimg = img.numpy()
            ax = fig.add_subplot(221+i)
            ax.imshow(np.transpose(npimg, (1, 2, 0)))
            ax.set_title(tit)
        plt.show()

parser=argparse.ArgumentParser()
parser.add_argument('--task',type=str,default='its',help='its or ots')
parser.add_argument('--xing ',type=str,default='test_imgs',help='Test imgs folder')
opt=parser.parse_args()
dataset=opt.task
gps=3
blocks=19

test_loader=DataLoader(dataset=test_Dataset('/media/mmsys/6f1091c9-4ed8-4a10-a03d-2acef144d2e1/SXY/Data/LOL/LOL-v1/test/'),batch_size=1,shuffle=True)


output_dir='/media/mmsys/6f1091c9-4ed8-4a10-a03d-2acef144d2e1/SXY/Data/LOL/LOL-v1/output_FFA/'

if not os.path.exists(output_dir):
    os.mkdir(output_dir)
model_dir='trained_models/v1_FFA/model_bestPSNR.pth'
device='cuda:1' #if torch.cuda.is_available() else 'cpu'
ckp=torch.load(model_dir,map_location=device)
net=FFA(gps=gps,blocks=blocks)
# net=nn.DataParallel(net)
net.load_state_dict(ckp['model'])
net.eval()
ssims=[]
psnrs=[]
with torch.no_grad():
    for iter_idx, data in enumerate(test_loader):
        # print(f'\r {im}',end='',flush=True)
        # haze = Image.open(img_dir+im)
        input,target,img_name = data

        # input= tfs.Compose([
        #     tfs.Normalize(mean=[0.64, 0.6, 0.58],std=[0.14,0.15, 0.152])
        # ])(input)[None,::]
        # haze_no=tfs.ToTensor()(haze)[None,::]
        with torch.no_grad():
            pred = net(input)
            ssim1 = ssim(pred, target).item()
            psnr1 = psnr(pred, target)
            print(f'image:{img_name[0]} |ssim:{ssim1:.4f}| psnr:{psnr1:.4f}')
            ssims.append(ssim1)
            psnrs.append(psnr1)
        # ts=torch.squeeze(pred.clamp(0,1).cpu())
        # tensorShow([haze_no,pred.clamp(0,1).cpu()],['haze','pred'])
        vutils.save_image(pred, output_dir+img_name[0], normalize=True)

print(f'\nssim:{np.mean(ssims):.4f}| psnr:{np.mean(psnrs):.4f}')