class PSNR(nn.Module):
def __init__(self, max_val=0):
super().__init__()
base10 = torch.log(torch.tensor(10.0))
max_val = torch.tensor(max_val).float()
self.register_buffer('base10', base10)
self.register_buffer('max_val', 20 * torch.log(max_val) / base10)
def __call__(self, a, b):
mse = torch.mean((a.float() - b.float()) ** 2)
if mse == 0:
return 0
return 10 * torch.log10((1.0 / mse))
2
发布时间 2023-04-23 21:52:10作者: helloWorldhelloWorld