import torch from collections import OrderedDict from transform.data_RGB import get_validation_data2 from torch.utils.data import DataLoader import os import argparse from model.LLFormer import LLFormer import cv2 from skimage import img_as_ubyte import utils.image_utils as utils def load_checkpoint(model, weights): checkpoint = torch.load(weights) try: model.load_state_dict(checkpoint["state_dict"]) except: state_dict = checkpoint["state_dict"] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) parser = argparse.ArgumentParser("LLFormer") parser.add_argument('--model', type=str, default='weights/difficult.pt', help='location of the data corpus') parser.add_argument('--test_dir', type=str, default='/media/mmsys/6f1091c9-4ed8-4a10-a03d-2acef144d2e1/SXY/Data/LOL/LOL-v1/eval15', help='location of the data corpus') parser.add_argument('--save_path', type=str, default='/media/mmsys/6f1091c9-4ed8-4a10-a03d-2acef144d2e1/SXY/Data/LOL/LOL-v1/eval15/output', help='location of the data corpus') parser.add_argument('--gpu', type=str, default='cuda:0', help='gpu device id') parser.add_argument('--weights',default='checkpoints/LLFormer_LOL/lolv1-models/model_bestPSNR.pth', type=str,help='Path to weights') args = parser.parse_args() val_dataset = get_validation_data2(args.test_dir, {'patch_size': 128 }) val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False) # Load corresponding lolv1-models architecture and weights model = LLFormer(inp_channels=3,out_channels=3,dim = 16,num_blocks = [2,4,8,16],num_refinement_blocks = 2,heads = [1,2,4,8],ffn_expansion_factor = 2.66,bias = False,LayerNorm_type = 'WithBias',attention=True,skip = False) model.to(args.gpu) load_checkpoint(model, args.weights) model.eval() def save_img(filepath, img): cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) print('restoring images......') def main(): psnr_val_rgb = [] ssim_val_rgb = [] for ii, data_val in enumerate(val_loader, 0): target = data_val[0].to('cuda:0') input_ = data_val[1].to('cuda:0') image_name = data_val[2] h, w = target.shape[2], target.shape[3] with torch.no_grad(): restored = model(input_) restored = restored[:, :, :h, :w] for res, tar in zip(restored, target): psnr_val_rgb.append(utils.torchPSNR(res, tar)) ssim_val_rgb.append(utils.torchSSIM(restored, target)) restored = torch.clamp(restored, 0, 1) restored = restored[:, :, :h, :w] restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() restored = img_as_ubyte(restored[0]) save_img_name = "".join(image_name) + ".png" save_img((os.path.join(args.save_path, save_img_name)), restored) print('processing',save_img_name) psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() ssim_val_rgb = torch.stack(ssim_val_rgb).mean().item() print('avg_psnr',psnr_val_rgb) print('avg_ssim', ssim_val_rgb) if __name__ == '__main__': main()