用断点调试阅读peft源码:prefix tuning

发布时间 2023-08-07 22:29:40作者: 鸽鸽的书房

今天我们阅读peft源码,主要是为了弄清楚prefix tuning的工作原理和代码细节。

模型定义部分

peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=20)

# 下载预训练模型T5,模型结构可以在debug console中输入model得到
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

主要是这一句:model = get_peft_model(model, peft_config),所以在这里设置断点。

首先跳转到:peft->mapping.py->get_peft_model函数。我逐行阅读并做出中文注释。

def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> PeftModel:
    """
    Returns a Peft model object from a model and a config.

    Args:
        model ([`transformers.PreTrainedModel`]): Model to be wrapped.
        peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
    """
    model_config = getattr(model, "config", {"model_type": "custom"}) # 得到T5模型config,在debug console中输入model_config可以查看
    if hasattr(model_config, "to_dict"):
        model_config = model_config.to_dict()  #把config中的属性序列化为 Python 字典

    peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)

    # <TaskType.SEQ_2_SEQ_LM: 'SEQ_2_SEQ_LM'>
    # dict_keys(['SEQ_CLS', 'SEQ_2_SEQ_LM', 'CAUSAL_LM', 'TOKEN_CLS', 'QUESTION_ANS', 'FEATURE_EXTRACTION'])
    if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance(
        peft_config, PromptLearningConfig
    ):
        return PeftModel(model, peft_config, adapter_name=adapter_name)
    if isinstance(peft_config, PromptLearningConfig):
        peft_config = _prepare_prompt_learning_config(peft_config, model_config)
    return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)

我们从最后一句跳进去,来到了peft->peft_model.py->PeftModelForSeq2SeqLM(PeftModel)类,所以我猜测MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type]定义了我们的模型是PeftModelForSeq2SeqLM并且被map到PeftModel,而传入的参数是model, peft_config, adapter_name=adapter_name.

PeftModelForSeq2SeqLM介绍如下:

"""
    Peft model for sequence-to-sequence language modeling.

    Args:
        model ([`~transformers.PreTrainedModel`]): Base transformer model.
        peft_config ([`PeftConfig`]): Peft config.


    Example:

        ```py
        >>> from transformers import AutoModelForSeq2SeqLM
        >>> from peft import PeftModelForSeq2SeqLM, get_peft_config

        >>> config = {
        ...     "peft_type": "LORA",
        ...     "task_type": "SEQ_2_SEQ_LM",
        ...     "inference_mode": False,
        ...     "r": 8,
        ...     "target_modules": ["q", "v"],
        ...     "lora_alpha": 32,
        ...     "lora_dropout": 0.1,
        ...     "fan_in_fan_out": False,
        ...     "enable_lora": None,
        ...     "bias": "none",
        ... }

        >>> peft_config = get_peft_config(config)
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
        >>> peft_model = PeftModelForSeq2SeqLM(model, peft_config)
        >>> peft_model.print_trainable_parameters()
        trainable params: 884736 || all params: 223843584 || trainable%: 0.3952474242013566
        ```
    """

以上这个例子简单地用lora微调一个t5-base模型,很便捷!

prefix tuning

找半天没看到prefix tuning的代码,我们直接打开/root/miniconda3/envs/peft-practice/lib/python3.10/site-packages/peft/tuners/prefix_tuning.py查看,发现它改写自Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py

class PrefixEncoder(torch.nn.Module):
    r'''
    The torch.nn model to encode the prefix

    Input shape: (batch-size, prefix-length) 
	prefix-length/num_virtual_tokens:20, hidden_size:768, prefix_hidden_size
    Output shape: (batch-size, prefix-length, 2*layers*hidden)
    '''
    def __init__(self, config):
        super().__init__()
        self.prefix_projection = config.prefix_projection
        if self.prefix_projection:
            # Use a two-layer MLP to encode the prefix
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
            self.trans = torch.nn.Sequential(
                torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
            )
        else:
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.trans(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values