class LLCaps(nn.Module): def __init__(self,device, in_channels=3, out_channels=3, n_feat=64, kernel_size=3, stride=2, n_RRG=3, n_MSRB=2, height=3, width=2, bias=False): super(LLCaps, self).__init__() self.device = device self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=bias) modules_body = [RRG(n_feat, n_MSRB, height, width, stride, bias) for _ in range(n_RRG)] self.body = nn.Sequential(*modules_body) self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=bias) self.rd_model = Unet(dim=6, init_dim=None, out_dim=None, dim_mults=(1,2,4,8), channels=3, self_condition=False, resnet_block_groups=6, learned_variance=False, learned_sinusoidal_cond=False, random_fourier_features=False, learned_sinusoidal_dim=16) self.rd_procedure = GaussianDiffusion(self.rd_model, image_size=256, timesteps=1000, sampling_timesteps=None, loss_type='l1',objective='pred_noise', beta_schedule='sigmoid', schedule_fn_kwargs=dict(), ddim_sampling_eta=0., auto_normalize = True) def forward(self, x): h = self.conv_in(x) h = self.body(h) h = self.conv_out(h) h += x h = self.rd_procedure(h) return h