wgml/models/gpt2/
cpu.rs

1// Gpt-2 transformer, ported from ggml/examples/gpt-2/main-backend.cpp
2
3use crate::gguf::Gguf;
4use crate::ops::{
5    BatchedMultiqueryAttention, BatchedMultiqueryAttentionParams, LayerNorm, UnaryOp,
6};
7use nalgebra::{DMatrix, DVector, Dyn, OMatrix, OVector, Vector4};
8
9pub struct Transformer;
10
11type VocabSize = Dyn;
12type SeqLen = Dyn;
13type NumHeads = Dyn;
14type Dim = Dyn;
15type Attn = Dyn; // 2304
16type HiddenDim = Dyn; // 3072?
17
18pub struct Gpt2Params {
19    /// Size of the vocabulary.
20    pub n_vocab: usize,
21    /// Max sequence length.
22    pub n_seq: usize,
23    /// Token embedding length.
24    pub n_embd: usize,
25    /// Number of heads.
26    pub n_head: usize,
27    /// Number of layers.
28    pub n_layer: usize,
29    // Feed-forward length.
30    pub ff_len: usize,
31    pub attn_b: usize,
32    pub ftype: usize,
33}
34
35impl Gpt2Params {
36    pub fn from_gguf(gguf: &Gguf) -> Self {
37        Self {
38            n_vocab: gguf.metadata["tokenizer.ggml.tokens"].unwrap_array_len(),
39            n_seq: gguf.metadata["gpt2.context_length"].unwrap_u32() as usize,
40            n_embd: gguf.metadata["gpt2.embedding_length"].unwrap_u32() as usize,
41            n_head: gguf.metadata["gpt2.attention.head_count"].unwrap_u32() as usize,
42            n_layer: gguf.metadata["gpt2.block_count"].unwrap_u32() as usize,
43            ftype: gguf.metadata["general.file_type"].unwrap_u32() as usize,
44            ff_len: gguf.metadata["gpt2.feed_forward_length"].unwrap_u32() as usize,
45            attn_b: gguf.tensors["blk.0.attn_qkv.bias"].dimensions()[0] as usize,
46        }
47    }
48}
49
50impl Default for Gpt2Params {
51    fn default() -> Self {
52        // Default params for GPT-2 117M
53        Self {
54            n_vocab: 50257,
55            n_seq: 1024,
56            n_embd: 768,
57            n_head: 12,
58            n_layer: 12,
59            attn_b: 2304,
60            ftype: 1,
61            ff_len: 3072,
62        }
63    }
64}
65
66pub struct Gpt2Layer {
67    // Normalization.
68    pub(crate) ln_1_g: OVector<f32, Dim>,
69    pub(crate) ln_1_b: OVector<f32, Dim>,
70    pub(crate) ln_2_g: OVector<f32, Dim>,
71    pub(crate) ln_2_b: OVector<f32, Dim>,
72
73    // attention
74    pub(crate) c_attn_attn_w: OMatrix<f32, Attn, Dim>,
75    pub(crate) c_attn_attn_b: OVector<f32, Attn>,
76    pub(crate) c_attn_proj_w: OMatrix<f32, Dim, Attn>,
77    pub(crate) c_attn_proj_b: OVector<f32, Dim>,
78
79    // KV cache
80    pub(crate) key_cache: OMatrix<f32, Dim, SeqLen>,
81    pub(crate) value_cache: OMatrix<f32, Dim, SeqLen>,
82
83    // mlp
84    pub(crate) c_mlp_fc_w: OMatrix<f32, HiddenDim, Dim>,
85    pub(crate) c_mlp_fc_b: OVector<f32, HiddenDim>,
86    pub(crate) c_mlp_proj_w: OMatrix<f32, Dim, HiddenDim>,
87    pub(crate) c_mlp_proj_b: OVector<f32, Dim>,
88}
89
90pub struct Gpt2Model {
91    // Normalization
92    pub(crate) ln_f_g: OVector<f32, Dim>,
93    pub(crate) ln_f_b: OVector<f32, Dim>,
94
95    pub(crate) wte: OMatrix<f32, Dim, VocabSize>, // token embedding
96    pub(crate) wpe: OMatrix<f32, Dim, SeqLen>,    // position embedding
97    pub(crate) lm_head: OMatrix<f32, VocabSize, Dim>, // language model head
98
99    pub(crate) layers: Vec<Gpt2Layer>,
100
101    // scratch memory
102    memory_q: OVector<f32, Dim>,
103    memory_att: OMatrix<f32, SeqLen, NumHeads>,
104    layer_input: DVector<f32>,
105    curr_768: DVector<f32>,
106    curr_768_b: DVector<f32>,
107    curr_2304: DVector<f32>,
108    curr_3072: DVector<f32>,
109    curr_vocab: DVector<f32>,
110}
111
112impl Gpt2Model {
113    pub fn from_gguf(gguf: &Gguf) -> (Self, Gpt2Params) {
114        let params = Gpt2Params::from_gguf(gguf);
115        let mut layers = vec![];
116
117        for i_layer in 0..params.n_layer {
118            let ln_1_g = format!("blk.{}.attn_norm.weight", i_layer);
119            let ln_1_b = format!("blk.{}.attn_norm.bias", i_layer);
120            let ln_2_g = format!("blk.{}.ffn_norm.weight", i_layer);
121            let ln_2_b = format!("blk.{}.ffn_norm.bias", i_layer);
122            let c_attn_attn_w = format!("blk.{}.attn_qkv.weight", i_layer);
123            let c_attn_attn_b = format!("blk.{}.attn_qkv.bias", i_layer);
124            let c_attn_proj_w = format!("blk.{}.attn_output.weight", i_layer);
125            let c_attn_proj_b = format!("blk.{}.attn_output.bias", i_layer);
126
127            let c_mlp_fc_w = format!("blk.{}.ffn_up.weight", i_layer);
128            let c_mlp_fc_b = format!("blk.{}.ffn_up.bias", i_layer);
129            let c_mlp_proj_w = format!("blk.{}.ffn_down.weight", i_layer);
130            let c_mlp_proj_b = format!("blk.{}.ffn_down.bias", i_layer);
131
132            let ln_1_g = gguf.tensors[&ln_1_g].data().as_f32().unwrap();
133            let ln_1_b = gguf.tensors[&ln_1_b].data().as_f32().unwrap();
134            let ln_2_g = gguf.tensors[&ln_2_g].data().as_f32().unwrap();
135            let ln_2_b = gguf.tensors[&ln_2_b].data().as_f32().unwrap();
136            let c_attn_attn_w = &gguf.tensors[&c_attn_attn_w].data().dequantize().unwrap();
137            let c_attn_attn_b = gguf.tensors[&c_attn_attn_b].data().as_f32().unwrap();
138            let c_attn_proj_w = &gguf.tensors[&c_attn_proj_w].data().dequantize().unwrap();
139            let c_attn_proj_b = gguf.tensors[&c_attn_proj_b].data().as_f32().unwrap();
140            let c_mlp_fc_w = &gguf.tensors[&c_mlp_fc_w].data().dequantize().unwrap();
141            let c_mlp_fc_b = gguf.tensors[&c_mlp_fc_b].data().as_f32().unwrap();
142            let c_mlp_proj_w = &gguf.tensors[&c_mlp_proj_w].data().dequantize().unwrap();
143            let c_mlp_proj_b = gguf.tensors[&c_mlp_proj_b].data().as_f32().unwrap();
144
145            let ln_1_g = DVector::from_row_slice(ln_1_g);
146            let ln_1_b = DVector::from_row_slice(ln_1_b);
147            let ln_2_g = DVector::from_row_slice(ln_2_g);
148            let ln_2_b = DVector::from_row_slice(ln_2_b);
149
150            let c_attn_attn_w =
151                DMatrix::from_row_slice(params.attn_b, params.n_embd, c_attn_attn_w);
152            let c_attn_attn_b = DVector::from_row_slice(c_attn_attn_b);
153            let c_attn_proj_w =
154                DMatrix::from_row_slice(params.n_embd, params.n_embd, c_attn_proj_w);
155            let c_attn_proj_b = DVector::from_row_slice(c_attn_proj_b);
156            let c_mlp_fc_w = DMatrix::from_row_slice(params.ff_len, params.n_embd, c_mlp_fc_w);
157            let c_mlp_fc_b = DVector::from_row_slice(c_mlp_fc_b);
158            let c_mlp_proj_w = DMatrix::from_row_slice(params.n_embd, params.ff_len, c_mlp_proj_w);
159            let c_mlp_proj_b = DVector::from_row_slice(c_mlp_proj_b);
160
161            let layer = Gpt2Layer {
162                ln_1_g,
163                ln_1_b,
164                ln_2_g,
165                ln_2_b,
166                c_attn_attn_w,
167                c_attn_attn_b,
168                c_attn_proj_w,
169                c_attn_proj_b,
170                c_mlp_fc_w,
171                c_mlp_fc_b,
172                c_mlp_proj_w,
173                c_mlp_proj_b,
174                key_cache: DMatrix::zeros(params.n_embd, params.n_seq),
175                value_cache: DMatrix::zeros(params.n_embd, params.n_seq),
176            };
177            layers.push(layer);
178        }
179
180        let ln_f_g = gguf.tensors["output_norm.weight"].data().as_f32().unwrap();
181        let ln_f_b = gguf.tensors["output_norm.bias"].data().as_f32().unwrap();
182        let wte = gguf.tensors["token_embd.weight"]
183            .data()
184            .dequantize()
185            .unwrap();
186        let wpe = &gguf.tensors["position_embd.weight"]
187            .data()
188            .dequantize()
189            .unwrap();
190
191        let ln_f_g = DVector::from_row_slice(ln_f_g);
192        let ln_f_b = DVector::from_row_slice(ln_f_b);
193        let wte = DMatrix::from_column_slice(params.n_embd, params.n_vocab, &wte);
194        let wpe = DMatrix::from_column_slice(params.n_embd, params.n_seq, wpe);
195        // NOTE: GPT2 shares the lm_head tensor with wte.
196        let lm_head = wte.transpose();
197
198        let model = Self {
199            ln_f_b,
200            ln_f_g,
201            wte,
202            wpe,
203            layers,
204            lm_head,
205            memory_q: DVector::zeros(params.n_embd),
206            memory_att: DMatrix::zeros(params.n_seq, params.n_head),
207            layer_input: DVector::zeros(params.n_embd),
208            curr_768: DVector::zeros(params.n_embd),
209            curr_768_b: DVector::zeros(params.n_embd),
210            curr_2304: DVector::zeros(params.attn_b),
211            curr_3072: DVector::zeros(params.ff_len),
212            curr_vocab: DVector::zeros(params.n_vocab),
213        };
214
215        (model, params)
216    }
217
218    pub fn logits_mut(&mut self) -> &mut DVector<f32> {
219        &mut self.curr_vocab
220    }
221}
222
223impl Transformer {
224    pub fn forward(params: &Gpt2Params, model: &mut Gpt2Model, embd: usize, pos: usize) {
225        // Positional encoding.
226        model.layer_input.copy_from(&model.wte.column(embd));
227        model.layer_input += &model.wpe.column(pos);
228
229        for layer in model.layers.iter_mut() {
230            // Layer norm.
231            {
232                // NOTE: in this implementation, we always have N = 1
233                // [ 768, N]
234                LayerNorm::run_cpu(&mut model.curr_768, &model.layer_input);
235
236                // cur = ln_1_g*cur + ln_1_b
237                // [ 768, N]
238                model.curr_768.component_mul_assign(&layer.ln_1_g);
239                model.curr_768 += &layer.ln_1_b;
240            }
241
242            // attn
243            // [2304, 768] - model.layers[il].c_attn_attn_w
244            // [2304,   1] - model.layers[il].c_attn_attn_b
245            // [ 768,   N] - cur (in)
246            // [2304,   N] - cur (out)
247            //
248            // cur = attn_w*cur + attn_b
249            // [2304, N]
250            {
251                model
252                    .curr_2304
253                    .gemv(1.0, &layer.c_attn_attn_w, &model.curr_768, 0.0);
254                model.curr_2304 += &layer.c_attn_attn_b;
255            }
256
257            // self-attention
258            // TODO: refactor this so that both llama2 and gpt2 share the attn code with KV cache.
259            // TODO: implement flash attention.
260            {
261                // [2304, 1]
262                let mut k_cache = layer.key_cache.column_mut(pos);
263                let mut v_cache = layer.value_cache.column_mut(pos);
264
265                model
266                    .memory_q
267                    .copy_from(&model.curr_2304.rows(0, params.n_embd));
268                k_cache.copy_from(&model.curr_2304.rows(params.n_embd, params.n_embd));
269                v_cache.copy_from(&model.curr_2304.rows(2 * params.n_embd, params.n_embd));
270
271                let head_size = params.n_embd / params.n_head;
272                let attn_params = BatchedMultiqueryAttentionParams {
273                    seq_len: params.n_seq as u32,
274                    kv_dim: params.n_embd as u32,
275                    kv_mul: 1,
276                    n_heads: params.n_head as u32,
277                    head_size: head_size as u32,
278                    pos: pos as u32,
279                };
280
281                BatchedMultiqueryAttention::run_cpu(
282                    &attn_params,
283                    &model.memory_q,
284                    &layer.key_cache,
285                    &layer.value_cache,
286                    &mut model.memory_att,
287                    &mut model.curr_768,
288                );
289            }
290
291            // projection
292            // [ 768, 768] - model.layers[il].c_attn_proj_w
293            // [ 768,   1] - model.layers[il].c_attn_proj_b
294            // [ 768,   N] - cur (in)
295            // [ 768,   N] - cur (out)
296            //
297            // cur = proj_w*cur + proj_b
298            // [768, N]
299            {
300                model
301                    .curr_768_b
302                    .gemv(1.0, &layer.c_attn_proj_w, &model.curr_768, 0.0);
303                model.curr_768_b += &layer.c_attn_proj_b;
304            }
305
306            // add the input
307            model.curr_768_b += &model.layer_input;
308
309            // prep input for next layer
310            model.layer_input.copy_from(&model.curr_768_b);
311
312            // feed-forward network
313            {
314                // norm
315                {
316                    LayerNorm::run_cpu(&mut model.curr_768, &model.curr_768_b);
317
318                    // cur = ln_2_g*cur + ln_2_b
319                    // [ 768, N]
320                    model.curr_768.component_mul_assign(&layer.ln_2_g);
321                    model.curr_768 += &layer.ln_2_b;
322                }
323
324                // fully connected
325                // [3072, 768] - model.layers[il].c_mlp_fc_w
326                // [3072,   1] - model.layers[il].c_mlp_fc_b
327                // [ 768,   N] - cur (in)
328                // [3072,   N] - cur (out)
329                //
330                // cur = fc_w*cur + fc_b
331                // [3072, N]
332                model
333                    .curr_3072
334                    .gemv(1.0, &layer.c_mlp_fc_w, &model.curr_768, 0.0);
335                model.curr_3072 += &layer.c_mlp_fc_b;
336
337                // GELU activation
338                // [3072, N]
339                model
340                    .curr_3072
341                    .apply(|x| *x = UnaryOp::Gelu.eval(*x, Vector4::zeros()));
342
343                // projection
344                // [ 768, 3072] - model.layers[il].c_mlp_proj_w
345                // [ 768,    1] - model.layers[il].c_mlp_proj_b
346                // [3072,    N] - cur (in)
347                // [ 768,    N] - cur (out)
348                //
349                // cur = proj_w*cur + proj_b
350                // [768, N]
351                model
352                    .curr_768
353                    .gemv(1.0, &layer.c_mlp_proj_w, &model.curr_3072, 0.0);
354                model.curr_768 += &layer.c_mlp_proj_b;
355            }
356
357            // finalize input for next layer
358            model.layer_input += &model.curr_768;
359        }
360
361        // norm
362        {
363            // [ 768, N]
364            LayerNorm::run_cpu(&mut model.curr_768, &model.layer_input);
365
366            // inpL = ln_f_g*inpL + ln_f_b
367            // [ 768, N]
368            model.curr_768.component_mul_assign(&model.ln_f_g);
369            model.curr_768 += &model.ln_f_b;
370        }
371
372        // inpL = WTE * inpL
373        // [ 768, 50257] - model.lm_head
374        // [ 768, N]     - inpL
375        model
376            .curr_vocab
377            .gemv(1.0, &model.lm_head, &model.curr_768, 0.0);
378
379        // // logits -> probs
380        // SoftMax::run_cpu(&mut curr2); // NOTE: done by the sampler
381    }
382}