Faster sorting algorithms discovered using deep reinforcement learning

发布时间 2023-06-26 09:14:13作者: 叫我小辰就好了

摘要:

  • AlphaDev模型优化排序算法,将排序算法提速70%。通过强化学习,AlphaDev发现了更加有效的算法,直接超越了科学家和工程师们几十年来的精心打磨。现在,新的算法已经成为两个标准C++编码库的一部分,每天都会被全球的程序员使用数万亿次。

介绍

  • 优化目标为排序算法的CPU延迟时间

  • 两类排序问题:

    • 固定长度排序:sort 3排序模型只能处理序列长度为3的排序问题
    • 可变长度排序:sort 5排序模型可以处理序列长度为1-5的排序问题
  • 作者将排序问题建立为一个名为AssemblyGame的single-player game。在这个游戏中,玩家(agent)需要选择一系列的汇编指令,组合成为一种新的排序算法

  • 这个问题有两个难点:

    • 搜索空间很大
    • reward function很难设计
      AssemblyGame中的一条错误指令可能会使整个算法失效,这使得在这个游戏空间中的探索极具挑战性
  • 设计了一种名为AlphaDev的模型,由两个核心部分构成:

    • a learning algorithm,与AlphaZero深度强化学习算法结构十分相似
    • a representation function,基于Transformers
  • 效果:使用AlphaDev发现了长度分别为345的三种固定排序算法已经集成进入LLVM标准C++库中

如何将算法表示为低级的CPU汇编指令

  • 编译原理背景知识
    • 编译过程:高级语言(C++)→汇编指令→机器码→CPU执行
    • 变量转移:内存→寄存器→不同寄存器数据运算→内存→输出
    • 汇编指令集取决于处理器架构,如x86与x32,amd与arm
  • 案例:图a为最大长度为2的可变长度排序C++代码,图b为对应的汇编指令,其中%eax, %ecx, %edx, %edi分别为4个寄存器地址,(%rsi), 4(%rsi)表示两个内存地址

如何使用DRL发现更快的排序算法

  • 状态被定义为\(S_{t} = <P_{t}, Z_{t}>\),其中\(P_{t}\)表示当前时刻生成的汇编指令集合,\(Z_{t}>\)表示寄存器的状态。\(t\)时刻,接收当前状态\(S_{t}\),采取动作\(a_{t}\),将一条汇编指令添加到汇编指令集合中,

  • reward设计:

    • 正确性奖励:给予一组测试输入,比较测试输出与真值
    • 时间成本奖励:增加序列长度,判断算法是否受到长度影响很大;测量实际算法耗时
  • 输掉游戏:汇编指令不正确/汇编指令耗时很长

  • 使用类似AlphaZero的深度强化学习网络结构,输入为动作\(S_{t}\),输出为策略与价值预测

  • AlphaDev需要具备一种表征复杂算法的能力,从而可以加速探索。因此设计了一种表征网络,如下图a,使用Transformer Encoder用于表示当前算法的结构,使用CPU state encoder用于预测当前内存与寄存器的状态

  • Transformer encoder:首先使用one-hot编码表示汇编指令集与地址,随后通过TF encoder生成embedding,如下图b所示

  • 如何预测代码的执行效果?

    • two value function heads:one predicting algorithm correctness and the second predicting algorithm latency

Method

Background

  • AlphaZero,由两部分组成,(1)表征网络\(f^{rep}\)用于当前状态\(S_{t}\)的潜在表征\(h_{t}\),(2)预测网络\(f^{pred}\)用于预测期望回报\(\hat{v}_{t}\)与策略\(\hat{\pi}_{t}\)。在达到新状态时,AlphaZero首先通过表示网络将状态编码为潜在表示

\[\mathrm{h}_{t}=f^{r e p}\left(\mathbf{S}_{t}\right) \]

\[\hat{v}_{t}, \hat{\pi}_{t}=f^{p r e d}\left(\mathrm{~h}_{t}\right) \]

  • Sorting networks,不需要依赖数据的排序方法,所有比较操作都与数据无关,并且排序网络可以实现并行,大大降低算法复杂度