import torch from torch import nn import numpy as np import matplotlib.pyplot as plt from PIL import Image import torchvision.transforms.functional as TF from torchvision import utils as vutils from torchvision import transforms from math import sqrt import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" # 读取两张图像 img1 = Image.open('img/00001.png').convert('RGB') img2 = Image.open('img/5.png').convert('RGB') inp_img = TF.to_tensor(img1) cro_img = TF.to_tensor(img2) # # fusion inp and cro by fft # lam = 1 # inp_img_fft = torch.fft.fftn(inp_img, dim=[1, 2]) # cro_img_fft = torch.fft.fftn(cro_img, dim=[1, 2]) # inp_img_fft_abs, inp_img_fft_pha = torch.abs(inp_img_fft), torch.angle(inp_img_fft) # cro_img_fft_abs, cro_img_fft_pha = torch.abs(cro_img_fft), torch.angle(cro_img_fft) # inp_img_fft_abs = torch.fft.fftshift(inp_img_fft_abs, dim=[1, 2]) # cro_img_fft_abs = torch.fft.fftshift(cro_img_fft_abs, dim=[1, 2]) # # inp_img_fft_abs = lam * cro_img_fft_abs + (1 - lam) * inp_img_fft_abs # # inp_img_fft_abs = torch.fft.ifftn(inp_img_fft_abs, dim=[1, 2]) # inp_cor_img = inp_img_fft_abs * (torch.exp(1j * inp_img_fft_pha)) # inp_cor_img = torch.real(torch.fft.ifftn(inp_cor_img, dim=[1, 2])) # inp_cor_img = torch.clamp(inp_cor_img, 0, 1) * 255.0 # print(inp_cor_img.shape) # vutils.save_image(inp_cor_img, '2-inputs1.jpg', normalize=True) lam = 1 # np.random.uniform(0, 1.0) img1_fft = torch.fft.fftn(inp_img, dim=[1, 2]) img2_fft = torch.fft.fftn(cro_img, dim=[1, 2]) img1_abs, img1_pha = torch.abs(img1_fft), torch.angle(img1_fft) img2_abs, img2_pha = torch.abs(img2_fft), torch.angle(img2_fft) img1_abs = torch.fft.fftshift(img1_abs, dim=[1, 2]) img2_abs = torch.fft.fftshift(img2_abs, dim=[1, 2]) img1_abs = lam * img2_abs + (1 - lam) * img1_abs img1_abs = torch.fft.ifftshift(img1_abs, dim=[1, 2]) img21 = img1_abs * (torch.exp(1j * img1_pha)) img21 = torch.real(torch.fft.ifftn(img21, dim=[1, 2])) img21 = torch.clamp(img21, 0, 1) * 255.0 img21 = img21.permute(1, 2, 0).numpy().astype(np.uint8) # 展示原始图像和重构图像 plt.subplot(223), plt.imshow(img21), plt.title('Reconstruct Image 1') plt.axis('off') plt.show() # plt.savefig('mix', bbox_inches='tight')