class SpaBlock(nn.Module): def __init__(self, in_channel, out_channel,relu_slope=0.2): super(SpaBlock, self).__init__() self.spatialConv = nn.Sequential(*[ nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=True), nn.LeakyReLU(relu_slope, inplace=False), nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1, bias=True), nn.LeakyReLU(relu_slope, inplace=False) ]) self.identity = nn.Conv2d(in_channel, out_channel, 1, 1, 0) def forward(self, x): out = self.spatialConv(x) ide_out = self.identity(x) return out + ide_out class FreBlock(nn.Module): def __init__(self, in_channel, out_channel, relu_slope=0.2): super(FreBlock, self).__init__() self.fftConv2 = nn.Sequential(*[ nn.Conv2d(out_channel, out_channel, 1, 1, 0), nn.LeakyReLU(relu_slope, inplace=False), nn.Conv2d(out_channel, out_channel, 1, 1, 0) ]) def forward(self, x): x_fft = torch.fft.rfft2(x, norm='backward') x_amp = torch.abs(x_fft) x_phase = torch.angle(x_fft) enhanced_phase = self.fftConv2(x_phase) enhanced_amp = self.fftConv2(x_amp) x_fft_out = torch.fft.irfft2(enhanced_amp * torch.exp(1j * x_phase), norm='backward') return x_fft_out class ResBlock(nn.Module): def __init__(self, in_channel, out_channel, mode, filter=False): super(ResBlock, self).__init__() self.conv1 = BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True) self.spaBlock = SpaBlock(in_channel, out_channel, relu_slope=0.2) self.freBlock = FreBlock(in_channel, out_channel, relu_slope=0.2) self.filter = filter self.yyBlock = YYBlock(in_channel, out_channel, relu_slope=0.2) def forward(self, x): out = self.conv1(x) out = self.yyBlock(out) # out = self.spaBlock(out) # out = self.freBlock(out) out = out + x return out