label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1)) label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1) pred_fft3 = torch.fft.fft2(pred_img, dim=(-2,-1)) pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1) f3 = criterion(pred_fft3, label_fft3) loss_fft = f3