DeepFashion2 的多进程处理代码(带进度条)

发布时间 2023-07-03 19:13:14作者: cold_moon
import json
from PIL import Image, ImageDraw
import os
from tqdm import tqdm
import multiprocessing
from functools import partial

def convert_rgba_to_rgb(image, output_path):
	# 转换为 RGB 模式
	rgb_image = image.convert('RGB')

	# 保存为 JPEG 格式
	rgb_image.save(output_path)

def save_bounding_box(image, bounding_box, save_path):
	cropped_image = image.crop(bounding_box)
	cropped_image.save(save_path)

def save_segmentation(image, segmentation):
	mask_image = Image.new('L', image.size)
	for polygon in segmentation:
		polygon_points = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)]
		ImageDraw.Draw(mask_image).polygon(polygon_points, fill=255)
		masked_image = Image.new('RGBA', image.size)
		masked_image.paste(image, mask=mask_image)

	return masked_image

def crop_object_region(image, output_path):
	# 转为灰度图像
	gray_image = image.convert('L')

	# 获取非零区域的边界框
	bbox = gray_image.getbbox()

	if bbox is not None:
		# 切割出非零区域
		object_region = image.crop(bbox)

		# 保存切割后的图像
		# 将 RGBA 图像转换为 RGB 并保存为 JPEG
		convert_rgba_to_rgb(object_region, output_path)
		# print("Object region saved to:", output_path)
	else:
		print("No object region found in the image.")

def create_directory(path):
	if not os.path.exists(path):  # 判断目录是否存在
		os.makedirs(path)  # 递归创建目录
		print("Directory created:", path)
	# else:
	#     print("Directory already exists:", path)

def detect_seg(json_file, image_dir, out_dir, category_id_dict):
	# 获取 JSON 文件名
	json_filename = os.path.basename(json_file)

	# 构建对应的 JPG 文件路径
	jpg_filename = os.path.splitext(json_filename)[0] + '.jpg'
	jpg_file = os.path.join(image_dir, jpg_filename)

	if os.path.exists(jpg_file):
		# 读取 JSON 文件
		with open(json_file, 'r') as file:
			json_data = json.load(file)

		# 处理 JSON 数据,找到对应的 JPG 文件
		# TODO: 在这里添加你的逻辑,根据需要处理 JSON 数据并处理对应的 JPG 文件
		# print("Found:", json_file, jpg_file)

		# 遍历每个 item
		for item in json_data:
			# 处理每个 item
			# TODO: 在这里添加你的逻辑,根据需要处理 item  

			if 'item' in item:
				# 获取 category_id 用来生成保存类别目录
				category_id = json_data[item]['category_id']

				# 调用函数创建目录
				out_detect_dir = os.path.join(out_dir, 'detect', category_id_dict[category_id])
				create_directory(out_detect_dir)

				out_seg_dir = os.path.join(out_dir, 'seg', category_id_dict[category_id])
				create_directory(out_seg_dir)

				# 待保存的检测区域路径
				bb_path = os.path.join(out_detect_dir, jpg_filename.replace('.jpg', f'_{item}.jpg'))

				# 待保存的分割区域路径
				seg_path = os.path.join(out_seg_dir, jpg_filename.replace('.jpg', f'_{item}.jpg'))

				if not (os.path.exists(bb_path) and os.path.exists(seg_path)):
					try:
						# 读取图像
						image = Image.open(jpg_file)

						# 根据bounding_box保存切割后的图像
						bounding_box = json_data[item]['bounding_box']
						save_bounding_box(image, bounding_box, bb_path)

						# 根据segmentation获得切割后的图像
						segmentation = json_data[item]['segmentation']
						seg_region = save_segmentation(image, segmentation)

						# 调用函数进行图像切割
						crop_object_region(seg_region, seg_path)
					except:
						print(f'该图像存在问题:{jpg_file}')

	else:
		print("JPG File not found:", jpg_file)

if __name__ == "__main__":

	category_id_dict = {
		1: 'short sleeve top',
		2: 'long sleeve top',
		3: 'short sleeve outwear',
		4: 'long sleeve outwear',
		5: 'vest',
		6: 'sling',
		7: 'shorts',
		8: 'trousers',
		9: 'skirt',
		10: 'short sleeve dress',
		11: 'long sleeve dress',
		12: 'vest dress',
		13: 'sling dress'
	}

	# 将字典值中的空格替换为下划线
	category_id_dict = {key: value.replace(' ', '_') for key, value in category_id_dict.items()}

	print(category_id_dict)

	# 训练集
	anno_dir = 'train/train/annos'
	image_dir = 'train/train/image'
	out_dir = 'train_clsv2'

	# 验证集
	# anno_dir = 'validation/annos'
	# image_dir = 'validation/image'
	# out_dir = 'validation_cls'

	# 获取所有 JSON 文件路径
	# json_files = glob.glob(os.path.join(anno_dir, '*.json'))
	# 替换 golb 加快推理速度
	json_files = []
	for entry in os.scandir(anno_dir):
		if entry.is_file() and entry.name.endswith('.json'):
			json_files.append(entry.path)

	with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool: # 使用CPU核心数作为进程数
		total_files = len(json_files)
		process_func = partial(detect_seg, image_dir=image_dir, out_dir=out_dir, category_id_dict=category_id_dict)
		with tqdm(total=total_files, desc='Processing files') as pbar:
			for _ in pool.imap_unordered(process_func, json_files):
				pbar.update(1)