11

发布时间 2023-10-30 10:37:35作者: helloWorldhelloWorld
import torch
import torch.nn as nn

class BasicConv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
        super(BasicConv, self).__init__()
        if bias and norm:
            bias = False

        padding = kernel_size // 2
        layers = list()
        if transpose:
            padding = kernel_size // 2 - 1
            layers.append(
                nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
        else:
            layers.append(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
        if norm:
            layers.append(nn.BatchNorm2d(out_channel))
        if relu:
            layers.append(nn.GELU())
        self.main = nn.Sequential(*layers)

    def forward(self, x):
        return self.main(x)

class Network(nn.Module):
    def __init__(self, in_channel=3, out_channel=20, relu_slope=0.2):
        super(Network, self).__init__()

        self.preConv = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=True),
        self.spatialConv = nn.Sequential(*[
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1, bias=True),
            nn.LeakyReLU(relu_slope, inplace=False),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1, bias=True),
            nn.LeakyReLU(relu_slope, inplace=False)
        ])
        self.fftConv2 = nn.Sequential(*[
            nn.Conv2d(out_channel, out_channel, 1, 1, 0),
            nn.LeakyReLU(relu_slope, inplace=False),
            nn.Conv2d(out_channel, out_channel, 1, 1, 0)
        ])

        self.fusion = nn.Conv2d(out_channel * 2, out_channel, 1, 1, 0)

        self.proConv = nn.Conv2d(out_channel, in_channel, 3, 1, bias=True),

    def forward(self, x1):
        print(x1.shape)
        x = self.preConv(x1)
        spatial_out = self.spatialConv(x)
        out = spatial_out + x

        x_fft = torch.fft.rfft2(out, norm='backward')
        x_amp = torch.abs(x_fft)
        x_phase = torch.angle(x_fft)

        enhanced_phase = self.fftConv2(x_phase)
        enhanced_amp = self.fftConv2(x_amp)
        x_fft_out1 = torch.fft.irfft2(x_amp * torch.exp(1j * enhanced_phase), norm='backward')
        x_fft_out2 = torch.fft.irfft2(enhanced_amp * torch.exp(1j * x_phase), norm='backward')

        out = self.fusion(torch.cat([out, x_fft_out1, x_fft_out2], dim=1))
        out = self.proConv(out)

        return out