Import required libraries:

import torch
import torch.nn as nn
import torch.optim as optim
from import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
from torchvision.datasets import ImageFolder

Define a simple convolutional block (Conv-BatchNorm-ReLU)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),

    def forward(self, x):
        return self.conv(x)

Define a simple upscaling block using sub-pixel convolution

class UpscaleBlock(nn.Module):
    def __init__(self, in_channels, scale_factor):
        super(UpscaleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.relu(x)
        return x

Define a custom super-resolution model (e.g., using ConvBlocks and UpscaleBlocks)

class SuperResolutionModel(nn.Module):
    def __init__(self, upscale_factor):
        super(SuperResolutionModel, self).__init__()
        self.conv1 = ConvBlock(3, 64, kernel_size=9, stride=1, padding=4)
        self.conv2 = ConvBlock(64, 32, kernel_size=1, stride=1, padding=0)
        self.upscale = UpscaleBlock(32, upscale_factor)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.upscale(x)
        x = self.conv3(x)
        return x

Create a custom dataset for image super-resolution

class SuperResolutionDataset(
    def __init__(self, image_folder, input_transform, target_transform):
        self.dataset = ImageFolder(image_folder)
        self.input_transform = input_transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        img, _ = self.dataset[index]
        target = self.target_transform(img)
        input = self.input_transform(target)
        return input, target

    def __len__(self):
        return len(self.dataset)

Instantiate the model, loss function, and optimizer

upscale_factor = 2
model = SuperResolutionModel(upscale_factor).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

Define input and target transformations for data preprocessing

input_transform = transforms.Compose([
    transforms.Resize((256 // upscale_factor, 256 // upscale_factor), interpolation=TF.InterpolationMode.BICUBIC),

target_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=TF.InterpolationMode.BICUBIC),

Create DataLoader for training and validation data

train_dataset = SuperResolutionDataset("path/to/train_data", input_transform, target_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)

val_dataset = SuperResolutionDataset("path/to/val_data", input_transform, target_transform)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

Training loop

val_loss = 0.0

with torch.no_grad():
    for inputs, targets in val_loader:
        inputs, targets =,

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        val_loss += loss.item()

val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")

