16GB显卡推理80G大模型

发布时间 2023-10-19 14:25:38作者: sunshine丶23

最近看到一篇推文是在不量化、不损失精度的情况下使用一张16G的显卡推理70B的大模型。方案来自于kaggle的一个方案,具体流程为:

1.创建一个空的(例如,没有权重的)模型

2.决定每一层将要去哪里(当有多个设备可用时)

3.在内存中加载其权重的一部分

4.在空模型中加载这些权重

5.将权重移动到设备上进行推理 

6.从第3步重复,直到所有的权重都被加载

PyTorch 1.9引入了一种新的设备,称为元设备(meta device)。

这使我们能够创建没有任何数据附加的张量,元设备上的张量只需要一个shape,只要你在元设备上,你就可以创建任意大的张量,而不必担心CPU(或GPU)的RAM够不够。

比如下面的代码,内存不够的话就会崩掉

1 import torch
2 large_tensor = torch.randn(100000, 100000)

这个大张量需要4 * 10**10字节(默认精度是FP32,所以张量的每个元素占用4字节),因此需要40GB的RAM。然而,在元设备上执行相同的操作就可以正常运行:

1 import torch
2 large_tensor = torch.randn(100000, 100000, device="meta")

这个张量没有关联的数据,只有一个形状。你可以直接在元设备上实例化一个模型:

1 large_model = torch.nn.Linear(100000, 100000, device="meta")

但是对于现成的模型来说,这种语法需要你重写所有的建模代码,以便每个模型的子部分都接受并传递一个设备关键字参数。由于这对Transformers库的预训练模型来说不切实际,accelerate库有一个context manager,整合了meta device可以实例化一个空模型。

1 # Load meta model (no memory used)
2 with init_empty_weights():
3     self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
4     self.model.tie_weights()

这一步很关键,我们知道每个权重的形状,因此我们可以知道一旦我们完全加载预训练的张量,它们将消耗多少内存。因此,我们可以决定如何在CPU和GPU之间分割我们的模型。

除此之外,定义了两个关键的方法,分别是load_layer_to_cpu,负责把 权重从disk挪到CPU,另外一个是move_layer_to_device,负责把权重从cpu挪到显卡。还有一个释放显存的方法clean_memory,负责清空显存。

 1 def load_layer_to_cpu(self, layer_name):
 2     self.weights_loader.set_state_dict(layer_name, self.device)
 3     state_dict = self.weights_loader.get_state_dict(self.device)
 4     if "value_head.weight" in state_dict:
 5         state_dict = {"lm_head.weight" : state_dict["value_head.weight"]}
 6     return state_dict
 7     
 8 def move_layer_to_device(self, state_dict):
 9     for param_name, param in state_dict.items():
10         assert param.dtype != torch.int8, "int8 not supported (need to add fp16_statistics)"
11         set_module_tensor_to_device(self.model, param_name, self.device, value=param, dtype=self.dtype)
12 
13 def clean_memory():
14     gc.collect()
15     ctypes.CDLL("libc.so.6").malloc_trim(0)
16     torch.cuda.empty_cache()

下面展示完整的代码

  1 from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
  2 from accelerate import init_empty_weights
  3 from accelerate.utils.modeling import set_module_tensor_to_device
  4 from safetensors.torch import load_file
  5 from optimum.bettertransformer import BetterTransformer
  6 
  7 N_BATCHES = 3
  8 MAX_LENGTH = 4096
  9 
 10 def clean_memory():
 11     gc.collect()
 12     ctypes.CDLL("libc.so.6").malloc_trim(0)
 13     torch.cuda.empty_cache()
 14 
 15 
 16 # Class for sharded llama
 17 class ShardedLlama:
 18     def __init__(self, checkpoint_path, weights_loader, device="cuda:0", dtype=torch.float16):
 19 
 20         # Save parameters
 21         self.checkpoint_path = Path(checkpoint_path)
 22         self.weights_loader = weights_loader
 23         self.device = device 
 24         self.dtype = dtype
 25 
 26         # Create model
 27         self.config = AutoConfig.from_pretrained(self.checkpoint_path)   
 28         self.tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
 29         self.tokenizer.pad_token = self.tokenizer.eos_token
 30         self.tokenizer.padding_side = "right"
 31         self.init_model()
 32         self.layer_names = ["model.embed_tokens"] + [f"model.layers.{i}" for i in range(len(self.model.model.layers))] + ["model.norm""value_head"]
 33 
 34     def init_model(self):
 35     
 36         # Load meta model (no memory used)
 37         with init_empty_weights():
 38             self.model = AutoModelForCausalLM.from_config(self.config)
 39             self.model.lm_head = torch.nn.Linear(8192, 8, bias=False) # originally 32k
 40             self.model.eval()
 41             self.model = BetterTransformer.transform(self.model) # enable flash attention
 42             self.model.tie_weights()
 43             
 44         self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm, self.model.lm_head]
 45 
 46         # Move buffers to device (note that much GPU memory used)
 47         for buffer_name, buffer in self.model.named_buffers():
 48             set_module_tensor_to_device(self.model, buffer_name, self.device, value=buffer, dtype=self.dtype)
 49 
 50     def load_layer_to_cpu(self, layer_name):
 51         self.weights_loader.set_state_dict(layer_name, self.device)
 52         state_dict = self.weights_loader.get_state_dict(self.device)
 53         if "value_head.weight" in state_dict:
 54             state_dict = {"lm_head.weight" : state_dict["value_head.weight"]}
 55         return state_dict
 56         
 57     def move_layer_to_device(self, state_dict):
 58         for param_name, param in state_dict.items():
 59             assert param.dtype != torch.int8, "int8 not supported (need to add fp16_statistics)"
 60             set_module_tensor_to_device(self.model, param_name, self.device, value=param, dtype=self.dtype)
 61 
 62     def __call__(self, inputs):
 63         # inputs = [(prefix, suffix), ...] with prefix.shape[0] = 1 and suffix.shape[0] = 5
 64         
 65         # Reboot the model to make sure buffers are loaded and memory is clean
 66         del self.model
 67         clean_memory()
 68         self.init_model()
 69         
 70        # Send batch to device
 71         batch = [(prefix.to(self.device), suffix.to(self.device)) for prefix, suffix in inputs]
 72         n_suffixes = len(batch[0][1])
 73         suffix_eos = [(suffix != self.tokenizer.pad_token_id).sum(1) - 1 for _, suffix in inputs]
 74 
 75         # Create attention mask for the largest input, and position ids to use KV cache
 76         attention_mask = torch.ones(MAX_LENGTH, MAX_LENGTH)
 77         attention_mask = attention_mask.triu(diagonal=1)[None, None, ...] == 0
 78         attention_mask = attention_mask.to(self.device)
 79         position_ids = torch.arange(MAX_LENGTH, dtype=torch.long, device=self.device)[None, :]
 80 
 81         with ThreadPoolExecutor() as executor, torch.inference_mode():
 82 
 83             # Load first layer
 84             future = executor.submit(self.load_layer_to_cpu, "model.embed_tokens")
 85 
 86             for i, (layer_name, layer) in tqdm(enumerate(zip(self.layer_names, self.layers)), desc=self.device, total=len(self.layers)):
 87 
 88                 # Load current layer and prepare next layer
 89                 state_dict = future.result()
 90                 if (i + 1) < len(self.layer_names):
 91                     future = executor.submit(self.load_layer_to_cpu, self.layer_names[i + 1])
 92                 self.move_layer_to_device(state_dict)
 93                 
 94                 # Run layer
 95                 for j, (prefix, suffix) in enumerate(batch):
 96                     if layer_name == "model.embed_tokens":
 97                         batch[j] = (layer(prefix), layer(suffix))
 98                     elif layer_name == "model.norm":
 99                         # Only keep the last token at this point
100                         batch[j] = (None, layer(suffix[torch.arange(n_suffixes), suffix_eos[j]][:, None]))
101                     elif layer_name == "value_head":
102                         batch[j] = layer(suffix)[:, 0].mean(1).detach().cpu().numpy()
103                     else:
104                         # Run prefix
105                         len_p, len_s = prefix.shape[1], suffix.shape[1]
106                         new_prefix, (k_cache, v_cache) = layer(prefix, use_cache=True, attention_mask=attention_mask[:, :, -len_p:, -len_p:])
107                         
108                         # Run suffix
109                         pos = position_ids[:, len_p:len_p + len_s].expand(n_suffixes, -1)
110                         attn = attention_mask[:, :, -len_s:, -len_p - len_s:].expand(n_suffixes, -1, -1, -1)
111                         kv_cache = (k_cache.expand(n_suffixes, -1, -1, -1), v_cache.expand(n_suffixes, -1, -1, -1))
112                         new_suffix = layer(suffix, past_key_value=kv_cache, position_ids=pos, attention_mask=attn)[0]
113                         batch[j] = (new_prefix, new_suffix)
114 
115                 # Remove previous layer from memory (including buffers)
116                 layer.to("meta")
117                 clean_memory() # proposed by CPMP
118 
119         # Get scores
120         return batch
121 
122 
123 
124 
125 def run_model(device, df, weights_loader):
126     model = ShardedLlama(checkpoint_path, weights_loader, device=device)
127     f = partial(get_tokens, tokenizer=model.tokenizer)
128     inputs = df.apply(f, axis=1).values
129     batches = np.array_split(inputs, N_BATCHES)
130     outputs = []
131     for i, batch in enumerate(batches):
132         outputs += model(batch)
133     return outputs

完整代码参考:https://www.kaggle.com/code/simjeg/platypus2-70b-without-wikipedia-rag

文章来源:

模型黑科技:单张16G卡推70B模型