2

发布时间 2023-06-09 09:04:52作者: helloWorldhelloWorld
label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1))
            label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1)
            pred_fft3 = torch.fft.fft2(pred_img, dim=(-2,-1))
            pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1)
            f3 = criterion(pred_fft3, label_fft3)
            loss_fft = f3