借助Candle中LLaMA 3的高质量Rust代码,可以帮助我们更直观地理解大羊驼架构、注意力与推演缓存机制。
不要被大语言模型动辄几十亿的参数吓到,也无需为技术报告中晦涩难懂的概念所痛苦。借助Candle中LLaMA 3的高质量Rust代码,可以帮助我们更直观地理解大羊驼架构、注意力与推演缓存机制。
LLaMA 3概览
LLaMA 3直接相关结构体的定义如下
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: Linear,
}
这个结构体的内容包含四部分:一个Embedding、一堆Block、一个RMS Norm和一个线性层。至于具体的结构则可以通过模型推演 forward
函数来推算
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx, cache)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}
输入依次经过Embedding、多个Block、RMS Norm以及线性层,得到输出。可以大致绘制如下图
而每个模块的规格,则是通过load
函数从配置结构体Config
中获得
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap())
.collect();
Ok(Self {
wte,
blocks,
ln_f,
lm_head,
})
}
配置结构体Config
里的内容则来自配置文件config.json
。例如LLaMA 3的配置文件
{
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 8192,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.40.0.dev0",
"use_cache": true,
"vocab_size": 128256
}
其中有几个与架构相关的参数包括vocab_size=128256
、hidden_size=4096
。这两个参数决定了开始的Embedding与最后的Linear的规格。Block的个数num_hidden_layers=32
,决定了渐进式处理的层级数。