Skip to main content

sapient_backends_cpu/
backend.rs

1//! `ExecutionBackend` trait and `CpuBackend` implementation.
2
3use std::collections::HashMap;
4
5use tracing::{debug, instrument};
6
7use sapient_core::buffer::{BufferHandle, CpuBuffer};
8use sapient_core::error::{Result, SapientError};
9use sapient_core::{DType, Tensor};
10use sapient_ir::graph::Graph;
11use sapient_ir::node::{Node, NodeId};
12use sapient_ir::op::OpType;
13
14use crate::kernels;
15use crate::pool::PoolAllocator;
16
17// ── ExecutionBackend trait ────────────────────────────────────────────────────
18
19/// The unified backend interface every hardware target must implement.
20///
21/// Backends may be selected at runtime via `Box<dyn ExecutionBackend>` or at
22/// compile time via generics.
23pub trait ExecutionBackend: Send + Sync {
24    /// Short name for logging / CLI display.
25    fn name(&self) -> &str;
26
27    /// Allocate an uninitialised buffer for the given shape and dtype.
28    fn allocate(&self, shape: &[usize], dtype: DType) -> Result<BufferHandle>;
29
30    /// Execute the graph, returning output tensors in the order of
31    /// `graph.outputs`.
32    fn execute(&self, graph: &Graph, inputs: HashMap<String, Tensor>) -> Result<Vec<Tensor>>;
33
34    /// Whether this backend can execute the given op natively.
35    fn supports_op(&self, op: &OpType) -> bool;
36
37    /// True if this backend is available on the current system.
38    fn is_available() -> bool
39    where
40        Self: Sized,
41    {
42        true
43    }
44}
45
46// ── CpuBackend ────────────────────────────────────────────────────────────────
47
48/// Pure-Rust CPU execution backend.
49pub struct CpuBackend {
50    pool: PoolAllocator,
51}
52
53impl std::fmt::Debug for CpuBackend {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.debug_struct("CpuBackend").finish()
56    }
57}
58
59impl Default for CpuBackend {
60    fn default() -> Self {
61        Self::new(256 * 1024 * 1024) // 256 MiB pool
62    }
63}
64
65impl CpuBackend {
66    /// Create a new CPU backend with the given pool capacity (bytes).
67    pub fn new(pool_bytes: usize) -> Self {
68        Self {
69            pool: PoolAllocator::new(pool_bytes),
70        }
71    }
72}
73
74impl ExecutionBackend for CpuBackend {
75    fn name(&self) -> &str {
76        "cpu"
77    }
78
79    fn allocate(&self, shape: &[usize], dtype: DType) -> Result<BufferHandle> {
80        let numel: usize = shape.iter().product();
81        // Try the pool first; fall back to a fresh allocation.
82        if let Some(handle) = self.pool.acquire(numel, dtype) {
83            return Ok(handle);
84        }
85        let buf = CpuBuffer::zeros(numel, dtype)?;
86        Ok(BufferHandle::new(buf))
87    }
88
89    #[instrument(skip_all, fields(graph = %graph.name))]
90    fn execute(&self, graph: &Graph, inputs: HashMap<String, Tensor>) -> Result<Vec<Tensor>> {
91        // Topological execution order.
92        let order = graph.topological_order()?;
93
94        // Value map: (NodeId, output_index) → Tensor.
95        let mut values: HashMap<(NodeId, usize), Tensor> = HashMap::new();
96
97        // Seed inputs.
98        for id in &graph.inputs {
99            if let Some(Node::Input { name, .. }) = graph.get(*id) {
100                if let Some(t) = inputs.get(name) {
101                    values.insert((*id, 0), t.clone());
102                }
103            }
104        }
105
106        for id in &order {
107            match graph.get(*id) {
108                Some(Node::Constant { value, .. }) => {
109                    values.insert((*id, 0), value.clone());
110                }
111                Some(Node::Input { .. }) => {
112                    // Already seeded above.
113                }
114                Some(Node::Operator {
115                    op,
116                    inputs: inp_ids,
117                    num_outputs,
118                    ..
119                }) => {
120                    let op = op.clone();
121                    let inp_ids = inp_ids.clone();
122                    let _num_outputs = *num_outputs;
123
124                    // Collect input tensors.
125                    let input_tensors: Vec<Tensor> = inp_ids
126                        .iter()
127                        .map(|&inp| {
128                            values.get(&(inp, 0)).cloned().ok_or_else(|| {
129                                SapientError::internal(format!("missing value for node {inp}"))
130                            })
131                        })
132                        .collect::<Result<Vec<_>>>()?;
133
134                    // Dispatch to kernel.
135                    let outputs = self.dispatch(&op, &input_tensors)?;
136
137                    for (i, t) in outputs.into_iter().enumerate() {
138                        values.insert((*id, i), t);
139                    }
140                }
141                Some(Node::Output { source, .. }) => {
142                    // Alias output to its source.
143                    if let Some(t) = values.get(&(*source, 0)).cloned() {
144                        values.insert((*id, 0), t);
145                    }
146                }
147                None => {}
148            }
149        }
150
151        // Collect graph outputs in order.
152        let out_tensors: Vec<Tensor> = graph
153            .outputs
154            .iter()
155            .map(|&oid| {
156                values
157                    .get(&(oid, 0))
158                    .cloned()
159                    .ok_or_else(|| SapientError::internal(format!("output {oid} not computed")))
160            })
161            .collect::<Result<Vec<_>>>()?;
162
163        debug!(
164            outputs = out_tensors.len(),
165            "CpuBackend: execution complete"
166        );
167        Ok(out_tensors)
168    }
169
170    fn supports_op(&self, op: &OpType) -> bool {
171        matches!(
172            op,
173            OpType::MatMul | OpType::Gemm { .. }
174            | OpType::Add | OpType::Sub | OpType::Mul | OpType::Div | OpType::Pow
175            | OpType::Neg | OpType::Abs | OpType::Sqrt | OpType::Exp | OpType::Log
176            | OpType::Relu | OpType::Sigmoid | OpType::Tanh | OpType::Gelu
177            | OpType::LeakyRelu { .. } | OpType::Silu | OpType::HardSwish
178            | OpType::Softmax { .. } | OpType::LogSoftmax { .. }
179            | OpType::LayerNorm { .. } | OpType::RmsNorm { .. }
180            | OpType::Conv2d { .. }
181            | OpType::Reshape | OpType::Transpose { .. } | OpType::Flatten { .. }
182            | OpType::Concat { .. }
183            | OpType::ReduceSum { .. } | OpType::ReduceMean { .. }
184            | OpType::ReduceMax { .. } | OpType::ReduceMin { .. }
185            | OpType::Identity | OpType::Clip { .. }
186            | OpType::Erf | OpType::Floor | OpType::Ceil | OpType::Round
187            // LLM ops
188            | OpType::Embedding { .. }
189            | OpType::MultiHeadAttention { .. }
190            | OpType::GroupedQueryAttention { .. }
191            | OpType::RotaryEmbedding { .. }
192            | OpType::CausalMask
193            | OpType::KVCacheConcat
194            | OpType::RepeatKV { .. }
195        )
196    }
197}
198
199impl CpuBackend {
200    /// Dispatch an op to its kernel.
201    fn dispatch(&self, op: &OpType, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
202        let out = match op {
203            // ── Linear algebra ────────────────────────────────────────────
204            OpType::MatMul => {
205                let a = inputs
206                    .get(0)
207                    .ok_or_else(|| SapientError::internal("MatMul: missing a"))?;
208                let b = inputs
209                    .get(1)
210                    .ok_or_else(|| SapientError::internal("MatMul: missing b"))?;
211                vec![kernels::matmul::matmul(a, b)?]
212            }
213            OpType::Gemm {
214                alpha,
215                beta,
216                trans_a,
217                trans_b,
218            } => {
219                let a = inputs
220                    .get(0)
221                    .ok_or_else(|| SapientError::internal("Gemm: missing a"))?;
222                let b = inputs
223                    .get(1)
224                    .ok_or_else(|| SapientError::internal("Gemm: missing b"))?;
225                let c = inputs.get(2);
226                vec![kernels::matmul::gemm(
227                    a,
228                    b,
229                    c,
230                    alpha.0 as f32,
231                    beta.0 as f32,
232                    *trans_a,
233                    *trans_b,
234                )?]
235            }
236
237            // ── Element-wise ──────────────────────────────────────────────
238            OpType::Add => vec![kernels::elementwise::add(
239                inputs.get(0).unwrap(),
240                inputs.get(1).unwrap(),
241            )?],
242            OpType::Sub => vec![kernels::elementwise::sub(
243                inputs.get(0).unwrap(),
244                inputs.get(1).unwrap(),
245            )?],
246            OpType::Mul => vec![kernels::elementwise::mul(
247                inputs.get(0).unwrap(),
248                inputs.get(1).unwrap(),
249            )?],
250            OpType::Div => vec![kernels::elementwise::div(
251                inputs.get(0).unwrap(),
252                inputs.get(1).unwrap(),
253            )?],
254            OpType::Pow => vec![kernels::elementwise::pow(
255                inputs.get(0).unwrap(),
256                inputs.get(1).unwrap(),
257            )?],
258            OpType::Neg => vec![kernels::elementwise::neg(inputs.get(0).unwrap())?],
259            OpType::Abs => vec![kernels::elementwise::abs(inputs.get(0).unwrap())?],
260            OpType::Sqrt => vec![kernels::elementwise::sqrt(inputs.get(0).unwrap())?],
261            OpType::Exp => vec![kernels::elementwise::exp(inputs.get(0).unwrap())?],
262            OpType::Log => vec![kernels::elementwise::log(inputs.get(0).unwrap())?],
263            OpType::Erf => vec![kernels::elementwise::erf(inputs.get(0).unwrap())?],
264            OpType::Floor => vec![kernels::elementwise::floor(inputs.get(0).unwrap())?],
265            OpType::Ceil => vec![kernels::elementwise::ceil(inputs.get(0).unwrap())?],
266            OpType::Round => vec![kernels::elementwise::round(inputs.get(0).unwrap())?],
267
268            // ── Activations ───────────────────────────────────────────────
269            OpType::Relu => vec![kernels::elementwise::relu(inputs.get(0).unwrap())?],
270            OpType::Sigmoid => vec![kernels::elementwise::sigmoid(inputs.get(0).unwrap())?],
271            OpType::Tanh => vec![kernels::elementwise::tanh_act(inputs.get(0).unwrap())?],
272            OpType::Gelu => vec![kernels::elementwise::gelu(inputs.get(0).unwrap())?],
273            OpType::Silu => vec![kernels::elementwise::silu(inputs.get(0).unwrap())?],
274            OpType::HardSwish => vec![kernels::elementwise::hard_swish(inputs.get(0).unwrap())?],
275            OpType::LeakyRelu { alpha } => {
276                vec![kernels::elementwise::leaky_relu(
277                    inputs.get(0).unwrap(),
278                    alpha.0 as f32,
279                )?]
280            }
281            OpType::Clip { min, max } => {
282                vec![kernels::elementwise::clip(
283                    inputs.get(0).unwrap(),
284                    min.map(|v| v.0 as f32),
285                    max.map(|v| v.0 as f32),
286                )?]
287            }
288
289            // ── Normalisation ──────────────────────────────────────────────
290            OpType::Softmax { axis } => {
291                vec![kernels::softmax::softmax(inputs.get(0).unwrap(), *axis)?]
292            }
293            OpType::LogSoftmax { axis } => {
294                vec![kernels::softmax::log_softmax(
295                    inputs.get(0).unwrap(),
296                    *axis,
297                )?]
298            }
299            OpType::LayerNorm { axis, epsilon } => {
300                let weight = inputs.get(1);
301                let bias = inputs.get(2);
302                vec![kernels::layernorm::layer_norm(
303                    inputs.get(0).unwrap(),
304                    weight,
305                    bias,
306                    *axis,
307                    epsilon.0 as f32,
308                )?]
309            }
310            OpType::RmsNorm { epsilon } => {
311                let weight = inputs.get(1);
312                vec![kernels::layernorm::rms_norm(
313                    inputs.get(0).unwrap(),
314                    weight,
315                    epsilon.0 as f32,
316                )?]
317            }
318
319            // ── Convolution ────────────────────────────────────────────────
320            OpType::Conv2d {
321                kernel_shape,
322                pads,
323                strides,
324                dilations,
325                groups,
326            } => {
327                let x = inputs.get(0).unwrap();
328                let w = inputs.get(1).unwrap();
329                let b = inputs.get(2);
330                vec![kernels::conv2d::conv2d(
331                    x,
332                    w,
333                    b,
334                    *kernel_shape,
335                    *pads,
336                    *strides,
337                    *dilations,
338                    *groups,
339                )?]
340            }
341
342            // ── Shape ops ─────────────────────────────────────────────────
343            OpType::Reshape => {
344                let x = inputs.get(0).unwrap();
345                // The new shape comes from the second input (if present) or is
346                // determined at shape-inference time.
347                // For now, identity (shape already baked in by runtime).
348                vec![x.clone()]
349            }
350            OpType::Identity => vec![inputs.get(0).unwrap().clone()],
351
352            // ── Reduce ────────────────────────────────────────────────────
353            OpType::ReduceSum { axes, keep_dims } => {
354                vec![kernels::reduce::reduce_sum(
355                    inputs.get(0).unwrap(),
356                    axes,
357                    *keep_dims,
358                )?]
359            }
360            OpType::ReduceMean { axes, keep_dims } => {
361                vec![kernels::reduce::reduce_mean(
362                    inputs.get(0).unwrap(),
363                    axes,
364                    *keep_dims,
365                )?]
366            }
367            OpType::ReduceMax { axes, keep_dims } => {
368                vec![kernels::reduce::reduce_max(
369                    inputs.get(0).unwrap(),
370                    axes,
371                    *keep_dims,
372                )?]
373            }
374
375            // ── LLM ops ───────────────────────────────────────────────────
376
377            // Embedding lookup: weight `[vocab, hidden]` at inputs[0], token ids at inputs[1].
378            OpType::Embedding { .. } => {
379                let weight = inputs
380                    .get(0)
381                    .ok_or_else(|| SapientError::internal("Embedding: missing weight"))?;
382                let ids_t = inputs
383                    .get(1)
384                    .ok_or_else(|| SapientError::internal("Embedding: missing input_ids"))?;
385                let dims = ids_t.shape().dims();
386                let seq_len: usize = dims.iter().product();
387                let hidden = weight.shape().dims()[1];
388                // Use to_f32_vec() to transparently handle BF16/F16 weights
389                let w = weight.to_f32_vec();
390                let ids: Vec<u32> = if ids_t.dtype() == DType::F32 {
391                    ids_t.as_f32_slice().iter().map(|&v| v as u32).collect()
392                } else {
393                    ids_t
394                        .as_bytes()
395                        .chunks_exact(4)
396                        .map(|c| u32::from_le_bytes(c.try_into().unwrap()))
397                        .collect()
398                };
399                let mut out = vec![0.0f32; seq_len * hidden];
400                for (i, &id) in ids.iter().enumerate() {
401                    let row = id as usize * hidden;
402                    out[i * hidden..(i + 1) * hidden].copy_from_slice(&w[row..row + hidden]);
403                }
404                let batch = if dims.len() >= 2 { dims[0] } else { 1 };
405                let seq = if dims.len() >= 2 { dims[1] } else { seq_len };
406                vec![Tensor::from_f32(&out, vec![batch, seq, hidden])
407                    .map_err(|e| SapientError::internal(e.to_string()))?]
408            }
409
410            // Grouped-Query Attention — calls the attention kernel.
411            OpType::GroupedQueryAttention {
412                n_heads: _,
413                n_kv_heads,
414                head_dim: _,
415                causal,
416            } => {
417                let q = inputs
418                    .get(0)
419                    .ok_or_else(|| SapientError::internal("GQA: missing Q"))?;
420                let k = inputs
421                    .get(1)
422                    .ok_or_else(|| SapientError::internal("GQA: missing K"))?;
423                let v = inputs
424                    .get(2)
425                    .ok_or_else(|| SapientError::internal("GQA: missing V"))?;
426                let mask = if *causal {
427                    let seq_q = q.shape().dims().get(2).copied().unwrap_or(1);
428                    let seq_k = k.shape().dims().get(2).copied().unwrap_or(1);
429                    Some(kernels::attention::causal_mask(seq_q, seq_k))
430                } else {
431                    None
432                };
433                vec![kernels::attention::scaled_dot_product_attention(
434                    q,
435                    k,
436                    v,
437                    mask.as_ref(),
438                    None,
439                    *n_kv_heads,
440                )?]
441            }
442
443            // Multi-head attention (non-GQA, n_kv_heads = n_heads).
444            OpType::MultiHeadAttention {
445                num_heads,
446                head_dim: _,
447                causal,
448                scale,
449            } => {
450                let q = inputs
451                    .get(0)
452                    .ok_or_else(|| SapientError::internal("MHA: missing Q"))?;
453                let k = inputs
454                    .get(1)
455                    .ok_or_else(|| SapientError::internal("MHA: missing K"))?;
456                let v = inputs
457                    .get(2)
458                    .ok_or_else(|| SapientError::internal("MHA: missing V"))?;
459                let mask = if *causal {
460                    let sq = q.shape().dims().get(2).copied().unwrap_or(1);
461                    let sk = k.shape().dims().get(2).copied().unwrap_or(1);
462                    Some(kernels::attention::causal_mask(sq, sk))
463                } else {
464                    None
465                };
466                vec![kernels::attention::scaled_dot_product_attention(
467                    q,
468                    k,
469                    v,
470                    mask.as_ref(),
471                    scale.map(|s| s.0 as f32),
472                    *num_heads,
473                )?]
474            }
475
476            // RoPE — apply rotary embeddings to Q or K.
477            OpType::RotaryEmbedding { base, dim: _ } => {
478                let x = inputs
479                    .get(0)
480                    .ok_or_else(|| SapientError::internal("RoPE: missing input"))?;
481                let seq_len = x.shape().dims().get(2).copied().unwrap_or(1);
482                let positions: Vec<usize> = (0..seq_len).collect();
483                vec![kernels::rope::apply_rope(x, &positions, base.0 as f32)?]
484            }
485
486            // Causal mask generation.
487            OpType::CausalMask => {
488                let seq = inputs
489                    .get(0)
490                    .map(|t| t.shape().dims().get(1).copied().unwrap_or(1))
491                    .unwrap_or(1);
492                vec![kernels::attention::causal_mask(seq, seq)]
493            }
494
495            // KV cache concat — identity for now (cache is managed by Pipeline).
496            OpType::KVCacheConcat | OpType::RepeatKV { .. } => {
497                vec![inputs.get(0).unwrap().clone()]
498            }
499
500            // MoE gate / dispatch — identity (scheduler handles expert routing).
501            OpType::MoEGate { .. } | OpType::ScaledDotProductAttention { .. } => {
502                vec![inputs.get(0).unwrap().clone()]
503            }
504
505            // ALiBi — return a zero tensor (will be added to attention logits).
506            OpType::ALiBi { .. } => {
507                vec![Tensor::zeros(vec![1], DType::F32).unwrap()]
508            }
509
510            // ── Fallback ──────────────────────────────────────────────────
511            other => {
512                return Err(SapientError::unsupported_op("cpu", &other.to_string()));
513            }
514        };
515        Ok(out)
516    }
517}