2

发布时间 2024-01-08 11:45:22作者: helloWorldhelloWorld
class SpaBlock(nn.Module):
    def __init__(self, in_channel, out_channel,relu_slope=0.2):
        super(SpaBlock, self).__init__()
        self.spatialConv = nn.Sequential(*[
            nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=True),
            nn.LeakyReLU(relu_slope, inplace=False),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1, bias=True),
            nn.LeakyReLU(relu_slope, inplace=False)
        ])

        self.identity = nn.Conv2d(in_channel, out_channel, 1, 1, 0)

    def forward(self, x):
        out = self.spatialConv(x)
        ide_out = self.identity(x)
        return out + ide_out


class FreBlock(nn.Module):
    def __init__(self, in_channel, out_channel, relu_slope=0.2):
        super(FreBlock, self).__init__()

        self.fftConv2 = nn.Sequential(*[
            nn.Conv2d(out_channel, out_channel, 1, 1, 0),
            nn.LeakyReLU(relu_slope, inplace=False),
            nn.Conv2d(out_channel, out_channel, 1, 1, 0)
        ])

    def forward(self, x):
        x_fft = torch.fft.rfft2(x, norm='backward')
        x_amp = torch.abs(x_fft)
        x_phase = torch.angle(x_fft)

        enhanced_phase = self.fftConv2(x_phase)
        enhanced_amp = self.fftConv2(x_amp)
        x_fft_out = torch.fft.irfft2(enhanced_amp * torch.exp(1j * x_phase), norm='backward')

        return x_fft_out


class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, mode, filter=False):
        super(ResBlock, self).__init__()
        self.conv1 = BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True)

        self.spaBlock = SpaBlock(in_channel, out_channel, relu_slope=0.2)
        self.freBlock = FreBlock(in_channel, out_channel, relu_slope=0.2)
        self.filter = filter

        self.yyBlock = YYBlock(in_channel, out_channel, relu_slope=0.2)

    def forward(self, x):
        out = self.conv1(x)

        out = self.yyBlock(out)

        # out = self.spaBlock(out)
        # out = self.freBlock(out)

        out = out + x
        return out