1use crate::gguf::Gguf;
4use nalgebra::{
5 vector, DMatrix, DVector, DVectorViewMut, Dyn, OMatrix, OVector, Rotation2, Storage,
6 StorageMut, Vector,
7};
8use std::ffi::c_int;
9
10type Dim = Dyn;
11type HiddenDim = Dyn;
12type NumHeads = Dyn;
13type SeqLen = Dyn;
14
15#[repr(C)]
16#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
17pub struct RawConfig {
18 dim: c_int,
21 hidden_dim: c_int,
23 n_layers: c_int,
25 n_q_heads: c_int,
27 n_kv_heads: c_int,
30 vocab_size: c_int,
32 seq_len: c_int,
34}
35
36#[derive(Copy, Clone, Debug)]
44pub struct Llama2Config {
45 pub dim: usize,
48 pub hidden_dim: usize,
50 pub n_layers: usize,
52 pub n_q_heads: usize,
54 pub n_kv_heads: usize,
57 pub vocab_size: usize,
59 pub seq_len: usize,
61 pub shared_weights: bool,
62}
63
64impl Llama2Config {
65 pub fn read(bytes: &[u8]) -> Self {
66 let elts: &[RawConfig] = bytemuck::cast_slice(&bytes[..std::mem::size_of::<RawConfig>()]);
67 elts[0].into()
68 }
69
70 pub fn from_gguf(gguf: &Gguf) -> Self {
71 Self {
72 dim: gguf.metadata["llama.embedding_length"].unwrap_u32() as usize,
73 hidden_dim: gguf.metadata["llama.feed_forward_length"].unwrap_u32() as usize,
74 n_layers: gguf.metadata["llama.block_count"].unwrap_u32() as usize,
75 n_q_heads: gguf.metadata["llama.attention.head_count"].unwrap_u32() as usize,
76 n_kv_heads: gguf.metadata["llama.attention.head_count_kv"].unwrap_u32() as usize,
77 vocab_size: gguf.metadata["tokenizer.ggml.tokens"].unwrap_array_len(),
78 seq_len: gguf.metadata["llama.context_length"].unwrap_u32() as usize,
79 shared_weights: true, }
81 }
82}
83
84impl From<RawConfig> for Llama2Config {
85 fn from(c: RawConfig) -> Self {
86 Self {
87 dim: c.dim as usize,
88 hidden_dim: c.hidden_dim as usize,
89 n_layers: c.n_layers as usize,
90 n_q_heads: c.n_q_heads as usize,
91 n_kv_heads: c.n_kv_heads as usize,
92 vocab_size: c.vocab_size.unsigned_abs() as usize,
93 seq_len: c.seq_len as usize,
94 shared_weights: c.vocab_size > 0,
95 }
96 }
97}
98
99pub struct TransformerLayerWeights {
100 pub attn_k: DMatrix<f32>,
101 pub attn_norm: DVector<f32>,
102 pub attn_q: DMatrix<f32>,
103 pub attn_v: DMatrix<f32>,
104 pub ffn_down: DMatrix<f32>,
105 pub ffn_gate: DMatrix<f32>,
106 pub ffn_norm: DVector<f32>,
107 pub ffn_up: DMatrix<f32>,
108 pub attn_output: DMatrix<f32>,
109}
110
111pub struct TransformerWeights {
112 pub layers: Vec<TransformerLayerWeights>,
113 pub token_embd: DMatrix<f32>,
114 pub output: DMatrix<f32>,
115 pub output_norm: DVector<f32>,
116}
117
118impl TransformerWeights {
119 pub fn from_gguf(config: &Llama2Config, gguf: &Gguf) -> Self {
120 let head_size = config.dim / config.n_q_heads;
121 let num_kv_heads_times_head_size = config.n_kv_heads * head_size;
122
123 let mut layers = vec![];
124
125 for i_layer in 0..config.n_layers {
126 let attn_q = format!("blk.{}.attn_q.weight", i_layer);
127 let attn_k = format!("blk.{}.attn_k.weight", i_layer);
128 let attn_v = format!("blk.{}.attn_v.weight", i_layer);
129 let attn_output = format!("blk.{}.attn_output.weight", i_layer);
130 let ffn_down = format!("blk.{}.ffn_down.weight", i_layer);
131 let ffn_gate = format!("blk.{}.ffn_gate.weight", i_layer);
132 let ffn_up = format!("blk.{}.ffn_up.weight", i_layer);
133 let ffn_norm = format!("blk.{}.ffn_norm.weight", i_layer);
134 let attn_norm = format!("blk.{}.attn_norm.weight", i_layer);
135
136 let attn_q = &gguf.tensors[&attn_q].data().dequantize().unwrap();
137 let attn_k = &gguf.tensors[&attn_k].data().dequantize().unwrap();
138 let attn_v = &gguf.tensors[&attn_v].data().dequantize().unwrap();
139 let attn_output = &gguf.tensors[&attn_output].data().dequantize().unwrap();
140 let ffn_down = &gguf.tensors[&ffn_down].data().dequantize().unwrap();
141 let ffn_gate = &gguf.tensors[&ffn_gate].data().dequantize().unwrap();
142 let ffn_up = &gguf.tensors[&ffn_up].data().dequantize().unwrap();
143 let ffn_norm = gguf.tensors[&ffn_norm].data().as_f32().unwrap();
144 let attn_norm = gguf.tensors[&attn_norm].data().as_f32().unwrap();
145
146 let ffn_norm = DVector::from_row_slice(ffn_norm);
147 let attn_norm = DVector::from_row_slice(attn_norm);
148
149 let attn_q = DMatrix::from_row_slice(config.dim, config.dim, attn_q);
150 let attn_k = DMatrix::from_row_slice(num_kv_heads_times_head_size, config.dim, attn_k);
151 let attn_v = DMatrix::from_row_slice(num_kv_heads_times_head_size, config.dim, attn_v);
152 let attn_output = DMatrix::from_row_slice(config.dim, config.dim, attn_output);
153 let ffn_down = DMatrix::from_row_slice(config.dim, config.hidden_dim, ffn_down);
154 let ffn_gate = DMatrix::from_row_slice(config.hidden_dim, config.dim, ffn_gate);
155 let ffn_up = DMatrix::from_row_slice(config.hidden_dim, config.dim, ffn_up);
156
157 layers.push(TransformerLayerWeights {
158 attn_q,
159 attn_k,
160 attn_v,
161 attn_output,
162 ffn_down,
163 ffn_gate,
164 ffn_up,
165 ffn_norm,
166 attn_norm,
167 });
168 }
169
170 let token_embd = "token_embd.weight";
171 let output = "output.weight";
172 let output_norm = "output_norm.weight";
173
174 let token_embd = &gguf.tensors[token_embd].data().dequantize().unwrap();
175 let output = gguf
176 .tensors
177 .get(output)
178 .map(|v| v.data().dequantize().unwrap());
179 let output_norm = gguf.tensors[output_norm].data().as_f32().unwrap();
180
181 let token_embd = DMatrix::from_column_slice(config.dim, config.vocab_size, token_embd);
182 let output = output
183 .map(|data| DMatrix::from_row_slice(config.vocab_size, config.dim, &data))
184 .unwrap_or_else(|| token_embd.transpose());
185 let output_norm = DVector::from_row_slice(output_norm);
186
187 Self {
188 layers,
189 token_embd,
190 output,
191 output_norm,
192 }
193 }
194}
195
196struct RunState {
197 x: OVector<f32, Dim>,
200 xb: OVector<f32, Dim>,
202 xb2: OVector<f32, Dim>,
204 hb: OVector<f32, HiddenDim>,
206 hb2: OVector<f32, HiddenDim>,
208 q: OVector<f32, Dim>,
210 att: OMatrix<f32, SeqLen, NumHeads>,
212 logits: OVector<f32, SeqLen>,
214 key_cache: Vec<OMatrix<f32, Dim, SeqLen>>,
216 value_cache: Vec<OMatrix<f32, Dim, SeqLen>>,
217}
218
219pub struct Transformer {
220 config: Llama2Config,
222 weights: TransformerWeights,
224 state: RunState,
226}
227
228impl Transformer {
229 pub fn new(config: Llama2Config, weights: TransformerWeights) -> Self {
230 Self {
231 state: RunState::new(&config),
232 config,
233 weights,
234 }
235 }
236
237 pub fn logits_mut(&mut self) -> &mut OVector<f32, SeqLen> {
238 &mut self.state.logits
239 }
240}
241
242impl RunState {
243 pub fn new(config: &Llama2Config) -> Self {
244 let kv_dim = (config.dim * config.n_kv_heads) / config.n_q_heads;
245 Self {
246 x: DVector::zeros(config.dim),
247 xb: DVector::zeros(config.dim),
248 xb2: DVector::zeros(config.dim),
249 hb: DVector::zeros(config.hidden_dim),
250 hb2: DVector::zeros(config.hidden_dim),
251 q: DVector::zeros(config.dim),
252 key_cache: (0..config.n_layers)
254 .map(|_| DMatrix::zeros(kv_dim, config.seq_len))
255 .collect(),
256 value_cache: (0..config.n_layers)
257 .map(|_| DMatrix::zeros(kv_dim, config.seq_len))
258 .collect(),
259 att: DMatrix::zeros(config.seq_len, config.n_q_heads),
260 logits: DVector::zeros(config.vocab_size),
261 }
262 }
263}
264
265fn rms_norm<SW: Storage<f32, Dyn>>(
277 out: &mut DVector<f32>,
278 a: &DVector<f32>,
279 w: &Vector<f32, Dyn, SW>,
280) {
281 const NUDGE_FACTOR: f32 = 1.0e-5;
282 let rms = 1.0 / (a.norm_squared() / (a.nrows() as f32) + NUDGE_FACTOR).sqrt();
283 out.zip_zip_apply(a, w, |o, a, w| *o = (a * rms) * w);
284}
285
286pub fn softmax<S: StorageMut<f32, Dyn>>(vals: &mut Vector<f32, Dyn, S>) {
291 let max_val = vals.max();
295 let mut sum = 0.0;
296
297 vals.apply(|x| {
298 *x = (*x - max_val).exp();
299 sum += *x;
300 });
301
302 *vals /= sum;
303}
304
305fn matmul<SOut: StorageMut<f32, Dyn>>(
309 out: &mut Vector<f32, Dyn, SOut>,
310 x: &DVector<f32>,
311 w: &DMatrix<f32>,
312) {
313 out.gemv(1.0, w, x, 0.0);
316}
317
318impl Transformer {
319 pub fn forward(&mut self, token: usize, pos: usize) {
320 let config = &self.config;
322 let w = &self.weights;
323 let s = &mut self.state;
324 let dim = config.dim;
325 let kv_dim = (config.dim * config.n_kv_heads) / config.n_q_heads;
327 let head_size = dim / config.n_q_heads;
329
330 s.x.copy_from(&w.token_embd.column(token));
333
334 for l in 0..config.n_layers {
336 let wl = &w.layers[l];
337
338 rms_norm(&mut s.xb, &s.x, &wl.attn_norm);
342
343 let mut k_cache = s.key_cache[l].column_mut(pos);
345 let mut v_cache = s.value_cache[l].column_mut(pos);
346
347 matmul(&mut s.q, &s.xb, &wl.attn_q);
354 matmul(&mut k_cache, &s.xb, &wl.attn_k);
355 matmul(&mut v_cache, &s.xb, &wl.attn_v);
356
357 Self::rotary_positional_encoding(&mut s.q, &mut k_cache, head_size, dim, kv_dim, pos);
359
360 Self::attention(config, s, w, pos, l);
362
363 s.x += &s.xb2;
367
368 rms_norm(&mut s.xb, &s.x, &wl.ffn_norm);
371
372 Self::ffn_silu(s, wl);
374
375 s.x += &s.xb2;
377 }
379
380 rms_norm(&mut s.xb, &s.x, &w.output_norm);
384
385 matmul(&mut s.logits, &s.xb, &w.output);
388 }
389
390 pub fn rotary_positional_encoding(
392 q: &mut DVector<f32>,
393 k: &mut DVectorViewMut<f32>,
394 head_size: usize,
395 dim: usize,
396 kv_dim: usize,
397 pos: usize,
398 ) {
399 for i in (0..dim).step_by(2) {
400 let head_dim = (i % head_size) as f32;
404 let theta = 10000.0_f32.powf(-head_dim / head_size as f32);
411 let m_theta = pos as f32 * theta;
412 let rot = Rotation2::new(m_theta);
413
414 let qi = vector![q[i], q[i + 1]];
415 let mut out_q = q.fixed_rows_mut::<2>(i);
416 out_q.copy_from(&(rot * qi));
417
418 if i < kv_dim {
423 let ki = vector![k[i], k[i + 1]];
424 let mut out_k = k.fixed_rows_mut::<2>(i);
425 out_k.copy_from(&(rot * ki));
426 }
427 }
428 }
429
430 fn attention(
431 config: &Llama2Config,
432 s: &mut RunState,
433 w: &TransformerWeights,
434 pos: usize,
435 l: usize,
436 ) {
437 let head_size = config.dim / config.n_q_heads;
439 let kv_mul = config.n_q_heads / config.n_kv_heads;
441
442 for h in 0..config.n_q_heads {
445 let q = s.q.rows(h * head_size, head_size);
447 let mut att = s.att.column_mut(h);
449
450 for t in 0..=pos {
458 let k = s.key_cache[l].column(t);
460 let k_head = k.rows((h / kv_mul) * head_size, head_size);
461
462 let mut score = q.dot(&k_head);
464 score /= (head_size as f32).sqrt();
465 att[t] = score;
467 }
468
469 softmax(&mut att.rows_mut(0, pos + 1));
471
472 let mut xb = s.xb.rows_mut(h * head_size, head_size);
476 xb.fill(0.0);
477 for t in 0..=pos {
478 let v = s.value_cache[l].column(t);
479 let v_head = v.rows((h / kv_mul) * head_size, head_size);
480 xb.axpy(att[t], &v_head, 1.0);
481 }
482 }
483
484 matmul(&mut s.xb2, &s.xb, &w.layers[l].attn_output);
487 }
488
489 fn ffn_silu(s: &mut RunState, wl: &TransformerLayerWeights) {
490 s.hb.gemv(1.0, &wl.ffn_gate, &s.xb, 0.0);
496 s.hb2.gemv(1.0, &wl.ffn_up, &s.xb, 0.0);
497
498 fn swish(x: f32, beta: f32) -> f32 {
500 x / (1.0 + (-beta * x).exp())
502 }
503
504 s.hb.zip_apply(&s.hb2, |h, h2| *h = h2 * swish(*h, 1.0));
505
506 matmul(&mut s.xb2, &s.hb, &wl.ffn_down);
508 }
509}