LLaMA 3的Rust实现

借助Candle中LLaMA 3的高质量Rust代码,可以帮助我们更直观地理解大羊驼架构、注意力与推演缓存机制。

不要被大语言模型动辄几十亿的参数吓到,也无需为技术报告中晦涩难懂的概念所痛苦。借助Candle中LLaMA 3的高质量Rust代码,可以帮助我们更直观地理解大羊驼架构、注意力与推演缓存机制。

LLaMA 3概览

LLaMA 3直接相关结构体的定义如下

pub struct Llama {
    wte: Embedding,
    blocks: Vec<Block>,
    ln_f: RmsNorm,
    lm_head: Linear,
}

这个结构体的内容包含四部分:一个Embedding、一堆Block、一个RMS Norm和一个线性层。至于具体的结构则可以通过模型推演 forward 函数来推算

pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
    let (_b_sz, seq_len) = x.dims2()?;
    let mut x = self.wte.forward(x)?;
    for (block_idx, block) in self.blocks.iter().enumerate() {
        x = block.forward(&x, index_pos, block_idx, cache)?;
    }
    let x = self.ln_f.forward(&x)?;
    let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
    let logits = self.lm_head.forward(&x)?;
    logits.to_dtype(DType::F32)
}

输入依次经过Embedding、多个Block、RMS Norm以及线性层,得到输出。可以大致绘制如下图

LLaMA架构示意图

而每个模块的规格,则是通过load 函数从配置结构体Config中获得

pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
    let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
    let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
    let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
    let blocks: Vec<_> = (0..cfg.num_hidden_layers)
        .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap())
        .collect();

    Ok(Self {
        wte,
        blocks,
        ln_f,
        lm_head,
    })
}

配置结构体Config 里的内容则来自配置文件config.json 。例如LLaMA 3的配置文件

{
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.40.0.dev0",
  "use_cache": true,
  "vocab_size": 128256
}

其中有几个与架构相关的参数包括vocab_size=128256hidden_size=4096。这两个参数决定了开始的Embedding与最后的Linear的规格。Block的个数num_hidden_layers=32,决定了渐进式处理的层级数。