Skip to main content

sapient_models/forward/
backend.rs

1//! Backend dispatch for native LLM forward passes.
2
3use anyhow::Result;
4use sapient_core::Tensor;
5
6use super::common;
7
8#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
9pub enum LlmBackendKind {
10    Cpu,
11    Metal,
12    #[default]
13    Auto,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct MacGpuSupport {
18    pub available: bool,
19    pub backend: &'static str,
20    pub reason: &'static str,
21}
22
23pub fn mac_gpu_support() -> MacGpuSupport {
24    MlxLlmOps::support()
25}
26
27impl std::str::FromStr for LlmBackendKind {
28    type Err = anyhow::Error;
29
30    fn from_str(value: &str) -> Result<Self> {
31        match value.to_ascii_lowercase().as_str() {
32            "cpu" => Ok(Self::Cpu),
33            "metal" => Ok(Self::Metal),
34            "auto" => Ok(Self::Auto),
35            other => anyhow::bail!("unsupported generation backend '{other}'"),
36        }
37    }
38}
39
40impl std::fmt::Display for LlmBackendKind {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            Self::Cpu => write!(f, "cpu"),
44            Self::Metal => write!(f, "metal"),
45            Self::Auto => write!(f, "auto"),
46        }
47    }
48}
49
50pub trait LlmBackend: Send + Sync {
51    fn name(&self) -> &'static str;
52    fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor>;
53
54    /// Linear projection with an optional bias added over the last dimension.
55    /// Backend-agnostic: computes `linear_3d` then folds in the bias on the host,
56    /// so every backend gets correct bias handling for free.
57    fn linear_3d_bias(&self, x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Result<Tensor> {
58        let y = self.linear_3d(x, weight)?;
59        match bias {
60            None => Ok(y),
61            Some(b) => common::add_bias_last_dim(&y, b),
62        }
63    }
64
65    fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor>;
66    fn layer_norm(
67        &self,
68        x: &Tensor,
69        weight: &Tensor,
70        bias: Option<&Tensor>,
71        eps: f32,
72    ) -> Result<Tensor>;
73    fn silu(&self, x: &Tensor) -> Result<Tensor>;
74    fn gelu(&self, x: &Tensor) -> Result<Tensor>;
75    fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor>;
76    fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor>;
77    fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor>;
78
79    /// RoPE over only the first `rotary_dim` channels (Phi partial rotary).
80    /// Computed on the CPU reference kernel for all backends — it is cheap and
81    /// avoids backend-specific partial-rotary support.
82    fn apply_rope_partial(
83        &self,
84        x: &Tensor,
85        positions: &[usize],
86        base: f32,
87        rotary_dim: usize,
88    ) -> Result<Tensor> {
89        common::apply_rope_partial(x, positions, base, rotary_dim)
90    }
91
92    fn gqa_attention(
93        &self,
94        q: &Tensor,
95        k: &Tensor,
96        v: &Tensor,
97        n_kv_heads: usize,
98        causal: bool,
99    ) -> Result<Tensor>;
100    fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>>;
101
102    /// Compute logits for ALL positions in the sequence.
103    /// Returns `seq_len` vectors each of length `vocab_size`.
104    /// Default impl delegates to the CPU reference kernel; backends may override.
105    fn all_logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<Vec<f32>>> {
106        common::all_logits_from_hidden(hidden, lm_head)
107    }
108}
109
110#[derive(Debug, Default, Clone)]
111pub struct CpuLlmBackend;
112
113impl LlmBackend for CpuLlmBackend {
114    fn name(&self) -> &'static str {
115        "cpu"
116    }
117
118    fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor> {
119        common::linear_3d(x, weight)
120    }
121
122    fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
123        common::rms_norm(x, weight, eps)
124    }
125
126    fn layer_norm(
127        &self,
128        x: &Tensor,
129        weight: &Tensor,
130        bias: Option<&Tensor>,
131        eps: f32,
132    ) -> Result<Tensor> {
133        common::layer_norm(x, weight, bias, eps)
134    }
135
136    fn silu(&self, x: &Tensor) -> Result<Tensor> {
137        common::silu(x)
138    }
139
140    fn gelu(&self, x: &Tensor) -> Result<Tensor> {
141        common::gelu(x)
142    }
143
144    fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
145        common::add(a, b)
146    }
147
148    fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
149        common::mul(a, b)
150    }
151
152    fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
153        common::apply_rope_positions(x, positions, base)
154    }
155
156    fn gqa_attention(
157        &self,
158        q: &Tensor,
159        k: &Tensor,
160        v: &Tensor,
161        n_kv_heads: usize,
162        causal: bool,
163    ) -> Result<Tensor> {
164        common::gqa_attention(q, k, v, n_kv_heads, causal)
165    }
166
167    fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
168        common::logits_from_hidden(hidden, lm_head)
169    }
170}
171
172#[derive(Debug, Default, Clone)]
173pub struct MetalLlmBackend {
174    cpu: CpuLlmBackend,
175    mlx: MlxLlmOps,
176}
177
178impl MetalLlmBackend {
179    pub fn is_available() -> bool {
180        MlxLlmOps::is_available()
181    }
182}
183
184impl LlmBackend for MetalLlmBackend {
185    fn name(&self) -> &'static str {
186        "metal"
187    }
188
189    fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor> {
190        self.mlx
191            .linear_3d(x, weight)
192            .or_else(|e| {
193                tracing::warn!(op = "linear_3d", error = %e, "MLX op failed; using CPU reference kernel");
194                self.cpu.linear_3d(x, weight)
195            })
196    }
197
198    fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
199        self.mlx.rms_norm(x, weight, eps).or_else(|e| {
200            tracing::warn!(op = "rms_norm", error = %e, "MLX op failed; using CPU reference kernel");
201            self.cpu.rms_norm(x, weight, eps)
202        })
203    }
204
205    fn layer_norm(
206        &self,
207        x: &Tensor,
208        weight: &Tensor,
209        bias: Option<&Tensor>,
210        eps: f32,
211    ) -> Result<Tensor> {
212        self.mlx.layer_norm(x, weight, bias, eps).or_else(|e| {
213            tracing::warn!(op = "layer_norm", error = %e, "MLX op failed; using CPU reference kernel");
214            self.cpu.layer_norm(x, weight, bias, eps)
215        })
216    }
217
218    fn silu(&self, x: &Tensor) -> Result<Tensor> {
219        self.mlx.silu(x).or_else(|e| {
220            tracing::warn!(op = "silu", error = %e, "MLX op failed; using CPU reference kernel");
221            self.cpu.silu(x)
222        })
223    }
224
225    fn gelu(&self, x: &Tensor) -> Result<Tensor> {
226        self.mlx.gelu(x).or_else(|e| {
227            tracing::warn!(op = "gelu", error = %e, "MLX op failed; using CPU reference kernel");
228            self.cpu.gelu(x)
229        })
230    }
231
232    fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
233        self.mlx.add(a, b).or_else(|e| {
234            tracing::warn!(op = "add", error = %e, "MLX op failed; using CPU reference kernel");
235            self.cpu.add(a, b)
236        })
237    }
238
239    fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
240        self.mlx.mul(a, b).or_else(|e| {
241            tracing::warn!(op = "mul", error = %e, "MLX op failed; using CPU reference kernel");
242            self.cpu.mul(a, b)
243        })
244    }
245
246    fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
247        self.mlx
248            .apply_rope_positions(x, positions, base)
249            .or_else(|e| {
250                tracing::warn!(op = "rope", error = %e, "MLX op failed; using CPU reference kernel");
251                self.cpu.apply_rope_positions(x, positions, base)
252            })
253    }
254
255    fn gqa_attention(
256        &self,
257        q: &Tensor,
258        k: &Tensor,
259        v: &Tensor,
260        n_kv_heads: usize,
261        causal: bool,
262    ) -> Result<Tensor> {
263        // Run on the Metal GPU. MlxLlmOps::gqa_attention builds an explicit
264        // causal mask for prefill (seq_q > 1) so the KV-cache offset case
265        // (seq_q < seq_k at decode) is handled correctly.
266        self.mlx
267            .gqa_attention(q, k, v, n_kv_heads, causal)
268            .or_else(|e| {
269                tracing::warn!(op = "gqa_attention", error = %e,
270                    "MLX attention failed; falling back to CPU");
271                self.cpu.gqa_attention(q, k, v, n_kv_heads, causal)
272            })
273    }
274
275    fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
276        self.mlx.logits_from_hidden(hidden, lm_head).or_else(|e| {
277            tracing::warn!(op = "logits", error = %e, "MLX op failed; using CPU reference kernel");
278            self.cpu.logits_from_hidden(hidden, lm_head)
279        })
280    }
281}
282
283/// Converts a `Tensor` to an `mlx_rs::Array`, caching the result by buffer pointer so
284/// that weight tensors (which have a stable `Arc<CpuBuffer>` address) are uploaded to
285/// the GPU exactly once across all tokens, instead of re-converted on every `linear_3d`
286/// call.  Activation tensors are ephemeral (different pointer each step) and never
287/// accumulate in the cache.
288#[cfg(feature = "mlx")]
289type MlxWeightCache =
290    std::sync::Arc<parking_lot::Mutex<std::collections::HashMap<usize, mlx_rs::Array>>>;
291
292#[derive(Clone)]
293struct MlxLlmOps {
294    /// Shared weight cache: `buffer_ptr → GPU Array`. Clones share the same cache so
295    /// the MetalLlmBackend (which clones MlxLlmOps per call site) reuses uploads.
296    #[cfg(feature = "mlx")]
297    cache: MlxWeightCache,
298}
299
300impl std::fmt::Debug for MlxLlmOps {
301    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        f.debug_struct("MlxLlmOps").finish()
303    }
304}
305
306// Can't use #[derive(Default)] because the `cache` field is cfg-gated on
307// the `mlx` feature and derive doesn't understand that.
308#[allow(clippy::derivable_impls)]
309impl Default for MlxLlmOps {
310    fn default() -> Self {
311        Self {
312            #[cfg(feature = "mlx")]
313            cache: std::sync::Arc::new(parking_lot::Mutex::new(std::collections::HashMap::new())),
314        }
315    }
316}
317
318#[cfg(target_os = "macos")]
319impl MlxLlmOps {
320    fn support() -> MacGpuSupport {
321        #[cfg(all(target_os = "macos", target_arch = "aarch64", feature = "mlx"))]
322        {
323            MacGpuSupport {
324                available: true,
325                backend: "mlx",
326                reason: "Apple Silicon with MLX feature enabled",
327            }
328        }
329
330        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(feature = "mlx")))]
331        {
332            MacGpuSupport {
333                available: false,
334                backend: "cpu",
335                reason: "Apple Silicon detected, but Sapient was built without the sapient-models/mlx feature",
336            }
337        }
338
339        #[cfg(all(target_os = "macos", not(target_arch = "aarch64")))]
340        {
341            MacGpuSupport {
342                available: false,
343                backend: "cpu",
344                reason: "MLX GPU execution requires Apple Silicon; Intel Macs use CPU",
345            }
346        }
347
348        #[cfg(not(target_os = "macos"))]
349        {
350            MacGpuSupport {
351                available: false,
352                backend: "cpu",
353                reason: "MLX GPU execution is only available on macOS",
354            }
355        }
356    }
357
358    fn is_available() -> bool {
359        Self::support().available
360    }
361
362    #[cfg(feature = "mlx")]
363    fn to_shape(dims: &[usize]) -> Result<Vec<i32>> {
364        dims.iter()
365            .map(|&d| {
366                i32::try_from(d)
367                    .map_err(|_| anyhow::anyhow!("shape dimension too large for MLX: {d}"))
368            })
369            .collect()
370    }
371
372    /// Convert a `Tensor` to an `mlx_rs::Array`.
373    ///
374    /// For tensors with stable buffer addresses (weights stored in `Arc<CpuBuffer>`) this
375    /// returns a cached copy — the upload to GPU happens only on the first call.  Activations
376    /// have a fresh allocation each decode step so they are converted without caching.
377    #[cfg(feature = "mlx")]
378    fn to_array(&self, tensor: &Tensor) -> Result<mlx_rs::Array> {
379        let ptr_key = tensor.as_bytes().as_ptr() as usize;
380
381        // Fast path: already uploaded.
382        {
383            let guard = self.cache.lock();
384            if let Some(arr) = guard.get(&ptr_key) {
385                return Ok(arr.clone());
386            }
387        }
388
389        // Slow path: convert and cache.
390        let shape = Self::to_shape(tensor.shape().dims())?;
391        let data = tensor.to_f32_cow();
392        let arr = mlx_rs::Array::from_slice(data.as_ref(), &shape);
393
394        // Only cache when the tensor looks like a weight (> 1 KiB and numel > 256).
395        // This avoids caching tiny scalars or ephemeral activation buffers that happen
396        // to share the same size.
397        let numel = tensor.numel();
398        if numel > 256 {
399            self.cache.lock().insert(ptr_key, arr.clone());
400        }
401
402        Ok(arr)
403    }
404
405    /// Convert without caching — for activation tensors created fresh each step.
406    ///
407    /// `to_f32_cow` on a non-contiguous tensor (e.g. KV-cache slices from
408    /// `slice_axis`) returns the full backing buffer, which is far larger than the
409    /// tensor's logical `numel`. `Array::from_slice` asserts `data.len == shape.product`
410    /// and panics. We limit the slice to `numel` elements to prevent the assert.
411    ///
412    /// For contiguous tensors (the common case for activations and weights) the
413    /// buffer length already equals `numel` so this limit is a no-op.
414    #[cfg(feature = "mlx")]
415    fn to_array_uncached(tensor: &Tensor) -> Result<mlx_rs::Array> {
416        let shape = Self::to_shape(tensor.shape().dims())?;
417        let numel = tensor.numel();
418        let cow = tensor.to_f32_cow();
419        // Limit to the logical element count so non-contiguous view tensors (KV
420        // cache slices) don't overflow the MLX assert.
421        let data = &cow[..numel.min(cow.len())];
422        Ok(mlx_rs::Array::from_slice(data, &shape))
423    }
424
425    #[cfg(feature = "mlx")]
426    fn to_tensor(array: mlx_rs::Array) -> Result<Tensor> {
427        let shape: Vec<usize> = array
428            .shape()
429            .iter()
430            .map(|&d| {
431                usize::try_from(d).map_err(|_| anyhow::anyhow!("negative MLX shape dimension: {d}"))
432            })
433            .collect::<Result<Vec<_>>>()?;
434        let data = array.as_slice::<f32>().to_vec();
435        Tensor::from_f32(&data, shape).map_err(|e| anyhow::anyhow!("{e}"))
436    }
437
438    fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor> {
439        #[cfg(not(feature = "mlx"))]
440        {
441            let _ = (x, weight);
442            anyhow::bail!("{}", Self::support().reason);
443        }
444
445        #[cfg(feature = "mlx")]
446        {
447            let dims = x.shape().dims();
448            if dims.len() != 3 {
449                anyhow::bail!("linear_3d expects [batch, seq, hidden]");
450            }
451            let (batch, seq, in_dim) = (dims[0], dims[1], dims[2]);
452            let w_dims = weight.shape().dims();
453            if w_dims.len() != 2 {
454                anyhow::bail!("linear weight must be 2-D");
455            }
456            let out_dim = w_dims[0];
457            if w_dims[1] != in_dim {
458                anyhow::bail!("linear weight in_dim mismatch: {} vs {in_dim}", w_dims[1]);
459            }
460
461            // x is a fresh activation (ephemeral); weight is a stable weight (cache it).
462            let x_arr =
463                Self::to_array_uncached(x)?.reshape(&Self::to_shape(&[batch * seq, in_dim])?)?;
464            let w_arr = self.to_array(weight)?.transpose()?;
465            let y = x_arr.matmul(&w_arr)?;
466            Self::to_tensor(y.reshape(&Self::to_shape(&[batch, seq, out_dim])?)?)
467        }
468    }
469
470    fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
471        #[cfg(not(feature = "mlx"))]
472        {
473            let _ = (x, weight, eps);
474            anyhow::bail!("{}", Self::support().reason);
475        }
476
477        #[cfg(feature = "mlx")]
478        {
479            let x = Self::to_array_uncached(x)?;
480            let weight = self.to_array(weight)?;
481            Self::to_tensor(mlx_rs::fast::rms_norm(&x, &weight, eps)?)
482        }
483    }
484
485    fn layer_norm(
486        &self,
487        x: &Tensor,
488        weight: &Tensor,
489        bias: Option<&Tensor>,
490        eps: f32,
491    ) -> Result<Tensor> {
492        #[cfg(not(feature = "mlx"))]
493        {
494            let _ = (x, weight, bias, eps);
495            anyhow::bail!("{}", Self::support().reason);
496        }
497
498        #[cfg(feature = "mlx")]
499        {
500            let x = Self::to_array_uncached(x)?;
501            let weight = self.to_array(weight)?;
502            let bias = bias.map(|b| self.to_array(b)).transpose()?;
503            Self::to_tensor(mlx_rs::fast::layer_norm(
504                &x,
505                Some(&weight),
506                bias.as_ref(),
507                eps,
508            )?)
509        }
510    }
511
512    fn silu(&self, x: &Tensor) -> Result<Tensor> {
513        #[cfg(not(feature = "mlx"))]
514        {
515            let _ = x;
516            anyhow::bail!("{}", Self::support().reason);
517        }
518
519        #[cfg(feature = "mlx")]
520        {
521            let x = Self::to_array_uncached(x)?;
522            Self::to_tensor(mlx_rs::nn::silu(&x)?)
523        }
524    }
525
526    fn gelu(&self, x: &Tensor) -> Result<Tensor> {
527        #[cfg(not(feature = "mlx"))]
528        {
529            let _ = x;
530            anyhow::bail!("{}", Self::support().reason);
531        }
532
533        #[cfg(feature = "mlx")]
534        {
535            let x = Self::to_array_uncached(x)?;
536            Self::to_tensor(mlx_rs::nn::gelu(&x)?)
537        }
538    }
539
540    fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
541        #[cfg(not(feature = "mlx"))]
542        {
543            let _ = (a, b);
544            anyhow::bail!("{}", Self::support().reason);
545        }
546
547        #[cfg(feature = "mlx")]
548        {
549            let a = Self::to_array_uncached(a)?;
550            let b = Self::to_array_uncached(b)?;
551            Self::to_tensor(a.add(&b)?)
552        }
553    }
554
555    fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
556        #[cfg(not(feature = "mlx"))]
557        {
558            let _ = (a, b);
559            anyhow::bail!("{}", Self::support().reason);
560        }
561
562        #[cfg(feature = "mlx")]
563        {
564            let a = Self::to_array_uncached(a)?;
565            let b = Self::to_array_uncached(b)?;
566            Self::to_tensor(a.multiply(&b)?)
567        }
568    }
569
570    fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
571        #[cfg(not(feature = "mlx"))]
572        {
573            let _ = (x, positions, base);
574            anyhow::bail!("{}", Self::support().reason);
575        }
576
577        #[cfg(feature = "mlx")]
578        {
579            let dims = x.shape().dims();
580            if dims.len() != 4 {
581                anyhow::bail!("RoPE expects [batch, heads, seq, head_dim]");
582            }
583            if positions.is_empty() {
584                anyhow::bail!("RoPE positions cannot be empty");
585            }
586            let offset = i32::try_from(positions[0])
587                .map_err(|_| anyhow::anyhow!("RoPE position too large for MLX"))?;
588            let contiguous = positions
589                .iter()
590                .enumerate()
591                .all(|(i, &p)| p == positions[0] + i);
592            if !contiguous {
593                anyhow::bail!("MLX RoPE requires contiguous positions");
594            }
595
596            let x = Self::to_array_uncached(x)?;
597            Self::to_tensor(mlx_rs::fast::rope(
598                &x,
599                dims[3] as i32,
600                // `traditional = false` → rotate-half (NeoX/HF) convention, which
601                // matches how Llama/Qwen/Phi weights are trained and what the CPU
602                // kernel does. `true` (interleaved/GPT-J) produces garbage here.
603                false,
604                Some(base),
605                1.0,
606                offset,
607                None::<&mlx_rs::Array>,
608            )?)
609        }
610    }
611
612    fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
613        #[cfg(not(feature = "mlx"))]
614        {
615            let _ = (hidden, lm_head);
616            anyhow::bail!("{}", Self::support().reason);
617        }
618
619        #[cfg(feature = "mlx")]
620        {
621            let dims = hidden.shape().dims();
622            let hidden_size = dims[2];
623            let seq = dims[1];
624            let h = hidden.to_f32_cow();
625            let last = &h[(seq - 1) * hidden_size..seq * hidden_size];
626            let h_last = mlx_rs::Array::from_slice(last, &[1, hidden_size as i32]);
627            let head = self.to_array(lm_head)?.transpose()?;
628            let logits = h_last.matmul(&head)?;
629            Ok(logits.as_slice::<f32>().to_vec())
630        }
631    }
632
633    /// Grouped-query attention (GQA) on the Metal GPU via MLX.
634    ///
635    /// MLX's `fast::scaled_dot_product_attention` dispatches to an optimised Metal
636    /// kernel when `seq_q = 1` (every decode step).  At prefill we build an explicit
637    /// additive causal mask in order to correctly handle `seq_q < seq_k` (cached
638    /// prefix) — MLX's built-in `"causal"` mode would assume square attention.
639    fn gqa_attention(
640        &self,
641        q: &Tensor,
642        k: &Tensor,
643        v: &Tensor,
644        n_kv_heads: usize,
645        causal: bool,
646    ) -> Result<Tensor> {
647        #[cfg(not(feature = "mlx"))]
648        {
649            let _ = (q, k, v, n_kv_heads, causal);
650            anyhow::bail!("{}", Self::support().reason);
651        }
652
653        #[cfg(feature = "mlx")]
654        {
655            let qs = q.shape().dims().to_vec();
656            let ks = k.shape().dims().to_vec();
657            let (batch, n_heads, seq_q, head_dim) = (qs[0], qs[1], qs[2], qs[3]);
658            let seq_k = ks[2];
659            let scale = 1.0 / (head_dim as f32).sqrt();
660
661            // GQA (n_heads > n_kv_heads): mlx_rs 0.25.3's fast::scaled_dot_product_attention
662            // does not correctly handle grouped-query attention when query head count ≠
663            // key/value head count — it produces garbage logits. Fall back to the
664            // verified CPU reference kernel for all GQA models (Qwen2.5, Llama 3.x,
665            // Mistral). Standard MHA (n_heads == n_kv_heads) can use MLX directly.
666            if n_heads != n_kv_heads {
667                anyhow::bail!(
668                    "GQA (n_heads={n_heads} ≠ n_kv_heads={n_kv_heads}): using CPU attention"
669                );
670            }
671
672            // q: [batch, n_heads, seq_q, head_dim]
673            // k: [batch, n_kv_heads, seq_k, head_dim]
674            // MLX SDPA expects [batch, heads, seq, dim].
675            let q_arr = Self::to_array_uncached(q)?;
676            let k_arr = Self::to_array_uncached(k)?;
677            let v_arr = Self::to_array_uncached(v)?;
678
679            // Build the causal mask when needed.
680            // - seq_q = 1 (decode): every cached key is in the past → no masking needed.
681            // - seq_q > 1 (prefill): need a [seq_q, seq_k] upper-triangular -inf mask.
682            let mask_arr: Option<mlx_rs::Array> = if causal && seq_q > 1 {
683                let offset = seq_k.saturating_sub(seq_q);
684                let mut data = vec![0.0f32; seq_q * seq_k];
685                for qi in 0..seq_q {
686                    for ki in 0..seq_k {
687                        if ki > qi + offset {
688                            data[qi * seq_k + ki] = f32::NEG_INFINITY;
689                        }
690                    }
691                }
692                Some(mlx_rs::Array::from_slice(
693                    &data,
694                    &[seq_q as i32, seq_k as i32],
695                ))
696            } else {
697                None
698            };
699
700            // IntoOption<ScaledDotProductAttentionMask> is implemented for
701            // Option<ScaledDotProductAttentionMask>, so wrap our optional mask.
702            let mlx_mask = mask_arr
703                .as_ref()
704                .map(mlx_rs::fast::ScaledDotProductAttentionMask::from);
705            let out_arr = mlx_rs::fast::scaled_dot_product_attention(
706                &q_arr, &k_arr, &v_arr, scale, mlx_mask,
707            )?;
708
709            // out_arr: [batch, n_heads, seq_q, head_dim]
710            Self::to_tensor(out_arr)
711        }
712    }
713}
714
715#[cfg(not(target_os = "macos"))]
716impl MlxLlmOps {
717    fn support() -> MacGpuSupport {
718        MacGpuSupport {
719            available: false,
720            backend: "cpu",
721            reason: "MLX GPU execution is only available on macOS",
722        }
723    }
724
725    fn is_available() -> bool {
726        false
727    }
728
729    fn linear_3d(&self, _x: &Tensor, _weight: &Tensor) -> Result<Tensor> {
730        anyhow::bail!("MLX is only available on macOS")
731    }
732
733    fn rms_norm(&self, _x: &Tensor, _weight: &Tensor, _eps: f32) -> Result<Tensor> {
734        anyhow::bail!("MLX is only available on macOS")
735    }
736
737    fn layer_norm(
738        &self,
739        _x: &Tensor,
740        _weight: &Tensor,
741        _bias: Option<&Tensor>,
742        _eps: f32,
743    ) -> Result<Tensor> {
744        anyhow::bail!("MLX is only available on macOS")
745    }
746
747    fn silu(&self, _x: &Tensor) -> Result<Tensor> {
748        anyhow::bail!("MLX is only available on macOS")
749    }
750
751    fn gelu(&self, _x: &Tensor) -> Result<Tensor> {
752        anyhow::bail!("MLX is only available on macOS")
753    }
754
755    fn add(&self, _a: &Tensor, _b: &Tensor) -> Result<Tensor> {
756        anyhow::bail!("MLX is only available on macOS")
757    }
758
759    fn mul(&self, _a: &Tensor, _b: &Tensor) -> Result<Tensor> {
760        anyhow::bail!("MLX is only available on macOS")
761    }
762
763    fn apply_rope_positions(
764        &self,
765        _x: &Tensor,
766        _positions: &[usize],
767        _base: f32,
768    ) -> Result<Tensor> {
769        anyhow::bail!("MLX is only available on macOS")
770    }
771
772    fn gqa_attention(
773        &self,
774        _q: &Tensor,
775        _k: &Tensor,
776        _v: &Tensor,
777        _n_kv_heads: usize,
778        _causal: bool,
779    ) -> Result<Tensor> {
780        anyhow::bail!("MLX is only available on macOS")
781    }
782
783    fn logits_from_hidden(&self, _hidden: &Tensor, _lm_head: &Tensor) -> Result<Vec<f32>> {
784        anyhow::bail!("MLX is only available on macOS")
785    }
786}
787
788#[derive(Debug, Clone)]
789pub enum LlmBackendDispatch {
790    Cpu(CpuLlmBackend),
791    Metal(MetalLlmBackend),
792}
793
794impl LlmBackendDispatch {
795    /// Returns `true` when the backend is CPU-only (thread-safe for concurrent
796    /// compute calls). Returns `false` for GPU backends (Metal/MLX) whose
797    /// command buffers do not support concurrent encoding from multiple threads.
798    pub fn is_cpu(&self) -> bool {
799        matches!(self, Self::Cpu(_))
800    }
801
802    pub fn from_kind(kind: LlmBackendKind) -> Result<Self> {
803        Self::from_kind_with_model_bytes(kind, 0)
804    }
805
806    /// Select a backend, optionally accounting for the model's weight footprint.
807    ///
808    /// `model_bytes` is the total weight size in bytes (0 = unknown).  On Apple
809    /// Silicon (unified memory), Metal is chosen for Auto when the model fits
810    /// with a 1.5× KV-cache headroom factor; otherwise CPU is used to avoid
811    /// swapping GPU memory which kills throughput.
812    pub fn from_kind_with_model_bytes(kind: LlmBackendKind, model_bytes: u64) -> Result<Self> {
813        match kind {
814            LlmBackendKind::Cpu => Ok(Self::Cpu(CpuLlmBackend)),
815            LlmBackendKind::Auto if MetalLlmBackend::is_available() => {
816                let fits = metal_memory_fits(model_bytes);
817                if fits {
818                    tracing::debug!(
819                        model_bytes,
820                        "auto-backend: Metal (model fits in unified memory)"
821                    );
822                    Ok(Self::Metal(MetalLlmBackend::default()))
823                } else {
824                    tracing::info!(
825                        model_bytes,
826                        "auto-backend: CPU (model too large for Metal GPU memory headroom — \
827                         use --backend metal to force GPU anyway)"
828                    );
829                    Ok(Self::Cpu(CpuLlmBackend))
830                }
831            }
832            LlmBackendKind::Auto => Ok(Self::Cpu(CpuLlmBackend)),
833            LlmBackendKind::Metal if MetalLlmBackend::is_available() => {
834                Ok(Self::Metal(MetalLlmBackend::default()))
835            }
836            LlmBackendKind::Metal => {
837                let support = mac_gpu_support();
838                anyhow::bail!(
839                    "Metal/MLX generation backend is unavailable: {}",
840                    support.reason
841                )
842            }
843        }
844    }
845}
846
847/// Returns true when `model_bytes` fit in the Apple Silicon unified memory pool
848/// with 1.5× headroom for KV cache and activations.  Returns true when
849/// `model_bytes == 0` (unknown size) so we don't block models we can't measure.
850fn metal_memory_fits(model_bytes: u64) -> bool {
851    if model_bytes == 0 {
852        return true;
853    }
854    let total_ram = total_system_ram_bytes();
855    // Reserve 2 GB for the OS + app overhead; require 1.5× headroom for the model.
856    let usable = total_ram.saturating_sub(2 * 1024 * 1024 * 1024);
857    model_bytes as f64 * 1.5 <= usable as f64
858}
859
860/// Total system RAM in bytes via `sysctl hw.memsize` (macOS) or `/proc/meminfo`.
861/// Returns 0 on failure (treated as unknown → Metal allowed).
862pub fn total_system_ram_bytes() -> u64 {
863    #[cfg(target_os = "macos")]
864    {
865        let output = std::process::Command::new("sysctl")
866            .args(["-n", "hw.memsize"])
867            .output()
868            .ok();
869        if let Some(out) = output {
870            if let Ok(s) = std::str::from_utf8(&out.stdout) {
871                if let Ok(n) = s.trim().parse::<u64>() {
872                    return n;
873                }
874            }
875        }
876    }
877    0
878}
879
880impl LlmBackend for LlmBackendDispatch {
881    fn name(&self) -> &'static str {
882        match self {
883            Self::Cpu(b) => b.name(),
884            Self::Metal(b) => b.name(),
885        }
886    }
887
888    fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor> {
889        match self {
890            Self::Cpu(b) => b.linear_3d(x, weight),
891            Self::Metal(b) => b.linear_3d(x, weight),
892        }
893    }
894
895    fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
896        match self {
897            Self::Cpu(b) => b.rms_norm(x, weight, eps),
898            Self::Metal(b) => b.rms_norm(x, weight, eps),
899        }
900    }
901
902    fn layer_norm(
903        &self,
904        x: &Tensor,
905        weight: &Tensor,
906        bias: Option<&Tensor>,
907        eps: f32,
908    ) -> Result<Tensor> {
909        match self {
910            Self::Cpu(b) => b.layer_norm(x, weight, bias, eps),
911            Self::Metal(b) => b.layer_norm(x, weight, bias, eps),
912        }
913    }
914
915    fn silu(&self, x: &Tensor) -> Result<Tensor> {
916        match self {
917            Self::Cpu(b) => b.silu(x),
918            Self::Metal(b) => b.silu(x),
919        }
920    }
921
922    fn gelu(&self, x: &Tensor) -> Result<Tensor> {
923        match self {
924            Self::Cpu(b) => b.gelu(x),
925            Self::Metal(b) => b.gelu(x),
926        }
927    }
928
929    fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
930        match self {
931            Self::Cpu(backend) => backend.add(a, b),
932            Self::Metal(backend) => backend.add(a, b),
933        }
934    }
935
936    fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
937        match self {
938            Self::Cpu(backend) => backend.mul(a, b),
939            Self::Metal(backend) => backend.mul(a, b),
940        }
941    }
942
943    fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
944        match self {
945            Self::Cpu(b) => b.apply_rope_positions(x, positions, base),
946            Self::Metal(b) => b.apply_rope_positions(x, positions, base),
947        }
948    }
949
950    fn gqa_attention(
951        &self,
952        q: &Tensor,
953        k: &Tensor,
954        v: &Tensor,
955        n_kv_heads: usize,
956        causal: bool,
957    ) -> Result<Tensor> {
958        match self {
959            Self::Cpu(b) => b.gqa_attention(q, k, v, n_kv_heads, causal),
960            Self::Metal(b) => b.gqa_attention(q, k, v, n_kv_heads, causal),
961        }
962    }
963
964    fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
965        match self {
966            Self::Cpu(b) => b.logits_from_hidden(hidden, lm_head),
967            Self::Metal(b) => b.logits_from_hidden(hidden, lm_head),
968        }
969    }
970
971    fn all_logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<Vec<f32>>> {
972        match self {
973            Self::Cpu(b) => b.all_logits_from_hidden(hidden, lm_head),
974            Self::Metal(b) => b.all_logits_from_hidden(hidden, lm_head),
975        }
976    }
977}