1use crate::gguf::Gguf;
4use crate::ops::{
5 BatchedMultiqueryAttention, BatchedMultiqueryAttentionParams, LayerNorm, UnaryOp,
6};
7use nalgebra::{DMatrix, DVector, Dyn, OMatrix, OVector, Vector4};
8
9pub struct Transformer;
10
11type VocabSize = Dyn;
12type SeqLen = Dyn;
13type NumHeads = Dyn;
14type Dim = Dyn;
15type Attn = Dyn; type HiddenDim = Dyn; pub struct Gpt2Params {
19 pub n_vocab: usize,
21 pub n_seq: usize,
23 pub n_embd: usize,
25 pub n_head: usize,
27 pub n_layer: usize,
29 pub ff_len: usize,
31 pub attn_b: usize,
32 pub ftype: usize,
33}
34
35impl Gpt2Params {
36 pub fn from_gguf(gguf: &Gguf) -> Self {
37 Self {
38 n_vocab: gguf.metadata["tokenizer.ggml.tokens"].unwrap_array_len(),
39 n_seq: gguf.metadata["gpt2.context_length"].unwrap_u32() as usize,
40 n_embd: gguf.metadata["gpt2.embedding_length"].unwrap_u32() as usize,
41 n_head: gguf.metadata["gpt2.attention.head_count"].unwrap_u32() as usize,
42 n_layer: gguf.metadata["gpt2.block_count"].unwrap_u32() as usize,
43 ftype: gguf.metadata["general.file_type"].unwrap_u32() as usize,
44 ff_len: gguf.metadata["gpt2.feed_forward_length"].unwrap_u32() as usize,
45 attn_b: gguf.tensors["blk.0.attn_qkv.bias"].dimensions()[0] as usize,
46 }
47 }
48}
49
50impl Default for Gpt2Params {
51 fn default() -> Self {
52 Self {
54 n_vocab: 50257,
55 n_seq: 1024,
56 n_embd: 768,
57 n_head: 12,
58 n_layer: 12,
59 attn_b: 2304,
60 ftype: 1,
61 ff_len: 3072,
62 }
63 }
64}
65
66pub struct Gpt2Layer {
67 pub(crate) ln_1_g: OVector<f32, Dim>,
69 pub(crate) ln_1_b: OVector<f32, Dim>,
70 pub(crate) ln_2_g: OVector<f32, Dim>,
71 pub(crate) ln_2_b: OVector<f32, Dim>,
72
73 pub(crate) c_attn_attn_w: OMatrix<f32, Attn, Dim>,
75 pub(crate) c_attn_attn_b: OVector<f32, Attn>,
76 pub(crate) c_attn_proj_w: OMatrix<f32, Dim, Attn>,
77 pub(crate) c_attn_proj_b: OVector<f32, Dim>,
78
79 pub(crate) key_cache: OMatrix<f32, Dim, SeqLen>,
81 pub(crate) value_cache: OMatrix<f32, Dim, SeqLen>,
82
83 pub(crate) c_mlp_fc_w: OMatrix<f32, HiddenDim, Dim>,
85 pub(crate) c_mlp_fc_b: OVector<f32, HiddenDim>,
86 pub(crate) c_mlp_proj_w: OMatrix<f32, Dim, HiddenDim>,
87 pub(crate) c_mlp_proj_b: OVector<f32, Dim>,
88}
89
90pub struct Gpt2Model {
91 pub(crate) ln_f_g: OVector<f32, Dim>,
93 pub(crate) ln_f_b: OVector<f32, Dim>,
94
95 pub(crate) wte: OMatrix<f32, Dim, VocabSize>, pub(crate) wpe: OMatrix<f32, Dim, SeqLen>, pub(crate) lm_head: OMatrix<f32, VocabSize, Dim>, pub(crate) layers: Vec<Gpt2Layer>,
100
101 memory_q: OVector<f32, Dim>,
103 memory_att: OMatrix<f32, SeqLen, NumHeads>,
104 layer_input: DVector<f32>,
105 curr_768: DVector<f32>,
106 curr_768_b: DVector<f32>,
107 curr_2304: DVector<f32>,
108 curr_3072: DVector<f32>,
109 curr_vocab: DVector<f32>,
110}
111
112impl Gpt2Model {
113 pub fn from_gguf(gguf: &Gguf) -> (Self, Gpt2Params) {
114 let params = Gpt2Params::from_gguf(gguf);
115 let mut layers = vec![];
116
117 for i_layer in 0..params.n_layer {
118 let ln_1_g = format!("blk.{}.attn_norm.weight", i_layer);
119 let ln_1_b = format!("blk.{}.attn_norm.bias", i_layer);
120 let ln_2_g = format!("blk.{}.ffn_norm.weight", i_layer);
121 let ln_2_b = format!("blk.{}.ffn_norm.bias", i_layer);
122 let c_attn_attn_w = format!("blk.{}.attn_qkv.weight", i_layer);
123 let c_attn_attn_b = format!("blk.{}.attn_qkv.bias", i_layer);
124 let c_attn_proj_w = format!("blk.{}.attn_output.weight", i_layer);
125 let c_attn_proj_b = format!("blk.{}.attn_output.bias", i_layer);
126
127 let c_mlp_fc_w = format!("blk.{}.ffn_up.weight", i_layer);
128 let c_mlp_fc_b = format!("blk.{}.ffn_up.bias", i_layer);
129 let c_mlp_proj_w = format!("blk.{}.ffn_down.weight", i_layer);
130 let c_mlp_proj_b = format!("blk.{}.ffn_down.bias", i_layer);
131
132 let ln_1_g = gguf.tensors[&ln_1_g].data().as_f32().unwrap();
133 let ln_1_b = gguf.tensors[&ln_1_b].data().as_f32().unwrap();
134 let ln_2_g = gguf.tensors[&ln_2_g].data().as_f32().unwrap();
135 let ln_2_b = gguf.tensors[&ln_2_b].data().as_f32().unwrap();
136 let c_attn_attn_w = &gguf.tensors[&c_attn_attn_w].data().dequantize().unwrap();
137 let c_attn_attn_b = gguf.tensors[&c_attn_attn_b].data().as_f32().unwrap();
138 let c_attn_proj_w = &gguf.tensors[&c_attn_proj_w].data().dequantize().unwrap();
139 let c_attn_proj_b = gguf.tensors[&c_attn_proj_b].data().as_f32().unwrap();
140 let c_mlp_fc_w = &gguf.tensors[&c_mlp_fc_w].data().dequantize().unwrap();
141 let c_mlp_fc_b = gguf.tensors[&c_mlp_fc_b].data().as_f32().unwrap();
142 let c_mlp_proj_w = &gguf.tensors[&c_mlp_proj_w].data().dequantize().unwrap();
143 let c_mlp_proj_b = gguf.tensors[&c_mlp_proj_b].data().as_f32().unwrap();
144
145 let ln_1_g = DVector::from_row_slice(ln_1_g);
146 let ln_1_b = DVector::from_row_slice(ln_1_b);
147 let ln_2_g = DVector::from_row_slice(ln_2_g);
148 let ln_2_b = DVector::from_row_slice(ln_2_b);
149
150 let c_attn_attn_w =
151 DMatrix::from_row_slice(params.attn_b, params.n_embd, c_attn_attn_w);
152 let c_attn_attn_b = DVector::from_row_slice(c_attn_attn_b);
153 let c_attn_proj_w =
154 DMatrix::from_row_slice(params.n_embd, params.n_embd, c_attn_proj_w);
155 let c_attn_proj_b = DVector::from_row_slice(c_attn_proj_b);
156 let c_mlp_fc_w = DMatrix::from_row_slice(params.ff_len, params.n_embd, c_mlp_fc_w);
157 let c_mlp_fc_b = DVector::from_row_slice(c_mlp_fc_b);
158 let c_mlp_proj_w = DMatrix::from_row_slice(params.n_embd, params.ff_len, c_mlp_proj_w);
159 let c_mlp_proj_b = DVector::from_row_slice(c_mlp_proj_b);
160
161 let layer = Gpt2Layer {
162 ln_1_g,
163 ln_1_b,
164 ln_2_g,
165 ln_2_b,
166 c_attn_attn_w,
167 c_attn_attn_b,
168 c_attn_proj_w,
169 c_attn_proj_b,
170 c_mlp_fc_w,
171 c_mlp_fc_b,
172 c_mlp_proj_w,
173 c_mlp_proj_b,
174 key_cache: DMatrix::zeros(params.n_embd, params.n_seq),
175 value_cache: DMatrix::zeros(params.n_embd, params.n_seq),
176 };
177 layers.push(layer);
178 }
179
180 let ln_f_g = gguf.tensors["output_norm.weight"].data().as_f32().unwrap();
181 let ln_f_b = gguf.tensors["output_norm.bias"].data().as_f32().unwrap();
182 let wte = gguf.tensors["token_embd.weight"]
183 .data()
184 .dequantize()
185 .unwrap();
186 let wpe = &gguf.tensors["position_embd.weight"]
187 .data()
188 .dequantize()
189 .unwrap();
190
191 let ln_f_g = DVector::from_row_slice(ln_f_g);
192 let ln_f_b = DVector::from_row_slice(ln_f_b);
193 let wte = DMatrix::from_column_slice(params.n_embd, params.n_vocab, &wte);
194 let wpe = DMatrix::from_column_slice(params.n_embd, params.n_seq, wpe);
195 let lm_head = wte.transpose();
197
198 let model = Self {
199 ln_f_b,
200 ln_f_g,
201 wte,
202 wpe,
203 layers,
204 lm_head,
205 memory_q: DVector::zeros(params.n_embd),
206 memory_att: DMatrix::zeros(params.n_seq, params.n_head),
207 layer_input: DVector::zeros(params.n_embd),
208 curr_768: DVector::zeros(params.n_embd),
209 curr_768_b: DVector::zeros(params.n_embd),
210 curr_2304: DVector::zeros(params.attn_b),
211 curr_3072: DVector::zeros(params.ff_len),
212 curr_vocab: DVector::zeros(params.n_vocab),
213 };
214
215 (model, params)
216 }
217
218 pub fn logits_mut(&mut self) -> &mut DVector<f32> {
219 &mut self.curr_vocab
220 }
221}
222
223impl Transformer {
224 pub fn forward(params: &Gpt2Params, model: &mut Gpt2Model, embd: usize, pos: usize) {
225 model.layer_input.copy_from(&model.wte.column(embd));
227 model.layer_input += &model.wpe.column(pos);
228
229 for layer in model.layers.iter_mut() {
230 {
232 LayerNorm::run_cpu(&mut model.curr_768, &model.layer_input);
235
236 model.curr_768.component_mul_assign(&layer.ln_1_g);
239 model.curr_768 += &layer.ln_1_b;
240 }
241
242 {
251 model
252 .curr_2304
253 .gemv(1.0, &layer.c_attn_attn_w, &model.curr_768, 0.0);
254 model.curr_2304 += &layer.c_attn_attn_b;
255 }
256
257 {
261 let mut k_cache = layer.key_cache.column_mut(pos);
263 let mut v_cache = layer.value_cache.column_mut(pos);
264
265 model
266 .memory_q
267 .copy_from(&model.curr_2304.rows(0, params.n_embd));
268 k_cache.copy_from(&model.curr_2304.rows(params.n_embd, params.n_embd));
269 v_cache.copy_from(&model.curr_2304.rows(2 * params.n_embd, params.n_embd));
270
271 let head_size = params.n_embd / params.n_head;
272 let attn_params = BatchedMultiqueryAttentionParams {
273 seq_len: params.n_seq as u32,
274 kv_dim: params.n_embd as u32,
275 kv_mul: 1,
276 n_heads: params.n_head as u32,
277 head_size: head_size as u32,
278 pos: pos as u32,
279 };
280
281 BatchedMultiqueryAttention::run_cpu(
282 &attn_params,
283 &model.memory_q,
284 &layer.key_cache,
285 &layer.value_cache,
286 &mut model.memory_att,
287 &mut model.curr_768,
288 );
289 }
290
291 {
300 model
301 .curr_768_b
302 .gemv(1.0, &layer.c_attn_proj_w, &model.curr_768, 0.0);
303 model.curr_768_b += &layer.c_attn_proj_b;
304 }
305
306 model.curr_768_b += &model.layer_input;
308
309 model.layer_input.copy_from(&model.curr_768_b);
311
312 {
314 {
316 LayerNorm::run_cpu(&mut model.curr_768, &model.curr_768_b);
317
318 model.curr_768.component_mul_assign(&layer.ln_2_g);
321 model.curr_768 += &layer.ln_2_b;
322 }
323
324 model
333 .curr_3072
334 .gemv(1.0, &layer.c_mlp_fc_w, &model.curr_768, 0.0);
335 model.curr_3072 += &layer.c_mlp_fc_b;
336
337 model
340 .curr_3072
341 .apply(|x| *x = UnaryOp::Gelu.eval(*x, Vector4::zeros()));
342
343 model
352 .curr_768
353 .gemv(1.0, &layer.c_mlp_proj_w, &model.curr_3072, 0.0);
354 model.curr_768 += &layer.c_mlp_proj_b;
355 }
356
357 model.layer_input += &model.curr_768;
359 }
360
361 {
363 LayerNorm::run_cpu(&mut model.curr_768, &model.layer_input);
365
366 model.curr_768.component_mul_assign(&model.ln_f_g);
369 model.curr_768 += &model.ln_f_b;
370 }
371
372 model
376 .curr_vocab
377 .gemv(1.0, &model.lm_head, &model.curr_768, 0.0);
378
379 }
382}