Skip to main content

sapient_models/forward/
common.rs

1//! Shared tensor ops for transformer forward passes.
2
3use anyhow::Result;
4use sapient_backends_cpu::kernels::{self, attention, layernorm, matmul, quant, rope};
5use sapient_core::error::SapientError;
6use sapient_core::{DType, Shape, Tensor};
7
8fn map_err<T>(result: std::result::Result<T, SapientError>) -> Result<T> {
9    result.map_err(|e| anyhow::anyhow!("{e}"))
10}
11
12// ── Online F16 → Q8_0 quantization at load time ──────────────────────────────
13
14/// Returns true if a weight tensor should be quantized online to Q8_0.
15///
16/// Criteria:
17/// - Must be a 2-D matrix with at least 32 elements (one Q8_0 block).
18/// - Must have dtype F16 or BF16 (safetensors weight matrices).
19/// - Must not be a norm weight, bias, embedding table, or lm_head
20///   (these have different access patterns or are tiny).
21pub fn should_quantize_online(name: &str, t: &Tensor) -> bool {
22    let dims = t.shape().dims();
23    if dims.len() != 2 {
24        return false;
25    }
26    let numel = dims[0] * dims[1];
27    if numel < 32 || numel % 32 != 0 {
28        return false;
29    }
30    // Skip small helper tensors and anything already quantized.
31    let skip = ["norm", "bias", "embed", "lm_head"];
32    if skip.iter().any(|s| name.contains(s)) {
33        return false;
34    }
35    matches!(t.dtype(), DType::F16 | DType::BF16)
36}
37
38/// Quantize a 2-D F16/BF16 weight tensor to Q8_0 in one pass.
39///
40/// The F16→F32 dequantization happens once here at load time; all subsequent
41/// decode steps use the already-NEON-optimized Q8_0 kernel (~1 byte/weight vs
42/// 2 bytes/weight for F16, and avoids the per-step F16→F32 conversion cost).
43pub fn quantize_tensor_to_q8_0(t: Tensor) -> Tensor {
44    let shape = t.shape().dims().to_vec();
45    let numel = shape[0] * shape[1];
46    debug_assert_eq!(numel % 32, 0);
47
48    let f32_data = t.to_f32_vec(); // one-time dequantization
49    let n_blocks = numel / 32;
50    let mut q8_bytes = Vec::with_capacity(n_blocks * 34);
51    for block in f32_data.chunks_exact(32) {
52        q8_bytes.extend_from_slice(&quant::quantize_q8_0_block(block));
53    }
54
55    Tensor::from_quant_bytes(&q8_bytes, shape, DType::Q8_0).unwrap_or(t)
56}
57
58/// Gather token embeddings: weight `[vocab, hidden]`, ids `[seq]` → `[1, seq, hidden]`.
59pub fn embed_tokens(weight: &Tensor, input_ids: &[u32]) -> Result<Tensor> {
60    let hidden = weight.shape().dims()[1];
61    let seq_len = input_ids.len();
62    // Embedding tables are commonly stored in F16/BF16; convert on the fly.
63    let w_cow = weight.to_f32_cow();
64    let w = w_cow.as_ref();
65    let mut out = vec![0.0f32; seq_len * hidden];
66
67    for (i, &id) in input_ids.iter().enumerate() {
68        let row = id as usize * hidden;
69        if row + hidden > w.len() {
70            anyhow::bail!("token id {id} out of vocab range");
71        }
72        out[i * hidden..(i + 1) * hidden].copy_from_slice(&w[row..row + hidden]);
73    }
74
75    Tensor::from_f32(&out, Shape::new([1, seq_len, hidden])).map_err(|e| anyhow::anyhow!("{e}"))
76}
77
78/// Linear on 3-D activations: `[1, seq, in] @ W^T` where W is `[out, in]`.
79pub fn linear_3d(x: &Tensor, weight: &Tensor) -> Result<Tensor> {
80    let dims = x.shape().dims();
81    if dims.len() != 3 {
82        anyhow::bail!("linear_3d expects [batch, seq, hidden]");
83    }
84    let (batch, seq, in_dim) = (dims[0], dims[1], dims[2]);
85    let w_dims = weight.shape().dims();
86    if w_dims.len() != 2 {
87        anyhow::bail!("linear weight must be 2-D");
88    }
89    let out_dim = w_dims[0];
90    if w_dims[1] != in_dim {
91        anyhow::bail!("linear weight in_dim mismatch: {} vs {in_dim}", w_dims[1]);
92    }
93
94    let x2d = map_err(x.reshape(vec![batch * seq, in_dim]))?;
95    // weight is [out, in] (PyTorch nn.Linear layout); matmul_nt computes x @ weightᵀ
96    // directly, honouring the layout and any F16/BF16 weight dtype.
97    let y2d = map_err(matmul::matmul_nt(&x2d, weight))?;
98    map_err(y2d.reshape(vec![batch, seq, out_dim]))
99}
100
101/// Reshape `[1, seq, n_heads * head_dim]` → `[1, n_heads, seq, head_dim]`.
102pub fn split_heads(x: &Tensor, n_heads: usize, head_dim: usize) -> Result<Tensor> {
103    let seq = x.shape().dims()[1];
104    permute(
105        &map_err(x.reshape(vec![1, seq, n_heads, head_dim]))?,
106        &[0, 2, 1, 3],
107    )
108}
109
110/// Merge heads back: `[1, n_heads, seq, head_dim]` → `[1, seq, n_heads * head_dim]`.
111pub fn merge_heads(x: &Tensor) -> Result<Tensor> {
112    let d = x.shape().dims();
113    let (n_heads, seq, head_dim) = (d[1], d[2], d[3]);
114    permute(x, &[0, 2, 1, 3])?
115        .reshape(vec![1, seq, n_heads * head_dim])
116        .map_err(|e| anyhow::anyhow!("{e}"))
117}
118
119pub fn permute(x: &Tensor, order: &[usize]) -> Result<Tensor> {
120    let dims = x.shape().dims();
121    if order.len() != dims.len() {
122        anyhow::bail!("permute rank mismatch");
123    }
124    let new_dims: Vec<usize> = order.iter().map(|&i| dims[i]).collect();
125    let src = x.as_f32_slice();
126    let mut out = vec![0.0f32; src.len()];
127
128    #[allow(clippy::too_many_arguments)]
129    fn recurse(
130        dims: &[usize],
131        order: &[usize],
132        src: &[f32],
133        out: &mut [f32],
134        src_strides: &[usize],
135        dst_strides: &[usize],
136        idx: &mut [usize],
137        depth: usize,
138    ) {
139        if depth == dims.len() {
140            let src_off: usize = idx
141                .iter()
142                .zip(src_strides.iter())
143                .map(|(&i, &s)| i * s)
144                .sum();
145            let dst_off: usize = order
146                .iter()
147                .enumerate()
148                .map(|(dst_ax, &src_ax)| idx[src_ax] * dst_strides[dst_ax])
149                .sum();
150            out[dst_off] = src[src_off];
151            return;
152        }
153        for i in 0..dims[depth] {
154            idx[depth] = i;
155            recurse(
156                dims,
157                order,
158                src,
159                out,
160                src_strides,
161                dst_strides,
162                idx,
163                depth + 1,
164            );
165        }
166    }
167
168    let src_strides = strides_for(dims);
169    let dst_strides = strides_for(&new_dims);
170    let mut idx = vec![0usize; dims.len()];
171    recurse(
172        dims,
173        order,
174        src,
175        &mut out,
176        &src_strides,
177        &dst_strides,
178        &mut idx,
179        0,
180    );
181    Tensor::from_f32(&out, Shape::new(new_dims)).map_err(|e| anyhow::anyhow!("{e}"))
182}
183
184fn strides_for(dims: &[usize]) -> Vec<usize> {
185    let mut strides = vec![1usize; dims.len()];
186    for i in (0..dims.len().saturating_sub(1)).rev() {
187        strides[i] = strides[i + 1] * dims[i + 1];
188    }
189    strides
190}
191
192/// Quantize 32 `f32` values into a single Q8_0 block (2-byte f16 scale + 32 × i8).
193/// Returns the 34-byte block in ggml layout.
194#[inline]
195fn quantize_f32_to_q8_0_block(data: &[f32]) -> [u8; 34] {
196    debug_assert_eq!(data.len(), 32, "Q8_0 block must have exactly 32 elements");
197    let max_abs = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
198    let scale = max_abs / 127.0;
199    let d = half::f16::from_f32(scale);
200    let inv_scale = if scale > 0.0 { 1.0 / scale } else { 0.0 };
201    let mut block = [0u8; 34];
202    block[0..2].copy_from_slice(&d.to_le_bytes());
203    for (i, &v) in data.iter().enumerate() {
204        block[2 + i] = (v * inv_scale).round().clamp(-127.0, 127.0) as i8 as u8;
205    }
206    block
207}
208
209/// Update the pre-allocated KV cache in place and return a view of length `seq_len + new_seq`.
210///
211/// When the cache holds Q8_0 blocks (quantized KV cache), the new F32 values are
212/// quantized on write and the returned tensor is a freshly-allocated F32 tensor
213/// (dequantized from the cache). When the cache is F32, the old in-place path is used.
214pub fn update_kv_cache(
215    cache: &mut Tensor,
216    current_seq_len: usize,
217    new_k: &Tensor,
218) -> Result<Tensor> {
219    let cd = cache.shape().dims().to_vec();
220    let nd = new_k.shape().dims().to_vec();
221
222    if cd.len() != 4 || nd.len() != 4 {
223        anyhow::bail!("update_kv_cache expects 4-D tensors");
224    }
225    if cd[0] != nd[0] || cd[1] != nd[1] || cd[3] != nd[3] {
226        anyhow::bail!("update_kv_cache shape mismatch");
227    }
228
229    // Dispatch to the Q8_0-quantized path when the cache holds packed blocks.
230    if cache.dtype() == DType::Q8_0 {
231        return update_kv_cache_q8(cache, &cd, &nd, current_seq_len, new_k);
232    }
233
234    let max_seq = cd[2];
235    let new_seq = nd[2];
236
237    if new_seq > max_seq {
238        anyhow::bail!("new tokens {} exceeds max cache size {}", new_seq, max_seq);
239    }
240
241    let mut total_seq = current_seq_len + new_seq;
242    let shift = total_seq.saturating_sub(max_seq);
243
244    let (b_sz, h, hd) = (cd[0], cd[1], cd[3]);
245    let new_k_slice = new_k.as_f32_slice();
246    let cache_strides = cache.strides().to_vec();
247
248    {
249        let cache_slice = cache.as_f32_slice_mut()?;
250
251        // If we need to shift, move existing elements left
252        if shift > 0 {
253            let keep_seq = current_seq_len - shift;
254            for bi in 0..b_sz {
255                for hi in 0..h {
256                    let cache_base = bi * cache_strides[0] + hi * cache_strides[1];
257                    for si in 0..keep_seq {
258                        let src_idx = cache_base + (si + shift) * cache_strides[2];
259                        let dst_idx = cache_base + si * cache_strides[2];
260                        cache_slice.copy_within(src_idx..src_idx + hd, dst_idx);
261                    }
262                }
263            }
264        }
265
266        // Now append the new tokens
267        let insert_pos = if shift > 0 {
268            current_seq_len - shift
269        } else {
270            current_seq_len
271        };
272        for bi in 0..b_sz {
273            for hi in 0..h {
274                let cache_base =
275                    bi * cache_strides[0] + hi * cache_strides[1] + insert_pos * cache_strides[2];
276                let new_base = ((bi * h + hi) * new_seq) * hd; // new_k is assumed contiguous from split_heads
277
278                for si in 0..new_seq {
279                    let c_idx = cache_base + si * cache_strides[2];
280                    let n_idx = new_base + si * hd;
281
282                    // Copy head_dim elements
283                    cache_slice[c_idx..c_idx + hd].copy_from_slice(&new_k_slice[n_idx..n_idx + hd]);
284                }
285            }
286        }
287    }
288
289    if shift > 0 {
290        total_seq = max_seq;
291    }
292
293    // Return a sliced view of the cache from 0 to total_seq
294    cache
295        .slice_axis(2, 0, total_seq)
296        .map_err(|e| anyhow::anyhow!("{e}"))
297}
298
299/// Q8_0-quantized KV cache update.
300///
301/// Writes new F32 tokens as Q8_0 blocks into the packed cache buffer by copying the
302/// existing bytes, mutating them, then swapping the cache tensor in place.
303/// Returns a freshly-allocated contiguous F32 tensor (dequantized from the live prefix)
304/// suitable for the attention kernel.
305///
306/// Buffer layout: flat row-major over [b, h, seq_pos], each position is
307/// `blocks_per_head * 34` bytes (one Q8_0 block per 32 head_dim elements).
308fn update_kv_cache_q8(
309    cache: &mut Tensor,
310    cd: &[usize],
311    nd: &[usize],
312    current_seq_len: usize,
313    new_k: &Tensor,
314) -> Result<Tensor> {
315    let (b_sz, h, max_seq, hd) = (cd[0], cd[1], cd[2], cd[3]);
316    let new_seq = nd[2];
317
318    if new_seq > max_seq {
319        anyhow::bail!("new tokens {} exceeds max cache size {}", new_seq, max_seq);
320    }
321
322    let blocks_per_head = hd / 32;
323    let bytes_per_pos = blocks_per_head * 34;
324    let mut total_seq = current_seq_len + new_seq;
325    let shift = total_seq.saturating_sub(max_seq);
326
327    let pos_off = |bi: usize, hi: usize, si: usize| -> usize {
328        (bi * h * max_seq + hi * max_seq + si) * bytes_per_pos
329    };
330
331    // In-place mutation via as_bytes_mut — zero allocation, zero copy.
332    let cache_bytes = cache.as_bytes_mut()?;
333
334    if shift > 0 {
335        let keep_seq = current_seq_len - shift;
336        for bi in 0..b_sz {
337            for hi in 0..h {
338                for si in 0..keep_seq {
339                    let src = pos_off(bi, hi, si + shift);
340                    let dst = pos_off(bi, hi, si);
341                    cache_bytes.copy_within(src..src + bytes_per_pos, dst);
342                }
343            }
344        }
345    }
346
347    let insert_pos = if shift > 0 { current_seq_len - shift } else { current_seq_len };
348    let new_k_f32 = new_k.to_f32_vec();
349
350    for bi in 0..b_sz {
351        for hi in 0..h {
352            for si in 0..new_seq {
353                let dst_start = pos_off(bi, hi, insert_pos + si);
354                let src_f32_start = (bi * h * new_seq + hi * new_seq + si) * hd;
355                let src_f32 = &new_k_f32[src_f32_start..src_f32_start + hd];
356                for blk in 0..blocks_per_head {
357                    let encoded = quantize_f32_to_q8_0_block(&src_f32[blk * 32..(blk + 1) * 32]);
358                    cache_bytes[dst_start + blk * 34..dst_start + blk * 34 + 34]
359                        .copy_from_slice(&encoded);
360                }
361            }
362        }
363    }
364
365    if shift > 0 {
366        total_seq = max_seq;
367    }
368
369    // Dequantize the live prefix to F32 for the attention kernel.
370    // Read directly from the (now-updated) in-place cache.
371    let cache_ro = cache.as_bytes();
372    let out_numel = b_sz * h * total_seq * hd;
373    let mut out_f32 = vec![0.0f32; out_numel];
374
375    for bi in 0..b_sz {
376        for hi in 0..h {
377            for si in 0..total_seq {
378                let src_start = pos_off(bi, hi, si);
379                let dst_f32_start = (bi * h * total_seq + hi * total_seq + si) * hd;
380                for blk in 0..blocks_per_head {
381                    let bb = &cache_ro[src_start + blk * 34..src_start + blk * 34 + 34];
382                    let d = half::f16::from_le_bytes([bb[0], bb[1]]).to_f32();
383                    for j in 0..32 {
384                        out_f32[dst_f32_start + blk * 32 + j] = bb[2 + j] as i8 as f32 * d;
385                    }
386                }
387            }
388        }
389    }
390
391    Tensor::from_f32_vec(out_f32, Shape::new(vec![b_sz, h, total_seq, hd]))
392        .map_err(|e| anyhow::anyhow!("{e}"))
393}
394
395pub fn apply_rope_positions(x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
396    map_err(rope::apply_rope(x, positions, base))
397}
398
399/// RoPE applied to only the first `rotary_dim` channels (Phi partial rotary).
400pub fn apply_rope_partial(
401    x: &Tensor,
402    positions: &[usize],
403    base: f32,
404    rotary_dim: usize,
405) -> Result<Tensor> {
406    map_err(rope::apply_rope_partial(x, positions, base, rotary_dim))
407}
408
409/// Add a per-feature bias `[n]` broadcast over the last dimension of `y`
410/// (shape `[.., n]`). `y` must be F32; `bias` may be F16/BF16.
411pub fn add_bias_last_dim(y: &Tensor, bias: &Tensor) -> Result<Tensor> {
412    let dims = y.shape().dims().to_vec();
413    let n = *dims.last().ok_or_else(|| anyhow::anyhow!("empty tensor"))?;
414    let bias_cow = bias.to_f32_cow();
415    let b = bias_cow.as_ref();
416    if b.len() != n {
417        anyhow::bail!("bias length {} does not match last dim {n}", b.len());
418    }
419    let mut data = y.as_f32_slice().to_vec();
420    for (i, v) in data.iter_mut().enumerate() {
421        *v += b[i % n];
422    }
423    map_err(Tensor::from_f32(&data, Shape::new(dims)))
424}
425
426pub fn rms_norm(x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
427    map_err(layernorm::rms_norm(x, Some(weight), eps))
428}
429
430pub fn layer_norm(x: &Tensor, weight: &Tensor, bias: Option<&Tensor>, eps: f32) -> Result<Tensor> {
431    map_err(layernorm::layer_norm(x, Some(weight), bias, -1, eps))
432}
433
434pub fn silu(x: &Tensor) -> Result<Tensor> {
435    map_err(kernels::elementwise::silu(x))
436}
437
438pub fn gelu(x: &Tensor) -> Result<Tensor> {
439    map_err(kernels::elementwise::gelu(x))
440}
441
442pub fn add(a: &Tensor, b: &Tensor) -> Result<Tensor> {
443    map_err(kernels::elementwise::add(a, b))
444}
445
446pub fn mul(a: &Tensor, b: &Tensor) -> Result<Tensor> {
447    map_err(kernels::elementwise::mul(a, b))
448}
449
450pub fn gqa_attention(
451    q: &Tensor,
452    k: &Tensor,
453    v: &Tensor,
454    n_kv_heads: usize,
455    causal: bool,
456) -> Result<Tensor> {
457    let mask = if causal {
458        let sq = q.shape().dims()[2];
459        let sk = k.shape().dims()[2];
460        Some(attention::causal_mask(sq, sk))
461    } else {
462        None
463    };
464    map_err(attention::scaled_dot_product_attention(
465        q,
466        k,
467        v,
468        mask.as_ref(),
469        None,
470        n_kv_heads,
471    ))
472}
473
474/// Compute logits for ALL positions in the sequence. Used by speculative
475/// decoding to verify K draft tokens in a single target-model forward pass.
476pub fn all_logits_from_hidden(hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<Vec<f32>>> {
477    let dims = hidden.shape().dims();
478    let hidden_size = dims[2];
479    let seq = dims[1];
480    let vocab_size = lm_head.shape().dims()[0];
481    let h = hidden.as_f32_slice();
482    let h_all =
483        Tensor::from_f32(h, Shape::new([seq, hidden_size])).map_err(|e| anyhow::anyhow!("{e}"))?;
484    let logits_flat = map_err(matmul::matmul_nt(&h_all, lm_head))?;
485    let flat = logits_flat.as_f32_slice();
486    let mut all = Vec::with_capacity(seq);
487    for i in 0..seq {
488        all.push(flat[i * vocab_size..(i + 1) * vocab_size].to_vec());
489    }
490    Ok(all)
491}
492
493pub fn logits_from_hidden(hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
494    // hidden: [1, seq, hidden], take last position
495    let dims = hidden.shape().dims();
496    let hidden_size = dims[2];
497    let seq = dims[1];
498    let h = hidden.as_f32_slice();
499    let last = &h[(seq - 1) * hidden_size..seq * hidden_size];
500    let h_last =
501        Tensor::from_f32(last, Shape::new([1, hidden_size])).map_err(|e| anyhow::anyhow!("{e}"))?;
502    // lm_head is [vocab, hidden]; matmul_nt computes h_last @ lm_headᵀ directly.
503    let logits = map_err(matmul::matmul_nt(&h_last, lm_head))?;
504    Ok(logits.as_f32_slice().to_vec())
505}
506
507pub fn mean_pool_hidden(hidden: &Tensor) -> Result<Vec<f32>> {
508    let dims = hidden.shape().dims();
509    let (seq, hidden_size) = (dims[1], dims[2]);
510    let h = hidden.as_f32_slice();
511    let mut out = vec![0.0f32; hidden_size];
512    for t in 0..seq {
513        for i in 0..hidden_size {
514            out[i] += h[t * hidden_size + i];
515        }
516    }
517    let n = seq as f32;
518    for v in &mut out {
519        *v /= n;
520    }
521    Ok(out)
522}