1use anyhow::Result;
4use sapient_backends_cpu::kernels::{self, attention, layernorm, matmul, quant, rope};
5use sapient_core::error::SapientError;
6use sapient_core::{DType, Shape, Tensor};
7
8fn map_err<T>(result: std::result::Result<T, SapientError>) -> Result<T> {
9 result.map_err(|e| anyhow::anyhow!("{e}"))
10}
11
12pub fn should_quantize_online(name: &str, t: &Tensor) -> bool {
22 let dims = t.shape().dims();
23 if dims.len() != 2 {
24 return false;
25 }
26 let numel = dims[0] * dims[1];
27 if numel < 32 || numel % 32 != 0 {
28 return false;
29 }
30 let skip = ["norm", "bias", "embed", "lm_head"];
32 if skip.iter().any(|s| name.contains(s)) {
33 return false;
34 }
35 matches!(t.dtype(), DType::F16 | DType::BF16)
36}
37
38pub fn quantize_tensor_to_q8_0(t: Tensor) -> Tensor {
44 let shape = t.shape().dims().to_vec();
45 let numel = shape[0] * shape[1];
46 debug_assert_eq!(numel % 32, 0);
47
48 let f32_data = t.to_f32_vec(); let n_blocks = numel / 32;
50 let mut q8_bytes = Vec::with_capacity(n_blocks * 34);
51 for block in f32_data.chunks_exact(32) {
52 q8_bytes.extend_from_slice(&quant::quantize_q8_0_block(block));
53 }
54
55 Tensor::from_quant_bytes(&q8_bytes, shape, DType::Q8_0).unwrap_or(t)
56}
57
58pub fn embed_tokens(weight: &Tensor, input_ids: &[u32]) -> Result<Tensor> {
60 let hidden = weight.shape().dims()[1];
61 let seq_len = input_ids.len();
62 let w_cow = weight.to_f32_cow();
64 let w = w_cow.as_ref();
65 let mut out = vec![0.0f32; seq_len * hidden];
66
67 for (i, &id) in input_ids.iter().enumerate() {
68 let row = id as usize * hidden;
69 if row + hidden > w.len() {
70 anyhow::bail!("token id {id} out of vocab range");
71 }
72 out[i * hidden..(i + 1) * hidden].copy_from_slice(&w[row..row + hidden]);
73 }
74
75 Tensor::from_f32(&out, Shape::new([1, seq_len, hidden])).map_err(|e| anyhow::anyhow!("{e}"))
76}
77
78pub fn linear_3d(x: &Tensor, weight: &Tensor) -> Result<Tensor> {
80 let dims = x.shape().dims();
81 if dims.len() != 3 {
82 anyhow::bail!("linear_3d expects [batch, seq, hidden]");
83 }
84 let (batch, seq, in_dim) = (dims[0], dims[1], dims[2]);
85 let w_dims = weight.shape().dims();
86 if w_dims.len() != 2 {
87 anyhow::bail!("linear weight must be 2-D");
88 }
89 let out_dim = w_dims[0];
90 if w_dims[1] != in_dim {
91 anyhow::bail!("linear weight in_dim mismatch: {} vs {in_dim}", w_dims[1]);
92 }
93
94 let x2d = map_err(x.reshape(vec![batch * seq, in_dim]))?;
95 let y2d = map_err(matmul::matmul_nt(&x2d, weight))?;
98 map_err(y2d.reshape(vec![batch, seq, out_dim]))
99}
100
101pub fn split_heads(x: &Tensor, n_heads: usize, head_dim: usize) -> Result<Tensor> {
103 let seq = x.shape().dims()[1];
104 permute(
105 &map_err(x.reshape(vec![1, seq, n_heads, head_dim]))?,
106 &[0, 2, 1, 3],
107 )
108}
109
110pub fn merge_heads(x: &Tensor) -> Result<Tensor> {
112 let d = x.shape().dims();
113 let (n_heads, seq, head_dim) = (d[1], d[2], d[3]);
114 permute(x, &[0, 2, 1, 3])?
115 .reshape(vec![1, seq, n_heads * head_dim])
116 .map_err(|e| anyhow::anyhow!("{e}"))
117}
118
119pub fn permute(x: &Tensor, order: &[usize]) -> Result<Tensor> {
120 let dims = x.shape().dims();
121 if order.len() != dims.len() {
122 anyhow::bail!("permute rank mismatch");
123 }
124 let new_dims: Vec<usize> = order.iter().map(|&i| dims[i]).collect();
125 let src = x.as_f32_slice();
126 let mut out = vec![0.0f32; src.len()];
127
128 #[allow(clippy::too_many_arguments)]
129 fn recurse(
130 dims: &[usize],
131 order: &[usize],
132 src: &[f32],
133 out: &mut [f32],
134 src_strides: &[usize],
135 dst_strides: &[usize],
136 idx: &mut [usize],
137 depth: usize,
138 ) {
139 if depth == dims.len() {
140 let src_off: usize = idx
141 .iter()
142 .zip(src_strides.iter())
143 .map(|(&i, &s)| i * s)
144 .sum();
145 let dst_off: usize = order
146 .iter()
147 .enumerate()
148 .map(|(dst_ax, &src_ax)| idx[src_ax] * dst_strides[dst_ax])
149 .sum();
150 out[dst_off] = src[src_off];
151 return;
152 }
153 for i in 0..dims[depth] {
154 idx[depth] = i;
155 recurse(
156 dims,
157 order,
158 src,
159 out,
160 src_strides,
161 dst_strides,
162 idx,
163 depth + 1,
164 );
165 }
166 }
167
168 let src_strides = strides_for(dims);
169 let dst_strides = strides_for(&new_dims);
170 let mut idx = vec![0usize; dims.len()];
171 recurse(
172 dims,
173 order,
174 src,
175 &mut out,
176 &src_strides,
177 &dst_strides,
178 &mut idx,
179 0,
180 );
181 Tensor::from_f32(&out, Shape::new(new_dims)).map_err(|e| anyhow::anyhow!("{e}"))
182}
183
184fn strides_for(dims: &[usize]) -> Vec<usize> {
185 let mut strides = vec![1usize; dims.len()];
186 for i in (0..dims.len().saturating_sub(1)).rev() {
187 strides[i] = strides[i + 1] * dims[i + 1];
188 }
189 strides
190}
191
192#[inline]
195fn quantize_f32_to_q8_0_block(data: &[f32]) -> [u8; 34] {
196 debug_assert_eq!(data.len(), 32, "Q8_0 block must have exactly 32 elements");
197 let max_abs = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
198 let scale = max_abs / 127.0;
199 let d = half::f16::from_f32(scale);
200 let inv_scale = if scale > 0.0 { 1.0 / scale } else { 0.0 };
201 let mut block = [0u8; 34];
202 block[0..2].copy_from_slice(&d.to_le_bytes());
203 for (i, &v) in data.iter().enumerate() {
204 block[2 + i] = (v * inv_scale).round().clamp(-127.0, 127.0) as i8 as u8;
205 }
206 block
207}
208
209pub fn update_kv_cache(
215 cache: &mut Tensor,
216 current_seq_len: usize,
217 new_k: &Tensor,
218) -> Result<Tensor> {
219 let cd = cache.shape().dims().to_vec();
220 let nd = new_k.shape().dims().to_vec();
221
222 if cd.len() != 4 || nd.len() != 4 {
223 anyhow::bail!("update_kv_cache expects 4-D tensors");
224 }
225 if cd[0] != nd[0] || cd[1] != nd[1] || cd[3] != nd[3] {
226 anyhow::bail!("update_kv_cache shape mismatch");
227 }
228
229 if cache.dtype() == DType::Q8_0 {
231 return update_kv_cache_q8(cache, &cd, &nd, current_seq_len, new_k);
232 }
233
234 let max_seq = cd[2];
235 let new_seq = nd[2];
236
237 if new_seq > max_seq {
238 anyhow::bail!("new tokens {} exceeds max cache size {}", new_seq, max_seq);
239 }
240
241 let mut total_seq = current_seq_len + new_seq;
242 let shift = total_seq.saturating_sub(max_seq);
243
244 let (b_sz, h, hd) = (cd[0], cd[1], cd[3]);
245 let new_k_slice = new_k.as_f32_slice();
246 let cache_strides = cache.strides().to_vec();
247
248 {
249 let cache_slice = cache.as_f32_slice_mut()?;
250
251 if shift > 0 {
253 let keep_seq = current_seq_len - shift;
254 for bi in 0..b_sz {
255 for hi in 0..h {
256 let cache_base = bi * cache_strides[0] + hi * cache_strides[1];
257 for si in 0..keep_seq {
258 let src_idx = cache_base + (si + shift) * cache_strides[2];
259 let dst_idx = cache_base + si * cache_strides[2];
260 cache_slice.copy_within(src_idx..src_idx + hd, dst_idx);
261 }
262 }
263 }
264 }
265
266 let insert_pos = if shift > 0 {
268 current_seq_len - shift
269 } else {
270 current_seq_len
271 };
272 for bi in 0..b_sz {
273 for hi in 0..h {
274 let cache_base =
275 bi * cache_strides[0] + hi * cache_strides[1] + insert_pos * cache_strides[2];
276 let new_base = ((bi * h + hi) * new_seq) * hd; for si in 0..new_seq {
279 let c_idx = cache_base + si * cache_strides[2];
280 let n_idx = new_base + si * hd;
281
282 cache_slice[c_idx..c_idx + hd].copy_from_slice(&new_k_slice[n_idx..n_idx + hd]);
284 }
285 }
286 }
287 }
288
289 if shift > 0 {
290 total_seq = max_seq;
291 }
292
293 cache
295 .slice_axis(2, 0, total_seq)
296 .map_err(|e| anyhow::anyhow!("{e}"))
297}
298
299fn update_kv_cache_q8(
309 cache: &mut Tensor,
310 cd: &[usize],
311 nd: &[usize],
312 current_seq_len: usize,
313 new_k: &Tensor,
314) -> Result<Tensor> {
315 let (b_sz, h, max_seq, hd) = (cd[0], cd[1], cd[2], cd[3]);
316 let new_seq = nd[2];
317
318 if new_seq > max_seq {
319 anyhow::bail!("new tokens {} exceeds max cache size {}", new_seq, max_seq);
320 }
321
322 let blocks_per_head = hd / 32;
323 let bytes_per_pos = blocks_per_head * 34;
324 let mut total_seq = current_seq_len + new_seq;
325 let shift = total_seq.saturating_sub(max_seq);
326
327 let pos_off = |bi: usize, hi: usize, si: usize| -> usize {
328 (bi * h * max_seq + hi * max_seq + si) * bytes_per_pos
329 };
330
331 let cache_bytes = cache.as_bytes_mut()?;
333
334 if shift > 0 {
335 let keep_seq = current_seq_len - shift;
336 for bi in 0..b_sz {
337 for hi in 0..h {
338 for si in 0..keep_seq {
339 let src = pos_off(bi, hi, si + shift);
340 let dst = pos_off(bi, hi, si);
341 cache_bytes.copy_within(src..src + bytes_per_pos, dst);
342 }
343 }
344 }
345 }
346
347 let insert_pos = if shift > 0 { current_seq_len - shift } else { current_seq_len };
348 let new_k_f32 = new_k.to_f32_vec();
349
350 for bi in 0..b_sz {
351 for hi in 0..h {
352 for si in 0..new_seq {
353 let dst_start = pos_off(bi, hi, insert_pos + si);
354 let src_f32_start = (bi * h * new_seq + hi * new_seq + si) * hd;
355 let src_f32 = &new_k_f32[src_f32_start..src_f32_start + hd];
356 for blk in 0..blocks_per_head {
357 let encoded = quantize_f32_to_q8_0_block(&src_f32[blk * 32..(blk + 1) * 32]);
358 cache_bytes[dst_start + blk * 34..dst_start + blk * 34 + 34]
359 .copy_from_slice(&encoded);
360 }
361 }
362 }
363 }
364
365 if shift > 0 {
366 total_seq = max_seq;
367 }
368
369 let cache_ro = cache.as_bytes();
372 let out_numel = b_sz * h * total_seq * hd;
373 let mut out_f32 = vec![0.0f32; out_numel];
374
375 for bi in 0..b_sz {
376 for hi in 0..h {
377 for si in 0..total_seq {
378 let src_start = pos_off(bi, hi, si);
379 let dst_f32_start = (bi * h * total_seq + hi * total_seq + si) * hd;
380 for blk in 0..blocks_per_head {
381 let bb = &cache_ro[src_start + blk * 34..src_start + blk * 34 + 34];
382 let d = half::f16::from_le_bytes([bb[0], bb[1]]).to_f32();
383 for j in 0..32 {
384 out_f32[dst_f32_start + blk * 32 + j] = bb[2 + j] as i8 as f32 * d;
385 }
386 }
387 }
388 }
389 }
390
391 Tensor::from_f32_vec(out_f32, Shape::new(vec![b_sz, h, total_seq, hd]))
392 .map_err(|e| anyhow::anyhow!("{e}"))
393}
394
395pub fn apply_rope_positions(x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
396 map_err(rope::apply_rope(x, positions, base))
397}
398
399pub fn apply_rope_partial(
401 x: &Tensor,
402 positions: &[usize],
403 base: f32,
404 rotary_dim: usize,
405) -> Result<Tensor> {
406 map_err(rope::apply_rope_partial(x, positions, base, rotary_dim))
407}
408
409pub fn add_bias_last_dim(y: &Tensor, bias: &Tensor) -> Result<Tensor> {
412 let dims = y.shape().dims().to_vec();
413 let n = *dims.last().ok_or_else(|| anyhow::anyhow!("empty tensor"))?;
414 let bias_cow = bias.to_f32_cow();
415 let b = bias_cow.as_ref();
416 if b.len() != n {
417 anyhow::bail!("bias length {} does not match last dim {n}", b.len());
418 }
419 let mut data = y.as_f32_slice().to_vec();
420 for (i, v) in data.iter_mut().enumerate() {
421 *v += b[i % n];
422 }
423 map_err(Tensor::from_f32(&data, Shape::new(dims)))
424}
425
426pub fn rms_norm(x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
427 map_err(layernorm::rms_norm(x, Some(weight), eps))
428}
429
430pub fn layer_norm(x: &Tensor, weight: &Tensor, bias: Option<&Tensor>, eps: f32) -> Result<Tensor> {
431 map_err(layernorm::layer_norm(x, Some(weight), bias, -1, eps))
432}
433
434pub fn silu(x: &Tensor) -> Result<Tensor> {
435 map_err(kernels::elementwise::silu(x))
436}
437
438pub fn gelu(x: &Tensor) -> Result<Tensor> {
439 map_err(kernels::elementwise::gelu(x))
440}
441
442pub fn add(a: &Tensor, b: &Tensor) -> Result<Tensor> {
443 map_err(kernels::elementwise::add(a, b))
444}
445
446pub fn mul(a: &Tensor, b: &Tensor) -> Result<Tensor> {
447 map_err(kernels::elementwise::mul(a, b))
448}
449
450pub fn gqa_attention(
451 q: &Tensor,
452 k: &Tensor,
453 v: &Tensor,
454 n_kv_heads: usize,
455 causal: bool,
456) -> Result<Tensor> {
457 let mask = if causal {
458 let sq = q.shape().dims()[2];
459 let sk = k.shape().dims()[2];
460 Some(attention::causal_mask(sq, sk))
461 } else {
462 None
463 };
464 map_err(attention::scaled_dot_product_attention(
465 q,
466 k,
467 v,
468 mask.as_ref(),
469 None,
470 n_kv_heads,
471 ))
472}
473
474pub fn all_logits_from_hidden(hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<Vec<f32>>> {
477 let dims = hidden.shape().dims();
478 let hidden_size = dims[2];
479 let seq = dims[1];
480 let vocab_size = lm_head.shape().dims()[0];
481 let h = hidden.as_f32_slice();
482 let h_all =
483 Tensor::from_f32(h, Shape::new([seq, hidden_size])).map_err(|e| anyhow::anyhow!("{e}"))?;
484 let logits_flat = map_err(matmul::matmul_nt(&h_all, lm_head))?;
485 let flat = logits_flat.as_f32_slice();
486 let mut all = Vec::with_capacity(seq);
487 for i in 0..seq {
488 all.push(flat[i * vocab_size..(i + 1) * vocab_size].to_vec());
489 }
490 Ok(all)
491}
492
493pub fn logits_from_hidden(hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
494 let dims = hidden.shape().dims();
496 let hidden_size = dims[2];
497 let seq = dims[1];
498 let h = hidden.as_f32_slice();
499 let last = &h[(seq - 1) * hidden_size..seq * hidden_size];
500 let h_last =
501 Tensor::from_f32(last, Shape::new([1, hidden_size])).map_err(|e| anyhow::anyhow!("{e}"))?;
502 let logits = map_err(matmul::matmul_nt(&h_last, lm_head))?;
504 Ok(logits.as_f32_slice().to_vec())
505}
506
507pub fn mean_pool_hidden(hidden: &Tensor) -> Result<Vec<f32>> {
508 let dims = hidden.shape().dims();
509 let (seq, hidden_size) = (dims[1], dims[2]);
510 let h = hidden.as_f32_slice();
511 let mut out = vec![0.0f32; hidden_size];
512 for t in 0..seq {
513 for i in 0..hidden_size {
514 out[i] += h[t * hidden_size + i];
515 }
516 }
517 let n = seq as f32;
518 for v in &mut out {
519 *v /= n;
520 }
521 Ok(out)
522}