222

发布时间 2023-12-27 19:09:03作者: helloWorldhelloWorld
import argparse
import logging
import os.path
import sys
import time
from collections import OrderedDict
import torchvision.utils as tvutils

import numpy as np
import torch
from IPython import embed
import lpips
from torchvision import utils as vutils

import options as option
from models import create_model

# sys.path.insert(0, "../../")
import utils as util
from data import create_dataloader, create_dataset
from data.util import bgr2ycbcr
from utils.metrics import *

#### options
parser = argparse.ArgumentParser()
parser.add_argument("-opt", type=str, help="Path to options YMAL file.", default='options/test/refusion.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)

opt = option.dict_to_nonedict(opt)

#### mkdir and logger
util.mkdirs(
    (
        path
        for key, path in opt["path"].items()
        if not key == "experiments_root"
        and "pretrain_model" not in key
        and "resume" not in key
    )
)

# os.systemc
util.setup_logger(
    "base",
    opt["path"]["log"],
    "test_" + opt["name"],
    level=logging.INFO,
    screen=True,
    tofile=True,
)
logger = logging.getLogger("base")
logger.info(option.dict2str(opt))

#### Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt["datasets"].items()):
    test_set = create_dataset(dataset_opt)
    test_loader = create_dataloader(test_set, dataset_opt)
    logger.info("Number of test images in [{:s}]: {:d}".format(dataset_opt["name"], len(test_set)))
    test_loaders.append(test_loader)

# load pretrained model by default
model = create_model(opt)
device = model.device

sde = util.IRSDE(max_sigma=opt["sde"]["max_sigma"], T=opt["sde"]["T"], schedule=opt["sde"]["schedule"], eps=opt["sde"]["eps"], device=device)
sde.set_model(model.model)
lpips_fn = lpips.LPIPS(net='alex').to(device)

scale = opt['degradation']['scale']

for test_loader in test_loaders:
    test_set_name = test_loader.dataset.opt["name"]  # path opt['']
    logger.info("\nTesting [{:s}]...".format(test_set_name))
    test_start_time = time.time()
    dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name)
    util.mkdir(dataset_dir)


    ssims_1, psnrs_1 = [], []

    for i, test_data in enumerate(test_loader):
        need_GT = False if test_loader.dataset.opt["dataroot_GT"] is None else True
        img_path = test_data["GT_path"][0] if need_GT else test_data["LQ_path"][0]
        img_name = os.path.splitext(os.path.basename(img_path))[0]

        #### input dataset_LQ
        LQ, GT = test_data["LQ"], test_data["GT"]
        noisy_state = sde.noise_state(LQ)

        model.feed_data(noisy_state, LQ, GT)
        pred_img = model.test(sde, save_states=True)

        visuals = model.get_current_visuals()
        SR_img = visuals["Output"]
        GT = visuals["GT"]

        if opt['save_img']:
            save_dir = opt["savepath"]
            save_img_path = os.path.join(save_dir, img_name + ".png")
            print(save_img_path)
            vutils.save_image(SR_img.float(), save_img_path, normalize=True)
        per_ssim_1 = ssim(SR_img, GT).item()
        per_psnr_1 = psnr(SR_img, GT)
        ssims_1.append(per_ssim_1)
        psnrs_1.append(per_psnr_1)
        print(f'\n {img_name} iter processing:{i + 1}   psnr:{per_psnr_1:.4f}  ssim:{per_ssim_1:.4f}', end='',flush=True)

    avg_ssim_1 = np.mean(ssims_1)
    avg_psnr_1 = np.mean(psnrs_1)
    print(f'\navg_psnr:{avg_psnr_1:.4f}  avg_ssim:{avg_ssim_1:.4f} ', end='', flush=True)