2222222222222

发布时间 2023-05-25 15:57:49作者: helloWorldhelloWorld
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms.functional as TF
from torchvision import utils as vutils

from torchvision import transforms
from math import sqrt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


# 读取两张图像
img1 = Image.open('img/00001.png').convert('RGB')
img2 = Image.open('img/5.png').convert('RGB')

inp_img = TF.to_tensor(img1)
cro_img = TF.to_tensor(img2)
# # fusion inp and cro by fft
# lam = 1
# inp_img_fft = torch.fft.fftn(inp_img, dim=[1, 2])
# cro_img_fft = torch.fft.fftn(cro_img, dim=[1, 2])
# inp_img_fft_abs, inp_img_fft_pha = torch.abs(inp_img_fft), torch.angle(inp_img_fft)
# cro_img_fft_abs, cro_img_fft_pha = torch.abs(cro_img_fft), torch.angle(cro_img_fft)
# inp_img_fft_abs = torch.fft.fftshift(inp_img_fft_abs, dim=[1, 2])
# cro_img_fft_abs = torch.fft.fftshift(cro_img_fft_abs, dim=[1, 2])
#
# inp_img_fft_abs = lam * cro_img_fft_abs + (1 - lam) * inp_img_fft_abs
#
# inp_img_fft_abs = torch.fft.ifftn(inp_img_fft_abs, dim=[1, 2])
# inp_cor_img = inp_img_fft_abs * (torch.exp(1j * inp_img_fft_pha))
# inp_cor_img = torch.real(torch.fft.ifftn(inp_cor_img, dim=[1, 2]))
# inp_cor_img = torch.clamp(inp_cor_img, 0, 1) * 255.0
# print(inp_cor_img.shape)
# vutils.save_image(inp_cor_img, '2-inputs1.jpg', normalize=True)

lam = 1 # np.random.uniform(0, 1.0)
img1_fft = torch.fft.fftn(inp_img, dim=[1, 2])
img2_fft = torch.fft.fftn(cro_img, dim=[1, 2])
img1_abs, img1_pha = torch.abs(img1_fft), torch.angle(img1_fft)
img2_abs, img2_pha = torch.abs(img2_fft), torch.angle(img2_fft)
img1_abs = torch.fft.fftshift(img1_abs, dim=[1, 2])
img2_abs = torch.fft.fftshift(img2_abs, dim=[1, 2])
img1_abs = lam * img2_abs + (1 - lam) * img1_abs

img1_abs = torch.fft.ifftshift(img1_abs, dim=[1, 2])
img21 = img1_abs * (torch.exp(1j * img1_pha))
img21 = torch.real(torch.fft.ifftn(img21, dim=[1, 2]))

img21 = torch.clamp(img21, 0, 1) * 255.0
img21 = img21.permute(1, 2, 0).numpy().astype(np.uint8)
# 展示原始图像和重构图像

plt.subplot(223), plt.imshow(img21), plt.title('Reconstruct Image 1')
plt.axis('off')

plt.show()
# plt.savefig('mix', bbox_inches='tight')