生产者消费者模式下实现多batch延时推理

发布时间 2023-10-30 18:04:11作者: wildkid1024

生产者消费者模式下实现多batch延时推理

需求分析

在实际推理过程中为了实现较高的吞吐量和较高的资源利用率,往往会使用多线程来收集多次请求,并组合形成多batch下的模型推理,一种常见的实现便是生产者和消费者模式,其需求如下:

  1. 生产者收集提交的请求,消费者对请求进行消费,并将结果返回。
  2. 资源是有限的,可运行的最大max_batch_size是有限的,队列中可等待的max_queue_num也是有限的。
  3. 邻近的资源可进行等待,最长要等待timeout的时间
  4. 资源封装完备,哪里使用就在哪里释放。

设计实现

  1. 在设计实现上,首先是需要有线程安全的生产者消费者队列,通过promise和future完成返回消息。
  2. 需要有设置各个参数的公开接口,以便于设置各个参数。
  3. 需要有一把锁和延时机制,使用条件变量condition_variable的wait_for函数来等待后续参数的传入。
  4. 资源的申请和释放都放置在消费者线程中,并使用promise和future来返回中间结果。
  5. 使用抽象类实现接口和封装,使用多态进行实现。

具体实现

  1. 头文件

这里声明了一个抽象类,关键要实现forward函数,create_infer外部函数实现多态调用。

#ifndef INFER_HPP
#define INFER_HPP
#include <memory>
#include <string>
#include <map>
#include <future>
namespace Infer {
using Tensor = std::string;
using RET_TYPE = std::map<std::string, Tensor>;
class InferInterface {
 public:
  // pic_tensor 表示输入给生产者的tensor, timeout 表示超时时间
  // wait_timeout 表示当队列满了的情况下,愿意最多等待 wait_timeout ms 来等待
  virtual std::shared_future<RET_TYPE> forward(const Tensor& pic_tensor, int wait_timeout = 10) = 0;
  virtual void set_timeout(const int timeout) = 0;
  virtual void set_max_queue_num(const int max_queue_num) = 0;
  virtual ~InferInterface() {}
  explicit InferInterface() {}
 protected:
  InferInterface(const InferInterface&) = delete;
  InferInterface(InferInterface&&) = delete;
  InferInterface& operator=(const InferInterface&) = delete;
};
std::shared_ptr<InferInterface> create_infer(const std::string& filepath, int max_bactch_size);
}
#endif // INFER_HPP
  1. cpp实现类

InferImpl是InferInterface的具体实现,是一种推理引擎的多态封装。

这里有4个原子操作,分别对应着running、超时时间timeout、最大可推理的batch_size以及队列的最大长度max_queue_num。
有一个锁,用以等待后续的请求,并合为一个batch。
有两个条件变量,cond_var_对应的是消费者,用以wait_for不为空且系统在runnning状态,并且queue的请求数量大于batch_size时取出batch_size请求。cond_queue_overflow_对应的是生产者,需要等待当前队列中的数据不满时才能进行填充。

forward函数里的实现生产者,当队列不满的时候向队列中插入请求。
worker函数是消费者的具体实现,会等待当前队列中请求数大于要推理的请求时才会组成batch一同推理,并将结果通过promise返回得到最终结果。

#include "infer.hpp"
#include <cstdio>
#include <queue>

namespace Infer
{
    using std::atomic;
    using std::future;
    using std::map;
    using std::mutex;
    using std::promise;
    using std::queue;
    using std::string;
    using std::thread;
    class InferImpl : public InferInterface
    {
    protected:
        /* data */
        using RET_PROMISE_PTR = std::shared_ptr<promise<RET_TYPE>>;
        struct Job
        {
            /* data */
            Tensor data;
            RET_PROMISE_PTR ret_promise;
        };

        atomic<bool> running;
        atomic<int> timeout;
        atomic<int> max_batch_size;
        atomic<int> max_queue_num;

        mutex lock;
        thread thread_;
        queue<Job> queue_;
        string context;

        std::condition_variable cond_var_;
        std::condition_variable cond_queue_overflow_;

    public:
        void set_timeout(const int timeout)
        {
            this->timeout = timeout;
        }
        void set_max_queue_num(const int max_queue_num)
        {
            this->max_queue_num = max_queue_num;
        }
        bool set_batch_size(const int batch_size)
        {
            if (batch_size < 1)
                return false;
            this->max_batch_size = batch_size;
            return true;
        }

        std::shared_future<RET_TYPE> forward(const Tensor &pic_tensor, int wait_timeout = 10)
        {
            Job job;
            job.data = pic_tensor;
            job.ret_promise = RET_PROMISE_PTR(new promise<RET_TYPE>());
            {
                std::unique_lock<mutex> l(lock);
                if (queue_.size() >= max_queue_num)
                {
                    if (0 == wait_timeout)
                    {
                        throw std::runtime_error("exhausted resource");
                    }
                    cond_queue_overflow_.wait_for(l, std::chrono::milliseconds(wait_timeout), [&]()
                                                 { return queue.size() < max_queue_num; });
                }

                if (queue_.size() >= max_queue_num)
                {
                    throw std::runtime_error("exhausted resource");
                }

                queue_.push(job);
            }
            cond_var_.notify_one();
            return job.ret_promise->get_future();
        }
        explicit InferImpl()
        {
            running = false;
            max_batch_size = 1;
            max_queue_num = 5;
            timeout = 0;
        }
        ~InferImpl() override {
            running = false;
            cond_var_.notify_one();
            if (thread_.joinable())
            {
                thread_.join();
            }
        }

        bool load_model(const string &filepath)
        {
            promise<bool> init_promise;
            thread_ = thread(&InferImpl::worker, this, filepath, std::ref(init_promise));
            running = true;
            return init_promise.get_future().get();
        }

    protected:
        void worker(const string &filepath, promise<bool> &init_promise)
        {
            // 加载模型
            context = filepath;
            if (context.empty())
            {
                init_promise.set_value(false);
                return;
            }

            init_promise.set_value(true);
            std::vector<Job> jobs;
            int batch_id = 0;

            while (running)
            {
                {
                    std::unique_lock<mutex> l(lock);
                    cond_var_.wait(l, [&](){
                        // true 则退出等待
                        return !queue_.empty() || !running; });
                    if (!running)
                        break;
                    if (0 != timeout)
                    {
                        cond_var_.wait(l, [&]()
                                       { return queue_.size() >= max_batch_size; });
                    }
                    for (int i = 0; !queue_.empty() && i < max_batch_size; i++)
                    {
                        jobs.emplace_back(queue_.front());
                        queue_.pop();
                    }
                }

                // 此处假装inference,得到运行结果

                int sz = jobs.size();
                for (auto& job:jobs){
                    auto bbox = job.data + "_result";
                    RET_TYPE handle_result;
                    handle_result["bbox"] = bbox;
                    job.ret_promise->set_value(handle_result);
                }

                jobs.clear();
                std::this_thread::sleep_for(std::chrono::milliseconds(1000 * sz));
                printf("batch id: %d job size: %d \n", batch_id, sz);
                cond_queue_overflow_.notify_one();
                ++batch_id;
            }

            context.clear();
            puts("context_ has cleared");
            puts("Workder done!");      
        }
    };

    std::shared_ptr<InferInterface> create_infer(const string& file_path, int max_batch_size){
        auto infer_ptr = new InferImpl();
        infer_ptr->set_batch_size(max_batch_size);

        if (! infer_ptr->load_model(file_path)){
            delete infer_ptr;
            return nullptr;
        }
        return std::shared_ptr<InferInterface>(infer_ptr);
    }
}
  1. main文件

在main文件中,创建了24个请求,每5个请求将组成一个batch进行推理。
这里有个小问题是,当最后一个请求小于5个时会一直等待,因此应当设置超时机制对剩余的请求进行一并处理。

#include <iostream>
#include <vector>
#include "infer.hpp"

using std::string;
using std::errc;

using RET_FUTURE = std::shared_future<Infer::RET_TYPE>;

int main(int argc, char const *argv[])
{
    /* code */
    string file_path = ""; 
    std::shared_ptr<Infer::InferInterface>infer_ptr = Infer::create_infer(file_path, 5);
    if (infer_ptr == nullptr){
        printf("create infer engine error\n");
        return -1;
    }
    // 每个实例推理愿意等待的时间
    infer_ptr->set_timeout(5);
    infer_ptr->set_max_queue_num(10);
    printf("create infer engine success!\n");

    std::vector<RET_FUTURE> shared_ptrs;
    char buffer[100];
    for (int i=0; i<24; ++i){
        sprintf(buffer, "%d.tensor", i);

        // 在队列满的时候,生产者愿意等1000ms
        shared_ptrs.push_back(infer_ptr->forward(buffer, 1000));
    }

    for (auto &shard_ptr:shared_ptrs){
        shard_ptr.get();
    }

    return 0;
}