wgml/models/llama2/
cpu.rs

1//! The CPU version of the llama2 transformer.
2
3use crate::gguf::Gguf;
4use nalgebra::{
5    vector, DMatrix, DVector, DVectorViewMut, Dyn, OMatrix, OVector, Rotation2, Storage,
6    StorageMut, Vector,
7};
8use std::ffi::c_int;
9
10type Dim = Dyn;
11type HiddenDim = Dyn;
12type NumHeads = Dyn;
13type SeqLen = Dyn;
14
15#[repr(C)]
16#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
17pub struct RawConfig {
18    /// The transformer dimension.
19    /// In particular, this is the size of an embedding.
20    dim: c_int,
21    /// Number of dimension of the feed-forward neural net.
22    hidden_dim: c_int,
23    /// Number of layers.
24    n_layers: c_int,
25    /// Number of query heads.
26    n_q_heads: c_int,
27    /// Number of key/value heads (can be < than `n_q_heads` because of multiquery).
28    /// See <https://youtu.be/Mn_9W1nCFLo?si=UnkLuzaHlX8JKyjl&t=3808> (Grouped-query diagram).
29    n_kv_heads: c_int,
30    /// Vocabulary size, usually 256 (byte -level).
31    vocab_size: c_int,
32    /// Max sequence length.
33    seq_len: c_int,
34}
35
36/*
37 * Important note: the original code (like most of the LLM literature) assumes row-major matrices
38 * with left-multiplication (vector * Matrix).
39 * nalgebra uses column-major with right-multiplication (Matrix * vector). So in the end the data layout still match,
40 * we just have to swap al the matrix dimensions, access columns instead of rows (and vice versa),
41 * and replace left-multiplication by right-multiplication.
42 */
43#[derive(Copy, Clone, Debug)]
44pub struct Llama2Config {
45    /// The transformer dimension.
46    /// In particular, this is the size of an embedding.
47    pub dim: usize,
48    /// Number of dimension of the feed-forward neural net.
49    pub hidden_dim: usize,
50    /// Number of layers.
51    pub n_layers: usize,
52    /// Number of query heads.
53    pub n_q_heads: usize,
54    /// Number of key/value heads (can be < than `n_q_heads` because of multiquery).
55    /// See <https://youtu.be/Mn_9W1nCFLo?si=UnkLuzaHlX8JKyjl&t=3808> (Grouped-query diagram).
56    pub n_kv_heads: usize,
57    /// Vocabulary size, usually 256 (byte -level).
58    pub vocab_size: usize,
59    /// Max sequence length.
60    pub seq_len: usize,
61    pub shared_weights: bool,
62}
63
64impl Llama2Config {
65    pub fn read(bytes: &[u8]) -> Self {
66        let elts: &[RawConfig] = bytemuck::cast_slice(&bytes[..std::mem::size_of::<RawConfig>()]);
67        elts[0].into()
68    }
69
70    pub fn from_gguf(gguf: &Gguf) -> Self {
71        Self {
72            dim: gguf.metadata["llama.embedding_length"].unwrap_u32() as usize,
73            hidden_dim: gguf.metadata["llama.feed_forward_length"].unwrap_u32() as usize,
74            n_layers: gguf.metadata["llama.block_count"].unwrap_u32() as usize,
75            n_q_heads: gguf.metadata["llama.attention.head_count"].unwrap_u32() as usize,
76            n_kv_heads: gguf.metadata["llama.attention.head_count_kv"].unwrap_u32() as usize,
77            vocab_size: gguf.metadata["tokenizer.ggml.tokens"].unwrap_array_len(),
78            seq_len: gguf.metadata["llama.context_length"].unwrap_u32() as usize,
79            shared_weights: true, // ???
80        }
81    }
82}
83
84impl From<RawConfig> for Llama2Config {
85    fn from(c: RawConfig) -> Self {
86        Self {
87            dim: c.dim as usize,
88            hidden_dim: c.hidden_dim as usize,
89            n_layers: c.n_layers as usize,
90            n_q_heads: c.n_q_heads as usize,
91            n_kv_heads: c.n_kv_heads as usize,
92            vocab_size: c.vocab_size.unsigned_abs() as usize,
93            seq_len: c.seq_len as usize,
94            shared_weights: c.vocab_size > 0,
95        }
96    }
97}
98
99pub struct TransformerLayerWeights {
100    pub attn_k: DMatrix<f32>,
101    pub attn_norm: DVector<f32>,
102    pub attn_q: DMatrix<f32>,
103    pub attn_v: DMatrix<f32>,
104    pub ffn_down: DMatrix<f32>,
105    pub ffn_gate: DMatrix<f32>,
106    pub ffn_norm: DVector<f32>,
107    pub ffn_up: DMatrix<f32>,
108    pub attn_output: DMatrix<f32>,
109}
110
111pub struct TransformerWeights {
112    pub layers: Vec<TransformerLayerWeights>,
113    pub token_embd: DMatrix<f32>,
114    pub output: DMatrix<f32>,
115    pub output_norm: DVector<f32>,
116}
117
118impl TransformerWeights {
119    pub fn from_gguf(config: &Llama2Config, gguf: &Gguf) -> Self {
120        let head_size = config.dim / config.n_q_heads;
121        let num_kv_heads_times_head_size = config.n_kv_heads * head_size;
122
123        let mut layers = vec![];
124
125        for i_layer in 0..config.n_layers {
126            let attn_q = format!("blk.{}.attn_q.weight", i_layer);
127            let attn_k = format!("blk.{}.attn_k.weight", i_layer);
128            let attn_v = format!("blk.{}.attn_v.weight", i_layer);
129            let attn_output = format!("blk.{}.attn_output.weight", i_layer);
130            let ffn_down = format!("blk.{}.ffn_down.weight", i_layer);
131            let ffn_gate = format!("blk.{}.ffn_gate.weight", i_layer);
132            let ffn_up = format!("blk.{}.ffn_up.weight", i_layer);
133            let ffn_norm = format!("blk.{}.ffn_norm.weight", i_layer);
134            let attn_norm = format!("blk.{}.attn_norm.weight", i_layer);
135
136            let attn_q = &gguf.tensors[&attn_q].data().dequantize().unwrap();
137            let attn_k = &gguf.tensors[&attn_k].data().dequantize().unwrap();
138            let attn_v = &gguf.tensors[&attn_v].data().dequantize().unwrap();
139            let attn_output = &gguf.tensors[&attn_output].data().dequantize().unwrap();
140            let ffn_down = &gguf.tensors[&ffn_down].data().dequantize().unwrap();
141            let ffn_gate = &gguf.tensors[&ffn_gate].data().dequantize().unwrap();
142            let ffn_up = &gguf.tensors[&ffn_up].data().dequantize().unwrap();
143            let ffn_norm = gguf.tensors[&ffn_norm].data().as_f32().unwrap();
144            let attn_norm = gguf.tensors[&attn_norm].data().as_f32().unwrap();
145
146            let ffn_norm = DVector::from_row_slice(ffn_norm);
147            let attn_norm = DVector::from_row_slice(attn_norm);
148
149            let attn_q = DMatrix::from_row_slice(config.dim, config.dim, attn_q);
150            let attn_k = DMatrix::from_row_slice(num_kv_heads_times_head_size, config.dim, attn_k);
151            let attn_v = DMatrix::from_row_slice(num_kv_heads_times_head_size, config.dim, attn_v);
152            let attn_output = DMatrix::from_row_slice(config.dim, config.dim, attn_output);
153            let ffn_down = DMatrix::from_row_slice(config.dim, config.hidden_dim, ffn_down);
154            let ffn_gate = DMatrix::from_row_slice(config.hidden_dim, config.dim, ffn_gate);
155            let ffn_up = DMatrix::from_row_slice(config.hidden_dim, config.dim, ffn_up);
156
157            layers.push(TransformerLayerWeights {
158                attn_q,
159                attn_k,
160                attn_v,
161                attn_output,
162                ffn_down,
163                ffn_gate,
164                ffn_up,
165                ffn_norm,
166                attn_norm,
167            });
168        }
169
170        let token_embd = "token_embd.weight";
171        let output = "output.weight";
172        let output_norm = "output_norm.weight";
173
174        let token_embd = &gguf.tensors[token_embd].data().dequantize().unwrap();
175        let output = gguf
176            .tensors
177            .get(output)
178            .map(|v| v.data().dequantize().unwrap());
179        let output_norm = gguf.tensors[output_norm].data().as_f32().unwrap();
180
181        let token_embd = DMatrix::from_column_slice(config.dim, config.vocab_size, token_embd);
182        let output = output
183            .map(|data| DMatrix::from_row_slice(config.vocab_size, config.dim, &data))
184            .unwrap_or_else(|| token_embd.transpose());
185        let output_norm = DVector::from_row_slice(output_norm);
186
187        Self {
188            layers,
189            token_embd,
190            output,
191            output_norm,
192        }
193    }
194}
195
196struct RunState {
197    // Current wave of activations.
198    /// Activation at current time stamp.
199    x: OVector<f32, Dim>,
200    /// Activation at current time stamp, inside a residual branch.
201    xb: OVector<f32, Dim>,
202    /// Additional buffer for convenience.
203    xb2: OVector<f32, Dim>,
204    /// Buffer for hidden dimension in the Feed-Forward net.
205    hb: OVector<f32, HiddenDim>,
206    /// Another buffer for hidden dimension in the Feed-Forward net.
207    hb2: OVector<f32, HiddenDim>,
208    /// Query.
209    q: OVector<f32, Dim>,
210    /// Scores/attention values.
211    att: OMatrix<f32, SeqLen, NumHeads>,
212    /// Output logits.
213    logits: OVector<f32, SeqLen>,
214    // KV cache. Each Vec contains `layer` elements.
215    key_cache: Vec<OMatrix<f32, Dim, SeqLen>>,
216    value_cache: Vec<OMatrix<f32, Dim, SeqLen>>,
217}
218
219pub struct Transformer {
220    /// The hyperparameters of the architecture (the blueprint).
221    config: Llama2Config,
222    /// The weights of the model.
223    weights: TransformerWeights,
224    /// Buffer of the "wave" of activations in the forward pass.
225    state: RunState,
226}
227
228impl Transformer {
229    pub fn new(config: Llama2Config, weights: TransformerWeights) -> Self {
230        Self {
231            state: RunState::new(&config),
232            config,
233            weights,
234        }
235    }
236
237    pub fn logits_mut(&mut self) -> &mut OVector<f32, SeqLen> {
238        &mut self.state.logits
239    }
240}
241
242impl RunState {
243    pub fn new(config: &Llama2Config) -> Self {
244        let kv_dim = (config.dim * config.n_kv_heads) / config.n_q_heads;
245        Self {
246            x: DVector::zeros(config.dim),
247            xb: DVector::zeros(config.dim),
248            xb2: DVector::zeros(config.dim),
249            hb: DVector::zeros(config.hidden_dim),
250            hb2: DVector::zeros(config.hidden_dim),
251            q: DVector::zeros(config.dim),
252            // TODO: for these two, the `kv_dim` doesn’t match the dimension in the field’s comment.
253            key_cache: (0..config.n_layers)
254                .map(|_| DMatrix::zeros(kv_dim, config.seq_len))
255                .collect(),
256            value_cache: (0..config.n_layers)
257                .map(|_| DMatrix::zeros(kv_dim, config.seq_len))
258                .collect(),
259            att: DMatrix::zeros(config.seq_len, config.n_q_heads),
260            logits: DVector::zeros(config.vocab_size),
261        }
262    }
263}
264
265/*
266 *
267 *
268 * Neural net blocks. The dynamics of the Transformer.
269 *
270 *
271 */
272/// Implementation of the Root Mean Square Normalization.
273///
274/// This implementation of the RMS normalization from the "Root Mean Square
275/// Normalization" paper by Zhang & Sennrich.
276fn rms_norm<SW: Storage<f32, Dyn>>(
277    out: &mut DVector<f32>,
278    a: &DVector<f32>,
279    w: &Vector<f32, Dyn, SW>,
280) {
281    const NUDGE_FACTOR: f32 = 1.0e-5;
282    let rms = 1.0 / (a.norm_squared() / (a.nrows() as f32) + NUDGE_FACTOR).sqrt();
283    out.zip_zip_apply(a, w, |o, a, w| *o = (a * rms) * w);
284}
285
286/// The softmax function.
287///
288/// Converts a set of real number into a probability distribution.
289/// See <https://fr.wikipedia.org/wiki/Fonction_softmax>
290pub fn softmax<S: StorageMut<f32, Dyn>>(vals: &mut Vector<f32, Dyn, S>) {
291    // Note that llama2.c also introduces a bias based on the max value
292    // to improve numerical stability. So it is effectively computing:
293    // softmax(z) = (e^z - max) / (e^z - max).sum()
294    let max_val = vals.max();
295    let mut sum = 0.0;
296
297    vals.apply(|x| {
298        *x = (*x - max_val).exp();
299        sum += *x;
300    });
301
302    *vals /= sum;
303}
304
305/// Most expensive part of the inference.
306// TODO llama2.c also takes the dimensions n and do, but it’s unclear if this isn’t just
307//      because the dimensions are not part of the float* input type.
308fn matmul<SOut: StorageMut<f32, Dyn>>(
309    out: &mut Vector<f32, Dyn, SOut>,
310    x: &DVector<f32>,
311    w: &DMatrix<f32>,
312) {
313    // TODO: parallelize per column? llama2.c paralelizes with openmp.
314    // TODO: use blast/faer?
315    out.gemv(1.0, w, x, 0.0);
316}
317
318impl Transformer {
319    pub fn forward(&mut self, token: usize, pos: usize) {
320        // A few convenience variables.
321        let config = &self.config;
322        let w = &self.weights;
323        let s = &mut self.state;
324        let dim = config.dim;
325        // This is the number of key/value heads multiplied by the size of a query head: NumKvHeadsTimesHeadSize
326        let kv_dim = (config.dim * config.n_kv_heads) / config.n_q_heads;
327        // The number of embedding vector elements associated to each query head.
328        let head_size = dim / config.n_q_heads;
329
330        // Copy the token embedding into x.
331        // TODO: rename `x` to `token_embedding`?
332        s.x.copy_from(&w.token_embd.column(token));
333
334        // Forward all the layers.
335        for l in 0..config.n_layers {
336            let wl = &w.layers[l];
337
338            // RMS norm before attention.
339            // See https://youtu.be/Mn_9W1nCFLo?si=Ogz_O_6LUsumWovB&t=1367
340            // TODO: rename `xb` to `normalized_token_embedding`?
341            rms_norm(&mut s.xb, &s.x, &wl.attn_norm);
342
343            // Key and value point to the KV cache.
344            let mut k_cache = s.key_cache[l].column_mut(pos);
345            let mut v_cache = s.value_cache[l].column_mut(pos);
346
347            // qkv matmuls for this position.
348            // This is self-attention, so `xb` is used for query, key, and value.
349            // These are essentially one row of Q’, K’, V’ from https://youtu.be/Mn_9W1nCFLo?si=7B_g41B2iGZ5238a&t=2422
350            // Note that despite keys/values having different number of heads as queries, the dimension of
351            // each k/v head are the same as the query heads. The dimension change happens through the
352            // multiplication by the weight matrices wk/wv.
353            matmul(&mut s.q, &s.xb, &wl.attn_q);
354            matmul(&mut k_cache, &s.xb, &wl.attn_k);
355            matmul(&mut v_cache, &s.xb, &wl.attn_v);
356
357            // Rotary Positional Encoding (RoPE).
358            Self::rotary_positional_encoding(&mut s.q, &mut k_cache, head_size, dim, kv_dim, pos);
359
360            // Batched multi-query attention.
361            Self::attention(config, s, w, pos, l);
362
363            // Residual connection back into x.
364            // See the LLama graph on the right: https://youtu.be/Mn_9W1nCFLo?si=XMDdHlXxON2QhFCd&t=320
365            // This step is the first big circled +
366            s.x += &s.xb2;
367
368            // RMSnorm before feed-forward.
369            // /!\ xb changes semantic again. It now contains the normalized {attention output+input}.
370            rms_norm(&mut s.xb, &s.x, &wl.ffn_norm);
371
372            // Feed-forward.
373            Self::ffn_silu(s, wl);
374
375            // Residual connection.
376            s.x += &s.xb2;
377            // Loop on the next layer. This layer’s output is the next layer’s input.
378        }
379
380        // Final rmsnorm.
381        // This is the top-most rmsnorm from https://youtu.be/Mn_9W1nCFLo?si=KO-aBXZo0DqCL4Qs&t=275
382        // (diagram on the right).
383        rms_norm(&mut s.xb, &s.x, &w.output_norm);
384
385        // Classifier into logits.
386        // This is the final "Linear" part from https://youtu.be/Mn_9W1nCFLo?si=-GT74rBY6j5TbbBO&t=275
387        matmul(&mut s.logits, &s.xb, &w.output);
388    }
389
390    // Rotary Positional Encoding (RoPE): complex-valued rotate q and k in each head.
391    pub fn rotary_positional_encoding(
392        q: &mut DVector<f32>,
393        k: &mut DVectorViewMut<f32>,
394        head_size: usize,
395        dim: usize,
396        kv_dim: usize,
397        pos: usize,
398    ) {
399        for i in (0..dim).step_by(2) {
400            // For RoPE, we have one rotation matrix like https://youtu.be/Mn_9W1nCFLo?si=GLIXuFLGVG8q6v2u&t=1963
401            // for each head. So we need to transform `i` into the corresponding index within
402            // the head.
403            let head_dim = (i % head_size) as f32;
404            // Not that the formulae from the video linked above would be:
405            //     10000.0.powf(-2.0 * ((i / 2) as f32 - 1.0) / dim as f32)
406            // Although in the paper shown in the video, their index is 1-based which his why thy
407            // have to subtract 1.0 whereas we don’t need to.The `i / 2` and multiplication by 2.0
408            // are both accounted for by stepping only on even values for `i`.
409            // Therefore, the formulae below is equivalent to the RoPE paper’s formulae.
410            let theta = 10000.0_f32.powf(-head_dim / head_size as f32);
411            let m_theta = pos as f32 * theta;
412            let rot = Rotation2::new(m_theta);
413
414            let qi = vector![q[i], q[i + 1]];
415            let mut out_q = q.fixed_rows_mut::<2>(i);
416            out_q.copy_from(&(rot * qi));
417
418            // When i >= kv_dim, we are done rotating all the elements from the keys. That’s
419            // because there are less key heads than query heads, but each key head sub-vector has
420            // the same dimension as the query head (they loose dimension when multiplied with the
421            // key weight matrices).
422            if i < kv_dim {
423                let ki = vector![k[i], k[i + 1]];
424                let mut out_k = k.fixed_rows_mut::<2>(i);
425                out_k.copy_from(&(rot * ki));
426            }
427        }
428    }
429
430    fn attention(
431        config: &Llama2Config,
432        s: &mut RunState,
433        w: &TransformerWeights,
434        pos: usize,
435        l: usize,
436    ) {
437        // The number of embedding vector elements associated to each query head.
438        let head_size = config.dim / config.n_q_heads;
439        // The number of query head associated to one key/value head.
440        let kv_mul = config.n_q_heads / config.n_kv_heads;
441
442        // Multihead attention. Iterate over all head.
443        // TODO: in llama2.c, each head is iterated on in parallel.
444        for h in 0..config.n_q_heads {
445            // Get the query vector for this head.
446            let q = s.q.rows(h * head_size, head_size);
447            // Attention scores for this head.
448            let mut att = s.att.column_mut(h);
449
450            // Iterate over all timesteps (tokens in the sequence), including the current one, but
451            // not past the current one due to causality.
452            // See the KV cache explanation there: https://youtu.be/Mn_9W1nCFLo?si=3n4GH9f2OzMb5Np0&t=2940
453            // -> This is iterating through all the green columns (from K^t) that are the rotated
454            //    (by RoPE). The values set in this loop into the `att` variable here (attention
455            //    scores) are the elements in the pink row (at the bottom of the QK^t matrix) divide
456            //    by sqrt(head_size) (in other words, this is what’s given to softmax afterward.
457            for t in 0..=pos {
458                // Get the key vector for this head and at this timestep.
459                let k = s.key_cache[l].column(t);
460                let k_head = k.rows((h / kv_mul) * head_size, head_size);
461
462                // Calculate the attention score as the dot product of q and k.
463                let mut score = q.dot(&k_head);
464                score /= (head_size as f32).sqrt();
465                // Save the score to the attention buffer.
466                att[t] = score;
467            }
468
469            // Softmax the scores to get attention weights from 0..=pos inclusively.
470            softmax(&mut att.rows_mut(0, pos + 1));
471
472            // Weighted sum of the values, store back into xb.
473            // /!\ xb is now changing semantic, storing the weighted sums for all the heads.
474            //       Now xb contains the "Attention 4" row from https://youtu.be/Mn_9W1nCFLo?si=550ar5aUg1I1k60l&t=2940.
475            let mut xb = s.xb.rows_mut(h * head_size, head_size);
476            xb.fill(0.0);
477            for t in 0..=pos {
478                let v = s.value_cache[l].column(t);
479                let v_head = v.rows((h / kv_mul) * head_size, head_size);
480                xb.axpy(att[t], &v_head, 1.0);
481            }
482        }
483
484        // Final matmul to get the output of the attention.
485        // TODO: rename xb2 to `attention_output`?
486        matmul(&mut s.xb2, &s.xb, &w.layers[l].attn_output);
487    }
488
489    fn ffn_silu(s: &mut RunState, wl: &TransformerLayerWeights) {
490        // We have: self.w2(F.silu(self.w1(x)) * self.w3(x)) first calculate self.w1(x) and
491        // self.w3(x)
492        //
493        // For this part, see https://youtu.be/Mn_9W1nCFLo?si=Ub9m1NeAzkmn-G8G&t=3973
494        // We have: w1 := W, w3 := V, w2 := W2
495        s.hb.gemv(1.0, &wl.ffn_gate, &s.xb, 0.0);
496        s.hb2.gemv(1.0, &wl.ffn_up, &s.xb, 0.0);
497
498        // SwiGLU non-linearity.
499        fn swish(x: f32, beta: f32) -> f32 {
500            // This is the swish function from https://youtu.be/Mn_9W1nCFLo?si=LT6puSAfzgpP6ydz&t=3973
501            x / (1.0 + (-beta * x).exp())
502        }
503
504        s.hb.zip_apply(&s.hb2, |h, h2| *h = h2 * swish(*h, 1.0));
505
506        // Final matmul to get the output of the feed-forward net.
507        matmul(&mut s.xb2, &s.hb, &wl.ffn_down);
508    }
509}