import torch.nn.functional as F def pixel_unshuffle(input, downscale_factor): batch_size, channels, height, width = input.size() unfolded = F.unfold(input, kernel_size=downscale_factor, stride=downscale_factor) # 将展开后的二维张量重塑为指定形状的三维张量 output = unfolded.view(batch_size, channels, downscale_factor ** 2, height // downscale_factor, width // downscale_factor) # 交换维度以符合输出要求的形状 output = output.permute(0, 1, 3, 2, 4).contiguous() output = output.view(batch_size, channels * downscale_factor ** 2, height // downscale_factor, width // downscale_factor) return output