从零开始训练LLaMA模型

2024/12/13 21:03:26

本文主要是介绍从零开始训练LLaMA模型,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

作者用AI生成的图片

在这篇文章里,我们会看到怎么训练上一篇文章中提到的LLaMA模型。

上一篇文章 :- 自己动手搭建 LLaMA 模型

# 导入必要的库
from typing import Optional  
import torch  
import time  
from pathlib import Path  
import json  
from sentencepiece import SentencePieceProcessor  
from tqdm import tqdm  

# 从model模块导入ModelArgs和llamaModel类
from model import ModelArgs, llamaModel

首先,导入所有必要的库,同时还要导入我们在前一篇文章中实现的 model.py 文件,以便后续使用。

    类 LLaMA:  

        def __init__(self, model: llamaModel, tokenizer: SentencePieceProcessor, model_args: ModelArgs):  
            self.model = model  
            self.tokenizer = tokenizer  
            self.args = model_args

在 LLaMA 类中定义一个初始化方法,初始化方法的参数包括:
model :- llamaModel 类的一个实例,表示语言模型。
tokenizer :- SentencePieceProcessor 类的一个实例,处理文本的分词和解词。
model_args :- ModelArgs 类的一个实例,包含模型的配置和参数。

    @staticmethod  
        def build(checkpoints_dir: str, tokenizer_path: str, load_model: bool, max_seq_len: int, max_batch_size: int, device: str):  
            prev_time = time.time()  
            if load_model:  
                checkpoints = sorted(Path(checkpoints_dir).glob("*.pth"))  
                assert len(checkpoints) > 0, f"在{checkpoints_dir}中未找到任何检查点文件"  
                ckpt_path = checkpoints[0]  
                print(f'加载检查点"{ckpt_path}"')  
                checkpoint = torch.load(ckpt_path, map_location="cpu")  
                print(f"检查点加载完成,耗时约为{(time.time() - prev_time):.2f}")  
                prev_time = time.time()  
            with open(Path(checkpoints_dir) / "params.json", "r") as f:  
                params = json.loads(f.read())  

            model_args: ModelArgs = ModelArgs(  
                max_seq_len=max_seq_len,  
                max_batch_size=max_batch_size,  
                device=device,  
                **params  
            )  

            tokenizer = SentencePieceProcessor()  
            tokenizer.load(tokenizer_path)  
            model_args.vocab_size = tokenizer.vocab_size()  

            if device == "cuda":  
                torch.set_default_tensor_type(torch.cuda.HalfTensor)  
            else:  
                torch.set_default_tensor_type(torch.BFloat16Tensor)  

            model = llamaModel(model_args).to(device)  

            if load_model:  
                # 检查点中唯一的不匹配项是rope.freqs,这里需要将其移除  
                del checkpoint['rope.freqs']  
                model.load_state_dict(checkpoint, strict=True)  
                print(f"状态字典加载完成,耗时约为{(time.time() - prev_time):.2f}")  

            return LLaMA(model, tokenizer, model_args)

这是 **LLaMA** 类中的第二个方法。**@staticmethod** —— 表示这是一个静态方法,不依赖于类或特定实例的数据,。

检查点加载:
如果 load_modelTrue,则会在指定目录中查找.pth文件(PyTorch模型检查点文件),确保存在至少一个检查点文件。然后将找到的第一个检查点加载到内存,并映射到CPU。并打印加载检查点所用的时间。

参数加载:
从检查点目录中的 params.json 文件中读取模型参数。使用这些参数,以及提供的 max_seq_len、max_batch_size 和设备(如 GPU)信息,构建一个 ModelArgs 对象。

分词器加载:
初始化一个 SentencePieceProcessor 分词器对象,并从指定路径加载,根据加载的分词器设定 model_args 中的词汇量。

张量类型配置: 根据指定的设备,将默认的张量类型设置为对于使用CUDA的设备为半精度,对于CPU设备则使用BFloat16。

模型初始化:
使用给定的模型参数初始化llamaModel并创建它,然后将其移动到指定设备。如果 load_modelTrue ,则从检查点移除 rope.freqs 键(以避免不匹配问题),并将状态字典加载到模型。

返回语句:
会创建并返回使用模型、分词器和模型参数初始化的LLaMA类实例。

推理的工作原理

在进行推理时,我们一次只处理一个 token 以减少不必要的计算。我们需要一种方法从词汇表中确定下一个 token,这称为 logits

假设一个句子 “爱是__ ,我们需要填最后一个词,我们可以想到很多不同的词,例如:仁慈、永恒、痛苦、纯粹、无条件等。填入的词取决于我们的知识、教育和经验。

大型语言模型也面临同样的问题,预测下一个词取决于它们的训练和预测策略,比如:贪心策略、束搜索法、温度采样、随机采样、Top-K、Top-P 等。

自注意力的输出是一个序列,而在使用KV缓存的情况下,它只是一个单一令牌。然后经过归一化之后,我们将其通过一个线性层,将自注意力输出的向量转换成一个表示该令牌在词汇表中概率的数值列表。如果词汇表大小为1000,我们将得到一个包含1000个数字的列表。经过Softmax后,这些数字将变成该令牌可能是下一个令牌的概率。

现在从这些众多的可能性中选择下一个token。如何选择?为此,我们采用策略。(在这里,token指代某种符号或标识。)

贪心策略 — 在每一步中,我们选择概率最高的词并将其添加到输入中以生成下一个词。如果初始词选择不当,那么下一个词也很可能错误,因此表现不佳。

束搜索算法 在每个步骤中,我们保留 K 条最佳路径,而其他路径将被淘汰。因此推理时间会增加,因为每一步都需要探索 K 种可能的选择。通常,束搜索的表现优于贪心策略。

温度 → 这个思路是在应用softmax之前对logits进行调整。低温度: 让模型更确定;高温度: 让模型更不确定。

随机采样 → 我们从Softmax输出的分布中抽取样本。

    logits = torch.Tensor([-2.5, -3, -0.6])  
    distribution = torch.softmax(logits, dim=0)  
    distribution  
    # 输出结果为 --> tensor([0.1206, 0.0731, 0.8063])

第一个令牌有12.06%的概率被选中,第二个有7.31%的概率,第三个有80.63%的概率。概率越高,被选中的可能性就越大。问题在于,可能会以极小的概率选择到毫无意义的令牌。

保留前K个最高概率 → 通过这种方式,我们只保留最高的K个概率,这样低概率的词元就不会被选中了。问题是,即使给定以下分布,低概率的词元仍然可能进入前K个词元(K=2)。
如下分布1: 0.50.4 , 0.05, 0.025, 0.025
如下分布2: 0.90.05 , 0.025, 0.020, 0.005

Top P → 通过这种方式,我们只保留累积概率大于或等于参数P的那些token。这种情况下,我们会得到更多的token,而在分布较为集中的情况下,则得到较少的token。较为均匀的分布会带来更多的token,而分布有显著峰值的情况则会带来较少的token。

在 LLaMA 中采用了top P策略,我们也会这样做。

    def text_completion(self, prompts: list[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None):  
            if max_gen_len is None:  
                max_gen_len = self.args.max_seq_len - 1  
            prompt_tokens = [self.tokenizer.encode(prompt, out_type=int, add_bos=True, add_eos=False) for prompt in prompts]  
            batch_size = len(prompt_tokens)  

            assert batch_size <= self.args.max_batch_size, f"批大小必须不超过 {self.args.max_batch_size}"  
            max_prompt_len = max(len(prompt) for prompt in prompt_tokens)  

            assert max_prompt_len <= self.args.max_seq_len, f"提示长度需小于或等于 {self.args.max_seq_len}"  
            total_len = min(self.args.max_seq_len, max_gen_len + max_prompt_len)  

            pad_id = self.tokenizer.pad_id()  
            tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=device)  

            for k, t in enumerate(prompt_tokens):  
                tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)  

            eos_reached = torch.tensor([False] * batch_size, device=device)  
            prompt_tokens_mask = tokens != pad_id # 如果token是提示token则为True,否则为False  
            cur_iterator = tqdm(range(1, total_len), desc="生成 tokens")  

            for cur_pos in cur_iterator:  
                with torch.no_grad():  
                    logits = self.model.forward(tokens[:, cur_pos-1:cur_pos], cur_pos)  
                if temperature > 0:  
                    probs = torch.softmax(logits[:, -1] / temperature, dim=-1)  
                    next_token = self._sample_top_p(probs, top_p)  
                else:  
                    next_token = torch.argmax(logits[:, -1], dim=-1)  

                next_token = next_token.reshape(-1)  
                next_token = torch.where(prompt_tokens_mask[:, cur_pos], tokens[:, cur_pos], next_token)  
                tokens[:, cur_pos] = next_token  
                eos_reached |= (~prompt_tokens_mask[:, cur_pos]) & (next_token == self.tokenizer.eos_id)  

                if all(eos_reached):  
                    break  

            out_tokens = []  
            out_text = []  

            for prompt_index, current_prompt_tokens in enumerate(tokens.tolist()):  
                if self.tokenizer.eos_id in current_prompt_tokens:  
                    eos_idx = current_prompt_tokens.index(self.tokenizer.eos_id)  
                    current_prompt_tokens = current_prompt_tokens[:eos_idx]  
                out_tokens.append(current_prompt_tokens)  
                out_text.append(self.tokenizer.decode(current_prompt_tokens))  

            return (out_tokens, out_text)

提示分词
每个提示都用分词器分词,包括序列起始标记(BOS),但不包括序列终止标记(EOS)。

确保批次大小和提示长度的限制
确保批次大小不超过最大允许值。同时确保提示长度不超过最大限制。计算生成序列所需的总长度。

初始化Token矩阵
初始化一个只填充了padding token的矩阵。然后用prompt tokens填充该矩阵。

生成 Token 循环
遍历序列中的每个位置。获取当前位置的 logits 值。应用温度缩放,使用 top-p 抽样或当温度为 0 时使用贪心抽样选择下一个 token。用生成的 token 更新 token 矩阵。检查 EOS token 来判断是否可以提前终止生成。

处理输出令牌
将生成的令牌矩阵转换成令牌列表。如果有EOS令牌,则在每个序列的EOS令牌处截断序列。将令牌序列解码成文本。返回令牌及其对应的文本完成。

    def _sample_top_p(self, probs, p):  
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) # (B, vocab_size)  
        probs_sum = torch.cumsum(probs_sort, dim=-1) # (B, vocab_size)  

        mask = probs_sum - probs_sort > p # (B, vocab_size)  

        probs_sort[mask] = 0.0   
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))  

        next_token = torch.multinomial(probs_sort, num_samples=1)  
        next_token = torch.gather(probs_idx, -1, next_token)   
        return next_token

首先将概率按降序排序,然后记录原始概率对应的索引。这用于在采样后映射回原始标记索引。
计算排序概率沿最后一维的累积和,生成一个布尔掩码,用来标记累积概率(不包括当前概率)超过阈值 P 的标记。
将部分概率重置为零后,重新归一化这些概率,使它们的总和为 1。
**next_token** 从重新归一化的概率中采样一个标记,并使用收集到的索引来获取该采样标记对应的原始索引。

    if __name__ == '__main__':  
        torch.manual_seed(0)  

        allow_cuda = False  
        device = 'cuda' if torch.cuda.is_available() and allow_cuda else 'cpu'  

        prompts = [  
            "简单地说,相对论表明",  
            "如果谷歌是一家在米兰成立的意大利公司,那么它会",  
            # Few shot prompt  
            """将英语翻译成法语:  

            海獭 => loutre de mer  
            薄荷 => menthe poivrée  
            毛绒长颈鹿玩偶 => girafe peluche  
            奶酪 =>""",  
            # Zero shot prompt  
            """告诉我以下人物是否实际上是戴着人类面具的哆啦A梦:  
            名字: Ebad Sayed  
            判断:   
            """  
        ]  

        model = LLaMA.build(  
            checkpoints_dir='llama-2-7b/',  
            tokenizer_path='tokenizer.model',  
            load_model=True,  
            max_seq_len=1024,  
            max_batch_size=len(prompts),  
            device=device  
        )  

        out_tokens, out_texts = (model.text_completion(prompts, max_gen_len=64))  
        assert len(out_texts) == len(prompts)  
        for i in range(len(out_texts)):  
            print(f'{out_texts[i]}')  
            print('-' * 50)

该脚本准备环境,初始化LLaMA模型并生成一组输入提示的文本完成,并打印生成的文本。主要包括配置设备、定义输入提示、搭建模型以及处理和展示输出结果。

从我的GitHub下载一下**download.sh**文件来下载LLaMA模型的权重。您可以在我的GitHub上找到所有相关的代码文件。



这篇关于从零开始训练LLaMA模型的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程