1

发布时间 2023-05-02 21:20:50作者: helloWorldhelloWorld
import torch

from collections import OrderedDict
from transform.data_RGB import get_validation_data2
from torch.utils.data import DataLoader
import os
import argparse
from model.LLFormer import LLFormer
import cv2
from skimage import img_as_ubyte
import utils.image_utils as utils


def load_checkpoint(model, weights):
    checkpoint = torch.load(weights)
    try:
        model.load_state_dict(checkpoint["state_dict"])
    except:
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

parser = argparse.ArgumentParser("LLFormer")
parser.add_argument('--model', type=str, default='weights/difficult.pt', help='location of the data corpus')
parser.add_argument('--test_dir', type=str, default='/media/mmsys/6f1091c9-4ed8-4a10-a03d-2acef144d2e1/SXY/Data/LOL/LOL-v1/eval15', help='location of the data corpus')
parser.add_argument('--save_path', type=str, default='/media/mmsys/6f1091c9-4ed8-4a10-a03d-2acef144d2e1/SXY/Data/LOL/LOL-v1/eval15/output', help='location of the data corpus')
parser.add_argument('--gpu', type=str, default='cuda:0', help='gpu device id')
parser.add_argument('--weights',default='checkpoints/LLFormer_LOL/lolv1-models/model_bestPSNR.pth', type=str,help='Path to weights')
args = parser.parse_args()

val_dataset = get_validation_data2(args.test_dir, {'patch_size': 128 })
val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, num_workers=0,
                        drop_last=False)


# Load corresponding lolv1-models architecture and weights

model = LLFormer(inp_channels=3,out_channels=3,dim = 16,num_blocks = [2,4,8,16],num_refinement_blocks = 2,heads = [1,2,4,8],ffn_expansion_factor = 2.66,bias = False,LayerNorm_type = 'WithBias',attention=True,skip = False)
model.to(args.gpu)
load_checkpoint(model, args.weights)
model.eval()

def save_img(filepath, img):
    cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

print('restoring images......')
def main():

    psnr_val_rgb = []
    ssim_val_rgb = []
    for ii, data_val in enumerate(val_loader, 0):
        target = data_val[0].to('cuda:0')
        input_ = data_val[1].to('cuda:0')
        image_name = data_val[2]
        h, w = target.shape[2], target.shape[3]
        with torch.no_grad():
            restored = model(input_)
            restored = restored[:, :, :h, :w]

        for res, tar in zip(restored, target):
            psnr_val_rgb.append(utils.torchPSNR(res, tar))
            ssim_val_rgb.append(utils.torchSSIM(restored, target))

        restored = torch.clamp(restored, 0, 1)
        restored = restored[:, :, :h, :w]
        restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
        restored = img_as_ubyte(restored[0])

        save_img_name = "".join(image_name) + ".png"
        save_img((os.path.join(args.save_path, save_img_name)), restored)
        print('processing',save_img_name)


    psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
    ssim_val_rgb = torch.stack(ssim_val_rgb).mean().item()

    print('avg_psnr',psnr_val_rgb)
    print('avg_ssim', ssim_val_rgb)


if __name__ == '__main__':
    main()