Pytorch构建超分辨率模型——常用模块

发布时间 2023-03-24 23:36:58作者: 马路野狼

Import required libraries:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
from torchvision.datasets import ImageFolder

Define a simple convolutional block (Conv-BatchNorm-ReLU)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

Define a simple upscaling block using sub-pixel convolution

class UpscaleBlock(nn.Module):
    def __init__(self, in_channels, scale_factor):
        super(UpscaleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.relu(x)
        return x

Define a custom super-resolution model (e.g., using ConvBlocks and UpscaleBlocks)

class SuperResolutionModel(nn.Module):
    def __init__(self, upscale_factor):
        super(SuperResolutionModel, self).__init__()
        self.conv1 = ConvBlock(3, 64, kernel_size=9, stride=1, padding=4)
        self.conv2 = ConvBlock(64, 32, kernel_size=1, stride=1, padding=0)
        self.upscale = UpscaleBlock(32, upscale_factor)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.upscale(x)
        x = self.conv3(x)
        return x

Create a custom dataset for image super-resolution

class SuperResolutionDataset(torch.utils.data.Dataset):
    def __init__(self, image_folder, input_transform, target_transform):
        self.dataset = ImageFolder(image_folder)
        self.input_transform = input_transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        img, _ = self.dataset[index]
        target = self.target_transform(img)
        input = self.input_transform(target)
        return input, target

    def __len__(self):
        return len(self.dataset)

Instantiate the model, loss function, and optimizer

upscale_factor = 2
model = SuperResolutionModel(upscale_factor).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

Define input and target transformations for data preprocessing

input_transform = transforms.Compose([
    transforms.Resize((256 // upscale_factor, 256 // upscale_factor), interpolation=TF.InterpolationMode.BICUBIC),
    transforms.ToTensor()
])

target_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])

Create DataLoader for training and validation data

train_dataset = SuperResolutionDataset("path/to/train_data", input_transform, target_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)

val_dataset = SuperResolutionDataset("path/to/val_data", input_transform, target_transform)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

Training loop

model.eval()
val_loss = 0.0

with torch.no_grad():
    for inputs, targets in val_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        val_loss += loss.item()

val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")

Validation loop

model.eval()
val_loss = 0.0

with torch.no_grad():
    for inputs, targets in val_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        val_loss += loss.item()

val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")