【代码分享】使用 avx512 + 查表法,优化凯撒加密

发布时间 2023-10-17 15:22:33作者: ahfuzhang

作者:张富春(ahfuzhang),转载时请注明作者和引用链接,谢谢!


关于凯撒加密,具体请看:https://en.wikipedia.org/wiki/Caesar_cipher
总而言之就是玩点没什么用的小心眼,把字母的顺序变化一下。

第一版:根据业务逻辑直接实现:

void caesarEncodeV0(uint8_t* out, uint8_t* in, int len, int rot){
    rot = rot % 26;
    uint8_t* end = in + len; 
    uint32_t* line = table.Table[rot];
    for (;in<end; in++, out++){
        if (*in>='a' && *in<='z'){
            *out = (*in - 'a' + rot)%26 + 'a';
        } else if (*in>='A' && *in<='Z'){
            *out = (*in - 'A' + rot)%26 + 'A';
        } else {
            *out = *in;
        }
    }
}

void testCaesar(){
    const char* s = "QAULi2jah2eqSD1zQAULhuG0Qs9mhOF9TDGtFAGtFqB=";
    int len = strlen(s);
    uint8_t* out = malloc(len+1);
    int rot = 4;
    caesarEncodeV0(out, s, len, rot);
    out[len] = '\0';
    printf("in :%s\n", s);
    printf("out:%s | caesarEncodeV0\n", out);
}

第二版:使用查表法

很明显,字符间的替换,可以预先放在一个数组里,然后查表就行了。

typedef struct{
    uint32_t Table[26][256];  // uint32_t 要比 uint8_t 更好,猜测是因为字节对齐的原因
} __attribute__((packed)) CaesarTable;

// 预先计算替换规则后的结果
void initTable(uint32_t *table[26][256]){
    for (int i=0; i<26; i++){
        for (int j=0; j<256; j++){
            if (j>='a' && j<='z'){
                table[i][j] = (uint8_t)((j-'a'+i)%26 + 'a');
            } else if (j>='A' && j<='Z'){
                table[i][j] = (uint8_t)((j-'A'+i)%26 + 'A');
            } else {
                table[i][j] = j;
            }
        }
    }
}

void caesarEncodeV1(uint8_t* out, uint8_t* in, int len, int rot, uint32_t *table[26][256]){
    rot = rot % 26;
    uint8_t* end = in + len; 
    uint32_t* line = table[rot];
    for (;in<end; in++, out++){
        *out = (uint8_t)line[*in];  // 直接查表得到结果
    }
}

CaesarTable table;

void testCaesar(){
    initTable(&table.Table);
    //
    const char* s = "QAULi2jah2eqSD1zQAULhuG0Qs9mhOF9TDGtFAGtFqB=";
    int len = strlen(s);
    uint8_t* out = malloc(len+1);
    int rot = 4;
    caesarEncodeV1(out, s, len, rot, &table.Table);
    out[len] = '\0';
    printf("in :%s\n", s);
    printf("out:%s | caesarEncodeV1\n", out);
}

第三版:使用 simd,一个周期内查表多次

查看文档发现,只有 avx512 指令集才能很好的支持查表操作。

void caesarEncodeSIMDV2(uint8_t* out, uint8_t* in, int len, int rot, uint32_t *table[26][256]){
    rot = rot % 26;
    uint32_t* line = table[rot];
    const batchSize = 16;
    uint8_t* end = in + len - (len&0x0f);   // 每个批次处理 16 字节,不够 16 字节的尾部要单独处理
    uint8_t* start = in;
    for (; start<end; start += batchSize, out += batchSize){
        _mm_storeu_epi8(                           // step5: 把 16 个 int8 存储到目的地址
            out, _mm512_cvtepi32_epi8(        // step4: 把  16  个 int32 的查表结果,转换成  16 个  int8
                _mm512_i32gather_epi32(       // step3: 把 16 个 int32 当成偏移量,在 table 开始的地址里面查询. 最后一个参数 4,表示查表中每个元素的偏移量是 4 字节
                    _mm512_cvtepu8_epi32(     // step2: 把 16 个 int8 转换成 16 个  int32 
                        _mm_loadu_si128(start)  // step1: 以非对齐的方式,从源地址加载 16 字节
                    ), line, 4))
        );
    }
    end = in + len;
    for (; start<end; start++, out++){
        *out = (uint8_t)line[*start];
    }
}

编译命令行为:

gcc -o caesar caesar.c -g -w -mavx -mavx2 -mavx512f -mavx512vl -mavx512bw -O2

最后测试的结果为:

  • 逐个字符查表:67.041 ns/op
  • avx512 查表:36.371 ns/op