11

发布时间 2023-07-16 21:21:27作者: helloWorldhelloWorld
import os
import glob
import h5py
import random
from PIL import Image
from matplotlib import pyplot as plt

import torch.utils.data as data
from torch.utils.data import DataLoader
import torchvision.transforms as tfs
from torchvision.transforms import functional as FF

from Utils.metrics import *
from Utils.option import *


class train_Dataset(data.Dataset):
    def __init__(self, path, mode='train', size=opt.crop_size, **kwargs):
        super(train_Dataset,self).__init__()
        self.size=size
        self.mode=mode
        self.format=format
        self.haze_imgs_dir = os.path.join(path, 'low')  # haze_img dir path
        self.haze_imgs_list = os.listdir(self.haze_imgs_dir)  # haze_img name list
        self.haze_imgs = [os.path.join(self.haze_imgs_dir, img) for img in self.haze_imgs_list]  # haze_img path list
        self.clear_dir = os.path.join(path, 'high')  # clean_img dir path
        self.mask = False
        if 'restored_mask' in kwargs:
            self.restored_mask = kwargs['restored_mask'].long()
            self.mask = True

        self.length = len(self.haze_imgs_list)
    def __getitem__(self, index):
        haze = Image.open(self.haze_imgs[index])
        if isinstance(self.size, int):
            while haze.size[0]<self.size or haze.size[1]<self.size :
                index = random.randint(0,self.length)
                haze = Image.open(self.haze_imgs[index])
        haze_name = self.haze_imgs[index].split('/')[-1]
        id = haze_name.split('_')[0]
        clear_name = id
        clear = Image.open(os.path.join(self.clear_dir, clear_name))
        clear = tfs.CenterCrop(haze.size[::-1])(clear)

        if not isinstance(self.size,str):
            i,j,h,w=tfs.RandomCrop.get_params(haze,output_size=(self.size,self.size))
            haze = FF.crop(haze,i,j,h,w)
            clear = FF.crop(clear,i,j,h,w)
        rand_hor = random.randint(0, 1)
        rand_rot = random.randint(0, 3)
        haze =self.augData_haze(haze.convert("RGB"), rand_hor, rand_rot)
        clear = self.augData_clear(clear.convert("RGB"), rand_hor, rand_rot)
        mask_flag = self.restored_mask[index] if self.mask else -1

        return haze, clear, index, mask_flag

    def augData_haze(self, haze, rand_hor, rand_rot):
        haze=tfs.RandomHorizontalFlip(rand_hor)(haze)
        if rand_rot:
            haze=FF.rotate(haze,90*rand_rot)
        haze=tfs.ToTensor()(haze)
        # haze=tfs.Normalize(mean=[0.64, 0.6, 0.58],std=[0.14,0.15, 0.152])(haze)
        return haze

    def augData_clear(self, clear, rand_hor, rand_rot):
        clear=tfs.RandomHorizontalFlip(rand_hor)(clear)
        if rand_rot:
            clear=FF.rotate(clear,90*rand_rot)
        clear=tfs.ToTensor()(clear)
        return clear

    def __len__(self):
        return self.length


class test_Dataset(data.Dataset):
    def __init__(self, path, mode='test'):
        super(test_Dataset,self).__init__()
        self.mode=mode
        self.format=format
        self.haze_imgs_dir = os.path.join(path, 'low')  # haze_img dir path
        self.haze_imgs_list = os.listdir(self.haze_imgs_dir)  # haze_img name list
        self.haze_imgs = [os.path.join(self.haze_imgs_dir, img) for img in self.haze_imgs_list]  # haze_img path list
        self.clear_dir = os.path.join(path, 'high')  # clean_img dir path

        self.length = len(self.haze_imgs_list)
    def __getitem__(self, index):
        haze = Image.open(self.haze_imgs[index])
        haze_name = self.haze_imgs[index].split('/')[-1]
        id = haze_name.split('_')[0]
        clear_name = id
        clear = Image.open(os.path.join(self.clear_dir, clear_name))
        clear = tfs.CenterCrop(haze.size[::-1])(clear)

        haze = self.augData_haze(haze.convert("RGB"))
        clear = self.augData_clear(clear.convert("RGB"))
        return haze, clear, haze_name

    def augData_haze(self, haze):
        haze=tfs.ToTensor()(haze)
        # haze=tfs.Normalize(mean=[0.64, 0.6, 0.58],std=[0.14,0.15, 0.152])(haze)
        return haze

    def augData_clear(self, clear):
        clear=tfs.ToTensor()(clear)
        return clear

    def __len__(self):
        return self.length