代码:
# -*- coding: utf-8 -*-
import numpy as np
def nchw_to_nc1hwc0(data, batch_size, num_channels, height, width, block_size):
assert data.shape == (batch_size, num_channels, height, width)
c0 = block_size
c1 = (num_channels + c0 - 1) // c0
nc1hwc0_data = np.zeros((batch_size, c1, height, width, c0), dtype=data.dtype)
for b in range(batch_size):
for i in range(num_channels):
for j in range(height):
for k in range(width):
nc1hwc0_data[b, i//c0, j, k, i%c0] = data[b, i, j, k]
return nc1hwc0_data
def nc1hwc0_to_nchw(data, batch_size, num_channels, height, width, block_size):
assert data.shape == (batch_size, (num_channels + block_size - 1) // block_size, height, width, block_size)
c0 = block_size
c1 = (num_channels + c0 - 1) // c0
nchw_data = np.zeros((batch_size, num_channels, height, width), dtype=data.dtype)
for b in range(batch_size):
for i in range(num_channels):
for j in range(height):
for k in range(width):
nchw_data[b, i, j, k] = data[b, i//c0, j, k, i%c0]
return nchw_data
def nchw_to_nc1hwc0_1(data, batch_size, num_channels, height, width, block_size):
assert data.shape == (batch_size, num_channels, height, width)
c0 = block_size
c1 = (num_channels + c0 - 1) // c0
nc1hwc0_data = np.zeros((batch_size, c1, height, width, c0), dtype=data.dtype)
for b in range(batch_size):
for i in range(num_channels):
c1_idx = i // c0
c0_idx = i % c0
nc1hwc0_data[b, c1_idx, :, :, c0_idx] = data[b, i, :, :]
return nc1hwc0_data
def nc1hwc0_to_nchw_1(data, batch_size, num_channels, height, width, block_size):
assert data.shape == (batch_size, (num_channels + block_size - 1) // block_size, height, width, block_size)
c0 = block_size
c1 = data.shape[1]
nchw_data = np.zeros((batch_size, num_channels, height, width), dtype=data.dtype)
for b in range(batch_size):
for i in range(num_channels):
c1_idx = i // c0
c0_idx = i % c0
nchw_data[b, i, :, :] = data[b, c1_idx, :, :, c0_idx]
return nchw_data
batch_size = 6
num_channels = 11 # 可以设置为任何正整数
height = 7
width = 11
block_size = 16
data = np.random.rand(batch_size, num_channels, height, width)
nc1hwc0_data = nchw_to_nc1hwc0(data, batch_size, num_channels, height, width, block_size)
nc1hwc0_data_1 = nchw_to_nc1hwc0_1(data, batch_size, num_channels, height, width, block_size)
print(nc1hwc0_data.shape) # (2, 3, 8, 8, 16)
print(nc1hwc0_data_1.shape)
nchw_data = nc1hwc0_to_nchw(nc1hwc0_data, batch_size, num_channels, height, width, block_size)
nchw_data_1 = nc1hwc0_to_nchw(nc1hwc0_data, batch_size, num_channels, height, width, block_size)
print(nchw_data.shape) # (2, 33, 8, 8)
print(nchw_data_1.shape)
assert np.allclose(data, nchw_data) # 验证转换后得到的数据与原始数据相同
assert np.allclose(data, nchw_data_1)
assert np.allclose(nchw_data, nchw_data_1)