11111111111

发布时间 2023-05-25 15:22:30作者: helloWorldhelloWorld
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from math import sqrt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


# 读取两张图像
img1 = Image.open('img/low/1.png')
img2 = Image.open('img/low/22.png')
# 转换为[N, C, H, W]张量形式
# transform = transforms.Compose([
#     transforms.Resize((256, 256)),
#     transforms.CenterCrop((224, 224)),
#     transforms.ToTensor()
# ])

# if img1.size != img2.size:
#     new_size = min(img1.size, img2.size)
#     transform = transforms.Compose([
#         transforms.Resize(new_size),
#         transforms.CenterCrop((224, 224)),
#         transforms.ToTensor()
#     ])
# else:
#     transform = transforms.Compose([
#         transforms.Resize((256, 256)),
#         transforms.CenterCrop((224, 224)),
#         transforms.ToTensor()
#     ])
#
#
# img1 = transform(img1).unsqueeze(0)  # 添加批次维(N=1)
# img2 = transform(img2).unsqueeze(0)  # 添加批次维(N=1)

if img1.size != img2.size:
    new_size = min(img1.size, img2.size)
    transform = transforms.Compose([
        transforms.Resize(new_size),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor()
    ])
    img1 = transform(img1).unsqueeze(0)
    img2 = transform(img2).unsqueeze(0)
else:
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor()
    ])
    img1 = transform(img1).unsqueeze(0)
    img2 = transform(img2).unsqueeze(0)

# assert img1.size() == img2.size()
# _, c, h, w = img1.size()
# h_crop = int(h * sqrt(1.0))
# w_crop = int(w * sqrt(1.0))
# print(h_crop)
# print(w_crop)
# h_start = h // 2 - h_crop // 2
# print(h_start)
# w_start = w // 2 - w_crop // 2
# print(w_start)

lam = 1 # np.random.uniform(0, 1.0)
img1_fft = torch.fft.fft2(img1, dim=[2, 3])
img2_fft = torch.fft.fft2(img2, dim=[2, 3])
img1_abs, img1_pha = torch.abs(img1_fft), torch.angle(img1_fft)
img2_abs, img2_pha = torch.abs(img2_fft), torch.angle(img2_fft)
img1_abs = torch.fft.fftshift(img1_abs, dim=[2, 3])
img2_abs = torch.fft.fftshift(img2_abs, dim=[2, 3])
# img1_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = lam * img2_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img1_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop]
# img2_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = lam * img1_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img2_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop]
img1_abs = lam * img2_abs + (1 - lam) * img1_abs
# img2_abs = lam * img1_abs_ + (1 - lam) * img2_abs_

img1_abs = torch.fft.ifftshift(img1_abs, dim=[2, 3])
img21 = img1_abs * (torch.exp(1j * img1_pha))
img21 = torch.real(torch.fft.ifft2(img21, dim=[2, 3]))

img21 = torch.clamp(img21, 0, 1) * 255.0
img21 = img21.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)
# 展示原始图像和重构图像
plt.subplot(221), plt.imshow(img1[0].permute(1, 2, 0)), plt.title('Original Image 1')
plt.axis('off')
plt.subplot(222), plt.imshow(img2[0].permute(1, 2, 0)), plt.title('Original Image 2')
plt.axis('off')
plt.subplot(223), plt.imshow(img21), plt.title('Reconstruct Image 1')
plt.axis('off')

plt.show()
plt.savefig('mix', bbox_inches='tight')