NCHW

发布时间 2023-03-31 14:55:55作者: 开发者的灵感

代码:

# -*- coding: utf-8 -*-
import numpy as np

def nchw_to_nc1hwc0(data, batch_size, num_channels, height, width, block_size):
    assert data.shape == (batch_size, num_channels, height, width)
    
    c0 = block_size
    c1 = (num_channels + c0 - 1) // c0
    
    nc1hwc0_data = np.zeros((batch_size, c1, height, width, c0), dtype=data.dtype)
    for b in range(batch_size):
        for i in range(num_channels):
            for j in range(height):
                for k in range(width):
                    nc1hwc0_data[b, i//c0, j, k, i%c0] = data[b, i, j, k]
                        
    return nc1hwc0_data

def nc1hwc0_to_nchw(data, batch_size, num_channels, height, width, block_size):
    assert data.shape == (batch_size, (num_channels + block_size - 1) // block_size, height, width, block_size)
    
    c0 = block_size
    c1 = (num_channels + c0 - 1) // c0
    
    nchw_data = np.zeros((batch_size, num_channels, height, width), dtype=data.dtype)
    for b in range(batch_size):
        for i in range(num_channels):
            for j in range(height):
                for k in range(width):
                    nchw_data[b, i, j, k] = data[b, i//c0, j, k, i%c0]
                        
    return nchw_data

def nchw_to_nc1hwc0_1(data, batch_size, num_channels, height, width, block_size):
    assert data.shape == (batch_size, num_channels, height, width)

    c0 = block_size
    c1 = (num_channels + c0 - 1) // c0
    
    nc1hwc0_data = np.zeros((batch_size, c1, height, width, c0), dtype=data.dtype)
    for b in range(batch_size):
        for i in range(num_channels):
            c1_idx = i // c0
            c0_idx = i % c0
            nc1hwc0_data[b, c1_idx, :, :, c0_idx] = data[b, i, :, :]
            
    return nc1hwc0_data

def nc1hwc0_to_nchw_1(data, batch_size, num_channels, height, width, block_size):
    assert data.shape == (batch_size, (num_channels + block_size - 1) // block_size, height, width, block_size)
    
    c0 = block_size
    c1 = data.shape[1]
    
    nchw_data = np.zeros((batch_size, num_channels, height, width), dtype=data.dtype)
    for b in range(batch_size):
        for i in range(num_channels):
            c1_idx = i // c0
            c0_idx = i % c0
            nchw_data[b, i, :, :] = data[b, c1_idx, :, :, c0_idx]
            
    return nchw_data

batch_size = 6
num_channels = 11  # 可以设置为任何正整数
height = 7
width = 11
block_size = 16

data = np.random.rand(batch_size, num_channels, height, width)

nc1hwc0_data = nchw_to_nc1hwc0(data, batch_size, num_channels, height, width, block_size)
nc1hwc0_data_1 = nchw_to_nc1hwc0_1(data, batch_size, num_channels, height, width, block_size)
print(nc1hwc0_data.shape)  # (2, 3, 8, 8, 16)
print(nc1hwc0_data_1.shape)

nchw_data = nc1hwc0_to_nchw(nc1hwc0_data, batch_size, num_channels, height, width, block_size)
nchw_data_1 = nc1hwc0_to_nchw(nc1hwc0_data, batch_size, num_channels, height, width, block_size)
print(nchw_data.shape)  # (2, 33, 8, 8)
print(nchw_data_1.shape)

assert np.allclose(data, nchw_data)  # 验证转换后得到的数据与原始数据相同
assert np.allclose(data, nchw_data_1)
assert np.allclose(nchw_data, nchw_data_1)