[fastllm]多线程下动态组batch实现解析

发布时间 2023-08-26 15:23:38作者: wildkid1024

[fastllm]多线程下动态组batch实现解析

需求分析

新版本的fastllm中添加了ForwardBatch的功能,用于处理批量推理请求,单次推理请求会被视为batch为1的批量请求,这样做似乎没什么问题。
然而在具体实践中,用户的请求往往是一个一个来的,每来一个请求都要等上一个请求完成之后才能开始推理下一个请求,一旦并发数起来,体验将及其糟糕。

所幸,stream流式输出能够在一定程度上缓解这个问题,web前端调用是一个异步线程队列,那么多个用户间的web 前端IO的时间差恰好可以给推理留出一定时间,虽然后台依然是一个token一个token地进行推理,但前端却看起来能够多用户并发使用。

但不幸的是,这种方法只在一定程度上解决问题,当用户量变多时,后台由于大量堆积的待推理的tokens,用户的体验又将变得十分糟糕。

在fastllm中,使用的是动态组batch的方案,即当A请求正在运行的token和B请求正在运行的token进行组合,合并为一个batch,在模型中并行推理,以提高模型实际运行时的吞吐。

具体实现

主要实现为两个函数,LaunchResponseTokens和FetchResponseTokens,其中LaunchResponseTokens主要根据当前输入增加一个监听token推理结果的线程,并返回context的handle;而FetchResponseTokens则根据给定的handle去队列中fetch对应的token结果。

LaunchResponseTokens的实现

LaunchResponseTokens函数可以拆成两部分去看,第一部分是mainLoopLocker.lock();到mainLoopLocker.unlock();这部分主要是创建并维护主线程。第二部分则是从dictLocker.lock();到dictLocker.unlock();这部分则是创建handle并向responseContextDict中添加初始化参数。

第二部分比较简单,可以忽略,要重点看看第一部分。第一部分是一个循环,可细分为前处理期、运行期和后处理期3部分,以model->dictLocker.lock();进行区分,是因为修改的是全局的dict。其中前处理期主要根据model config得到attentionMasks和positionIds,这两者都是以 std::vector <Data*> positionIds的类型存储的,不同的handle存储不同的参数即可,比较有意思的是inputIds,它是一个[1, all inputs len]的向量,所以需要一个seqLens来记录每个线程对应的inputs的长度。在运行期则将合成的batch输出送入到模型当中去,以并行的方式运行,不过在笔者的这版源码中,除了涉及inputs_ids部分是并行处理外,其他都是将batch进行展开计算的,也即在Attention之前的layerNorm、QKV Linear以及Attention后的FFN是多batch并行计算的,关键的Attention部分由于涉及到attentionMasks和positionIds还是需要拆batch来进行计算。不过在最新的代码中,作者已经将所有线程都处在token len为1时的状况进行了优化,这在多个长文本回复将有比较明显的加速。最后是后处理部分,这部分将各个线程对应的token取出,放入到消费者队列中,等待FetchResponseTokens的fetch。

PS: 这里mainLoop线程在启动前加了双重判断,理应来说mainLoopLocker.lock();应该是放在第一重判断和第二重判断之间的,如果在外面加lock的话,一重判断就应该是ok的。不知道是我理解错了,还是作者手滑,有待验证。

 int ChatGLMModel::LaunchResponseTokens(const std::vector<int> &inputTokens,
                                           const GenerationConfig &generationConfig) {
        mainLoopLocker.lock();
        if (mainLoop == nullptr) {
            if (mainLoop == nullptr) {
                mainLoop = new std::thread([](ChatGLMModel *model) {
                    while (true) {
                        std::vector <Data*> attentionMasks;
                        std::vector <Data*> positionIds;
                        std::vector <std::pair <Data*, Data*> > pastKeyValues;
                        std::vector <float> ids;
                        std::vector <int> seqLens;
                        std::vector <int> handles;
                        std::vector <GenerationConfig> generationConfigs;
                        LastTokensManager tokensManager;
                        model->dictLocker.lock();
                        for (auto &it: model->responseContextDict.dicts) {
                            if (it.second->isEnding) {
                                continue;
                            }
                            generationConfigs.push_back(it.second->generationConfig);
                            tokensManager.units.push_back(it.second->tokens);
                            handles.push_back(it.first);
                            for (int i = 0; i < it.second->currentTokens.size(); i++) {
                                ids.push_back(it.second->currentTokens[i]);
                            }
                            if (it.second->preTokens == 0) {
                                int seqLen = it.second->currentTokens.size();
                                if (model->GetVersion() == 1) {
                                    int gmask_token_id =
                                            model->weight.dicts.find("gmask_token_id") != model->weight.dicts.end() ?
                                            atoi(model->weight.dicts["gmask_token_id"].c_str()) : 130001;
                                    if (it.second->currentTokens.size() < 2 ||
                                        it.second->currentTokens.back() != model->bos_token_id) {
                                        ids.push_back(gmask_token_id);
                                        ids.push_back(model->bos_token_id);
                                        seqLen += 2;
                                    }
                                } else {
                                    if (it.second->currentTokens.size() < 2 ||
                                        it.second->currentTokens[0] != 64790) {
                                        ids.insert(ids.begin() + (ids.size() - it.second->currentTokens.size()), 64790);
                                        ids.insert(ids.begin() + (ids.size() - it.second->currentTokens.size()), 64792);
                                        seqLen += 2;
                                    }
                                }

                                seqLens.push_back(seqLen);
                                std::vector<float> vmask = std::vector<float>(seqLen * seqLen, 0);
                                std::vector<float> vpids = std::vector<float>(seqLen * 2, 0);
                                for (int i = 0; i < seqLen - 1; i++) {
                                    vmask[i * seqLen + seqLen - 1] = 1;
                                    vpids[i] = i;
                                }
                                vpids[seqLen - 1] = seqLen - 2;
                                vpids[seqLen * 2 - 1] = 1;

                                if (model->GetVersion() == 2) {
                                    for (int i = 0; i < seqLen; i++) {
                                        vpids[i] = i;
                                        for (int j = i + 1; j < seqLen; j++) {
                                            vmask[i * seqLen + j] = 1;
                                        }
                                    }
                                }

                                it.second->intParams["maskIds"] = seqLen - (model->GetVersion() == 1 ?  2 : 0);
                                it.second->intParams["len"] = 1;

                                attentionMasks.push_back(new Data(DataType::FLOAT32, {seqLen, seqLen}, vmask));
                                positionIds.push_back(new Data(DataType::FLOAT32, {2, seqLen}, vpids));
                            } else {
                                seqLens.push_back(1);
                                it.second->intParams["len"]++;
                                attentionMasks.push_back(nullptr);
                                positionIds.push_back(new Data(DataType::FLOAT32, {2, 1}, {(float)it.second->intParams["maskIds"], (float)(it.second->intParams["len"])}));
                                if (model->GetVersion() == 2) {
                                    it.second->intParams["maskIds"]++;
                                }
                            }

                            it.second->preTokens += seqLens.back();
                            for (int i = 0; i < model->block_cnt; i++) {
                                pastKeyValues.push_back(std::make_pair(&it.second->pastKeyValues[i].first,
                                                                       &it.second->pastKeyValues[i].second));
                            }
                        }

                        if (seqLens.size() > 0) {
                            model->dictLocker.unlock();
#ifdef USE_CUDA
                            FastllmCudaClearBigBuffer();
#endif
                            Data inputIds = Data(DataType::FLOAT32, {1, (int) ids.size()}, ids);
                            std::vector<int> ret = model->ForwardBatch(seqLens.size(), inputIds, attentionMasks,
                                                                       positionIds, seqLens, pastKeyValues, generationConfigs, tokensManager);
                            model->dictLocker.lock();
                            for (int i = 0; i < handles.size(); i++) {
                                auto &it = *model->responseContextDict.dicts.find(handles[i]);
                                int curRet = ret[i];
                                if (curRet == model->eos_token_id) {
                                    it.second->isEnding = true;
                                } else {
                                    it.second->currentTokens = std::vector<int>{curRet};
                                    it.second->resultTokenQueue.push(curRet);
                                    it.second->tokens.Push(curRet);
                                    it.second->curTokens++;
                                    if (it.second->curTokens == it.second->generationConfig.output_token_limit) {
                                        it.second->isEnding = true;
                                    }
                                }
                            }
                        }

                        for (int i = 0; i < attentionMasks.size(); i++) {
                            delete attentionMasks[i];
                        }
                        for (int i = 0; i < positionIds.size(); i++) {
                            delete positionIds[i];
                        }

                        model->dictLocker.unlock();
                        MySleep(0);
                    }
                }, this);
            }
        }
        mainLoopLocker.unlock();

        dictLocker.lock();
        int handleId = responseContextDict.CreateHandle();
        ResponseContext *context = responseContextDict.GetHandle(handleId);
        context->Init(this->block_cnt);
        context->currentTokens = inputTokens;
        context->generationConfig = generationConfig;
        context->tokens = LastTokensUnit(generationConfig.last_n);
        dictLocker.unlock();
        return handleId;
    }

FetchResponseTokens函数的实现

这部分功能就是消费者,从消费者队列中取之前生成的token即可。

实现逻辑上比较简单,从responseContextDict根据handle,找到对应的context,然后循环不断地fetch他token直到ending即可。这里有个有意思的问题,while (true)是在不断轮询队列中的token,实际上是一种简单但不太高效的写法,生产者消费者问题在系统中是一个很经典的问题。

 int ChatGLMModel::FetchResponseTokens(int handleId) {
        dictLocker.lock();
        ResponseContext *context = responseContextDict.GetHandle(handleId);
        if (context == nullptr) {
            dictLocker.unlock();
            return -1;
        } else {
            while (true) {
                if (context->resultTokenQueue.size() > 0) {
                    int ret = context->resultTokenQueue.front();
                    context->resultTokenQueue.pop();
                    dictLocker.unlock();
                    return ret;
                } else {
                    if (context->isEnding) {
                        responseContextDict.RemoveHandle(handleId);
                        dictLocker.unlock();
                        return -1;
                    }
                }
                dictLocker.unlock();
                MySleep(0);
                dictLocker.lock();
            }
        }
    }

总结与讨论

通过构建context封装的的方式来对token进行管理,通过context字典来记录不同线程的的tokens,主线程中则对多个线程下的token和输入配置进行拼接,batch并行推理后并将结果写入到各个context中,前台则通过不同handle取对应的token,这种设计可以极大提高系统的吞吐,增强用户体验。

不过仍然有一些可讨论的点,比如forwardbatch中参数可以改为纯Data*类型的数据,不过这样的话就需要1. 区分第一次batch和后续batch,在实现上第一次运行不组batch 或者2. 进行padding,但这不是一个太好的思路。另外就是自定义手写的函数可以这么玩,但如果是onnx、Trt类似的静态图,做这样的实现可能会有一些困扰。

参考

  1. Orca: A Distributed Serving System for Transformer-Based Generative Models[OSDI22][SNU]