def getHighLowFre(image): f = torch.fft.fft2(image) # 计算频率 freqs = torch.fft.fftfreq(image.shape[-1]) # print(freqs) # 设定阈值,用于分离高频和低频信息 threshold = 0.1 # 创建掩码,用于分离高频和低频信息 mask = (freqs.abs() < threshold).float().to(args.device) print(mask) # 应用掩码,分离高频和低频信息 low_freq = torch.fft.ifft2(f * mask) print(low_freq) high_freq = image - low_freq print(high_freq) return high_freq, low_freq