LLM采样后处理总结:LLM的后处理的cpp实现

发布时间 2023-10-11 18:15:19作者: wildkid1024

LLM采样后处理总结:LLM的后处理的cpp实现

在经过LLM的lm_head之后,会得到[batch, vocab_size]大小的矩阵向量,此时需要对输出的逻辑张量进行采样,除了beam_search的贪心策略,还有repetition_penalty、temperature、top_k、top_p等几种控制采样的方法。

repetition_penalty

repetition_penalty的主要作用是控制重复,这里first和last分别为vocab中的第一个元素和最后一个元素的位置,input_ids为之前输出的文本id。
也即是把之前输出过的内容全部变小,那么就可以防止文本出现不断重复的情况,penalty越小,惩罚力度越大,penalty越大,惩罚力度越小,重复概率就会增加。

void sampling_repetition_penalty(float *first, float *last, const std::vector<int> &input_ids,
                                                       float penalty) {
    std::unordered_set<int> unique_input_ids(input_ids.begin(), input_ids.end());
    for (int id : unique_input_ids) {
        if (first[id] > 0) {
            first[id] /= penalty;
        } else {
            first[id] *= penalty;
        }
    }
}

temperature

temperature是控制softmax下的平滑参数,相当于在softmax前每个逻辑值都进行了放缩。
当temp越大的时候,此时softmax值之间的差距会减小,分布就越均匀,此时采样出的结果就越随机,反之就会使得原本高概率的的变得更高低的更低减少了随机性。

void sampling_temperature(float *first, float *last, float temp) {
    float inv_temp = 1.f / temp;
    for (float *it = first; it != last; it++) {
        *it *= inv_temp;
    }
}

top_k

top_k是取前k个,直接排序拿到概率最大的前k个。

void sampling_top_k(TokenIdScore *first, TokenIdScore *kth, TokenIdScore *last) {
    std::nth_element(first, kth, last, std::greater<TokenIdScore>());
}

top_p

top_p是先对所有的值进行softmax,然后找到满足sum_p <= top_p的最小集合,然后对这个集合内的数再进行softmax和采样。
一种简单的做法是将所有值进行排序,然后贪心找到满足条件的前k个。
示例代码中使用了一种类似于快速排序的方法,每次找mid点,将大于mid和小于mid的分为两堆,要么在大的一堆要么在小的一堆。
当在大的一堆中时就mid往前移动,在小的一堆时则更新top_p = top_p-sum_p,直至找到对应的位置。
时间复杂度上会稍微比先排序快一些。

void sampling_softmax_inplace(TokenIdScore *first, TokenIdScore *last) {
    float max_score = std::max_element(first, last)->score;
    float sum = 0.f;
    for (TokenIdScore *p = first; p != last; p++) {
        float s = std::exp(p->score - max_score);
        p->score = s;
        sum += s;
    }
    float inv_sum = 1.f / sum;
    for (TokenIdScore *p = first; p != last; p++) {
        p->score *= inv_sum;
    }
}
TokenIdScore *sampling_top_p(TokenIdScore *first, TokenIdScore *last, float top_p) {
    // fast top_p in expected O(n) time complexity
    sampling_softmax_inplace(first, last);

    while (first + 1 < last) {
        float pivot_score = (last - 1)->score; // use mid score?
        TokenIdScore *mid =
            std::partition(first, last - 1, [pivot_score](const TokenIdScore &x) { return x.score > pivot_score; });
        std::swap(*mid, *(last - 1));

        float prefix_sum =
            std::accumulate(first, mid, 0.f, [](float sum, const TokenIdScore &x) { return sum + x.score; });
        if (prefix_sum >= top_p) {
            last = mid;
        } else if (prefix_sum + mid->score < top_p) {
            first = mid + 1;
            top_p -= prefix_sum + mid->score;
        } else {
            return mid + 1;
        }
    }
    return last;
}