1

发布时间 2023-04-23 20:06:33作者: helloWorldhelloWorld
import torch.nn.functional as F

def pixel_unshuffle(input, downscale_factor):
    batch_size, channels, height, width = input.size()
    unfolded = F.unfold(input, kernel_size=downscale_factor, stride=downscale_factor)
    # 将展开后的二维张量重塑为指定形状的三维张量
    output = unfolded.view(batch_size, channels, downscale_factor ** 2, height // downscale_factor, width // downscale_factor)
    # 交换维度以符合输出要求的形状
    output = output.permute(0, 1, 3, 2, 4).contiguous()
    output = output.view(batch_size, channels * downscale_factor ** 2, height // downscale_factor, width // downscale_factor)
    return output