1111

发布时间 2023-07-06 10:50:31作者: helloWorldhelloWorld
class LLCaps(nn.Module):
    def __init__(self,device, in_channels=3, out_channels=3, n_feat=64, kernel_size=3, stride=2, n_RRG=3, n_MSRB=2, height=3, width=2, bias=False):
        super(LLCaps, self).__init__()
        self.device = device
        self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=bias)
        modules_body = [RRG(n_feat, n_MSRB, height, width, stride, bias) for _ in range(n_RRG)]
        self.body = nn.Sequential(*modules_body)
        self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=bias)
        self.rd_model = Unet(dim=6, init_dim=None, out_dim=None, dim_mults=(1,2,4,8), channels=3, self_condition=False, resnet_block_groups=6, learned_variance=False, learned_sinusoidal_cond=False, random_fourier_features=False, learned_sinusoidal_dim=16)
        self.rd_procedure = GaussianDiffusion(self.rd_model, image_size=256, timesteps=1000, sampling_timesteps=None, loss_type='l1',objective='pred_noise', beta_schedule='sigmoid', schedule_fn_kwargs=dict(), ddim_sampling_eta=0., auto_normalize = True)

    def forward(self, x):
        h = self.conv_in(x)
        h = self.body(h)
        h = self.conv_out(h)
        h += x
        h = self.rd_procedure(h)
        return h