import torch.nn as nn import torch def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size//2), bias=bias) class PALayer(nn.Module): def __init__(self, channel): super(PALayer, self).__init__() self.pa = nn.Sequential( nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True), nn.Sigmoid() ) def forward(self, x): y = self.pa(x) return x * y class CALayer(nn.Module): def __init__(self, channel): super(CALayer, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.ca = nn.Sequential( nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True), nn.Sigmoid() ) def forward(self, x): y = self.avg_pool(x) y = self.ca(y) return x * y class Block(nn.Module): def __init__(self, conv, dim, kernel_size,): super(Block, self).__init__() self.conv1 = conv(dim, dim, kernel_size, bias=True) self.act1 = nn.ReLU(inplace=True) self.conv2 = conv(dim, dim, kernel_size, bias=True) self.calayer = CALayer(dim) self.palayer = PALayer(dim) def forward(self, x): res =self.act1(self.conv1(x)) res = self.conv2(res) res_ca = self.calayer(res) res_pa = self.palayer(res) res = x + res_ca + res_pa return res class Group(nn.Module): def __init__(self, conv, dim, kernel_size, blocks): super(Group, self).__init__() modules = [ Block(conv, dim, kernel_size) for _ in range(blocks)] modules.append(conv(dim, dim, kernel_size)) self.gp = nn.Sequential(*modules) def forward(self, x): res = self.gp(x) res += x return res class FFA(nn.Module): def __init__(self, gps, blocks, conv=default_conv): super(FFA, self).__init__() self.gps = gps self.dim = 64 kernel_size = 3 pre_process = [conv(3, self.dim, kernel_size)] # assert self.gps==1 self.g1 = Group(conv, self.dim, kernel_size, blocks=blocks) self.g2 = Group(conv, self.dim, kernel_size, blocks=blocks) self.g3 = Group(conv, self.dim, kernel_size, blocks=blocks) self.ca = nn.Sequential(*[ nn.AdaptiveAvgPool2d(1), nn.Conv2d(self.dim * self.gps, self.dim // 16, 1, padding=0), nn.ReLU(inplace=True), nn.Conv2d(self.dim // 16, self.dim * self.gps, 1, padding=0, bias=True), nn.Sigmoid() ]) self.palayer = PALayer(self.dim) post_precess = [ conv(self.dim, self.dim, kernel_size), conv(self.dim, 3, kernel_size)] self.pre = nn.Sequential(*pre_process) self.post = nn.Sequential(*post_precess) def forward(self, x1): x = self.pre(x1) res1 = self.g1(x) res2 = self.g2(res1) res3 = self.g3(res2) w = self.ca(torch.cat([res1, res2, res3], dim=1)) w = w.view(-1,self.gps, self.dim)[:,:,:,None,None] out = w[:,0,::] * res1 + w[:,1,::] * res2 + w[:,2,::] * res3 out = self.palayer(out) x = self.post(out) return x + x1 if __name__ == "__main__": net = FFA(gps=3,blocks=19) print(net)
import torch import torch.nn as nn def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size//2), bias=bias) class PALayer(nn.Module): def __init__(self, channel): super(PALayer, self).__init__() self.pa = nn.Sequential( nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True), nn.Sigmoid() ) def forward(self, x): y = self.pa(x) return x * y class CALayer(nn.Module): def __init__(self, channel): super(CALayer, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.ca = nn.Sequential( nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True), nn.Sigmoid() ) def forward(self, x): y = self.avg_pool(x) y = self.ca(y) return x * y class Block(nn.Module): def __init__(self, conv, dim, kernel_size,): super(Block, self).__init__() self.conv1 = conv(dim, dim, kernel_size, bias=True) self.act1 = nn.ReLU(inplace=True) self.conv2 = conv(dim, dim, kernel_size, bias=True) self.calayer = CALayer(dim) self.palayer = PALayer(dim) def forward(self, x): res=self.act1(self.conv1(x)) res=self.conv2(res) res_ca=self.calayer(res) res_pa=self.palayer(res) res = x + res_ca + res_pa return res class Group(nn.Module): def __init__(self, conv, dim, kernel_size, blocks): super(Group, self).__init__() modules = [ Block(conv, dim, kernel_size) for _ in range(blocks)] modules.append(conv(dim, dim, kernel_size)) self.gp = nn.Sequential(*modules) def forward(self, x): res = self.gp(x) res += x return res class FFA(nn.Module): def __init__(self,gps,blocks,conv=default_conv): super(FFA, self).__init__() self.gps = gps self.dim = 64 kernel_size = 3 pre_process = [conv(3, self.dim, kernel_size)] # assert self.gps==1 self.g1 = Group(conv, self.dim, kernel_size, blocks=blocks) self.g2 = Group(conv, self.dim, kernel_size, blocks=blocks) self.g3 = Group(conv, self.dim, kernel_size, blocks=blocks) self.ca = nn.Sequential(*[ nn.AdaptiveAvgPool2d(1), nn.Conv2d(self.dim * self.gps, self.dim // 16, 1, padding=0), nn.ReLU(inplace=True), nn.Conv2d(self.dim // 16, self.dim * self.gps, 1, padding=0, bias=True), nn.Sigmoid() ]) self.palayer=PALayer(self.dim) post_precess = [ conv(192, self.dim, kernel_size), conv(self.dim, 3, kernel_size)] self.pre = nn.Sequential(*pre_process) self.post = nn.Sequential(*post_precess) def forward(self, x1): x = self.pre(x1) res1 = self.g1(x) res2 = self.g2(res1) res3 = self.g3(res2) out = torch.cat([res1, res2, res3], dim=1) # w = self.ca(torch.cat([res1, res2, res3], dim=1)) # w = w.view(-1,self.gps, self.dim)[:,:,:,None,None] # out = w[:,0,::] * res1 + w[:,1,::] * res2 + w[:,2,::] * res3 # out = self.palayer(out) x = self.post(out) return x + x1 if __name__ == "__main__": net = FFA(gps=3,blocks=19) print(net)