《Language Model Cascades》论文学习

发布时间 2023-07-21 15:14:50作者: 郑瀚Andrew

一、Introduction

语言模型 (LM) 已展现出令人印象深刻的小样本学习能力,很多人建议应该将LM视为一个基础通用推理计算器,这个基础通用推理计算器可以被用于例如:

  • scratchpads
  • chain of thought prompting
  • learned verifiers
  • selection-inference
  • bootstrapping
  • been applied in formal mathematics settings to guide theorem provers

等场景中。

要完成上述这些”计算“场景中,有两种和LM的交互范式可以采用:

  • 通过prompt技术引导单个LM进行step-by-step的交互式推理
  • 通过将多个LM(采用了不同的微调方式,或者使用了不同的prompt方式)进行串联进行交互式推理

我们认为,概率编程语言(probabilistic programming languages,PPL)会在未来成为统一上述问题的一种通用框架。

和传统编程框架最大的不同是,概率变成语言(PPL)不使用整型或者浮点型数字作为输入和输出,取而代之的是自然语言字符串。

也就是说,我们使用PPL来定义基于字符串值随机变量的联合概率模型,同时以字符串作为输入参数,LM模型计算得到的后验结果(model inference)也是一个字符串。

我们将这个概率程序称为”语言模型级联(language model cascade)“。我们希望language model cascade能够使面向业务目标的端到端开发(通用计算和推理)成为可能。 

从概率统计的角度,我们来重新理解一下LM prompting技术。

zero-shot prompt的预测过程可以视为如下的条件概率预测,

  • Q:代表我们向LLM输入的问题
  • θ:代表预训练模型的模型参数
  • A:代表模型预测答案的概率分布

few-shot prompt的预测过程可以视为如下的条件概率预测, 

  • D:代表着一组question-answer pairs,

fine-tune/sft model的预测过程可以视为如下的条件概率预测, 

  • θ’:代表fine-tune/sft模型的模型参数

给LM推理过程增加thought辅助,可以视为如下的条件概率预测,

  • T:代表问题分解思考

上述条件概率分解公式,很直观地阐述了,如果希望LLM进行thought inference,需要在prompt中提供一个对应的question-thought-answer示例。

chain-of-thought prompt的预测过程可以视为如下的条件概率预测, 

 

  • prompt中的question-thought-answer示例可以是多个 

参考链接:

https://model-cascades.github.io/
https://arxiv.org/pdf/2207.10342.pdf

 

二、Cascades介绍

在本节中,我们将展示如何创建 cascades LM,以此解决各种基于语言的推理问题。

级联是一个概率程序,包含从LM采样得到的字符串随机变量。下图 2 是一个用于简单问答任务的cascade程序。

上图是一个Chain of thought的cascade程序例子,每一个Yield 表达式返回一个符合语言模型 S 的字符串概率分布。这个程序定义了一个包含”question“、”thought“、”answer“变量的联合概率分布。

其他的cascade程序例子如下:

The basic question answering graph directly generates the answer given the question

Self critique introduces a step in which the model critiques its own reasoning in natural languag 

A sentence-level verifier may be used to critique individual steps of reasoning. Furthermore, when to halt generation may itself be a random variable 

Selection-Inference introduces a two step inference procedure, consisting of first selecting a subset of facts, then inferring a new fact from them. 

我们将cascade作为一种基于跟踪的概率编程语言,嵌入 Python 中,构建出了一种新的概率编程语言。同时cascade支持任意控制流和递归。

虽然我们在本文中的演示是关于因果语言模型的few-shot promping,但理论上,cascade可以适用于微调模型、特殊 LM mask 设定,以及其他复杂数据类型(例如图像)。

 

三、Cascades代码示例

Cascades是一个Python库,支持语言模型的复杂组合,例如

  • scratchpads
  • chain of thought
  • tool use
  • selection-inference

Cascades可以被嵌入Python程序中,作为一个通用的、基于跟踪的概率编程库(universal trace-based probabilistic programming library)来使用。

0x1:Scratchpads and Chain of thought

# Copyright 2023 The cascades Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Basic tests for Scratchpads."""
from absl.testing import absltest
import cascades as cc
from cascades.examples import scratchpad


class ScratchpadTest(absltest.TestCase):

  def test_sample_solution(self):
    examples = scratchpad.load_chain_examples()
    print("examples: ", examples)
    target = scratchpad.ReasonIO(
        question='Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
        reason=None,
        answer='72',
        id='test/123')

    mock_lm = cc.mock_lm(
        response=': Half of 48 is 24. 24 + 48 is 72.\nAnswer: 72\n===')

    model = scratchpad.sample_with_prompts(
        lm=mock_lm,
        target=target,
        examples=examples,
        n_prompts=3)
    trace = model.sample(seed=0)
    print("trace: ", trace)

    self.assertEqual('72', trace.return_value)
    self.assertEqual('test/123', trace['problem_id'].value)


if __name__ == '__main__':
  absltest.main()


/*

examples:  (ReasonIO(question='There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?', reason='We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees.', answer='6', id=None), ReasonIO(question='If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?', reason='There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars.', answer='5', id=None), ReasonIO(question='Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?', reason='Leah had 32 chocolates and Leah’s sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates.', answer='39', id=None), ReasonIO(question='Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?', reason='Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops.', answer='8', id=None), ReasonIO(question='Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?', reason='He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys.', answer='9', id=None), ReasonIO(question='There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?', reason='There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers.', answer='29', id=None), ReasonIO(question='Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?', reason='Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls.', answer='33', id=None), ReasonIO(question='Olivia has $23. She bought five bagels for $3 each. How much money does she have left?', reason='She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8.', answer='8', id=None))


trace:  Record(
  problem_id: Log(name='problem_id', score=0.0, value='test/123', should_stop=False, replayed=False, metadata=None)
  choose_prompts: Sample(name='choose_prompts', score=Array(-6.2383246, dtype=float32, weak_type=True), value=[ReasonIO(question='There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?', reason='We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees.', answer='6', id=None), ReasonIO(question='Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?', reason='He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys.', answer='9', id=None), ReasonIO(question='Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?', reason='Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops.', answer='8', id=None)], should_stop=False, replayed=False, metadata=None)
  thought: Sample(name='thought', score=Array(0., dtype=float32, weak_type=True), value=': Half of 48 is 24. 24 + 48 is 72.\nAnswer: 72\n', should_stop=False, replayed=False, metadata=None)
  return_value: Log(name='return_value', score=0.0, value='72', should_stop=False, replayed=False, metadata=None)
)

*/

0x2:Semi-supervised learning

Figure 3. QTA model with hidden thoughts. 

0x3:Selection-Inference

Selection Inference是一种多步推理prompt技术,它将prompt推理分为两个模块:

  • 选择模块(selection module):从给定问题的事实集合中,选择一个事实子集
  • 推理模块(inference module):根据被选择出的事实子集推断出新的事实

可以用下图4的模型来表示这个流程,

 

Figure 4. Selection inference as a cascade. Here S is the selected subset of facts and I is an inference driven by this subset. 

这里S是从预先指定的“事实”集合中选择的一组事实,I 是由该事实驱动的推论。

S 和 I 节点可以迭代以进行多步推理。

同时,基于few-shot prompt技术,我们可以通过给出示例来“训练”模型,以此来改变模型预测的前验概率,

0x4:Verifiers

尽管向模型添加明确的“thought”变量已经可以提高性能,但模型仍然可能得到得到错误的答案,或者从错误的推理中得到的”正确“的答案。

解决这个问题的一个直观地办法是,引导模型在推理的过程中判断答案和想法是否有效。

 

Figure 5. Verifier model. The small double-ringed nodes are deterministic buffer nodes that concatenate their inputs, accumulating all past strings. All other nodes are stochastic. The verifiers are observed to take on the “correct” value. 

0x5:Tool-use 

0x6:Twenty questions 

Figure 6. Twenty questions. 

参考链接:

https://github.com/google-research/cascades