1use anyhow::Result;
4use sapient_core::Tensor;
5
6use super::common;
7
8#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
9pub enum LlmBackendKind {
10 Cpu,
11 Metal,
12 #[default]
13 Auto,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct MacGpuSupport {
18 pub available: bool,
19 pub backend: &'static str,
20 pub reason: &'static str,
21}
22
23pub fn mac_gpu_support() -> MacGpuSupport {
24 MlxLlmOps::support()
25}
26
27impl std::str::FromStr for LlmBackendKind {
28 type Err = anyhow::Error;
29
30 fn from_str(value: &str) -> Result<Self> {
31 match value.to_ascii_lowercase().as_str() {
32 "cpu" => Ok(Self::Cpu),
33 "metal" => Ok(Self::Metal),
34 "auto" => Ok(Self::Auto),
35 other => anyhow::bail!("unsupported generation backend '{other}'"),
36 }
37 }
38}
39
40impl std::fmt::Display for LlmBackendKind {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 Self::Cpu => write!(f, "cpu"),
44 Self::Metal => write!(f, "metal"),
45 Self::Auto => write!(f, "auto"),
46 }
47 }
48}
49
50pub trait LlmBackend: Send + Sync {
51 fn name(&self) -> &'static str;
52 fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor>;
53
54 fn linear_3d_bias(&self, x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Result<Tensor> {
58 let y = self.linear_3d(x, weight)?;
59 match bias {
60 None => Ok(y),
61 Some(b) => common::add_bias_last_dim(&y, b),
62 }
63 }
64
65 fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor>;
66 fn layer_norm(
67 &self,
68 x: &Tensor,
69 weight: &Tensor,
70 bias: Option<&Tensor>,
71 eps: f32,
72 ) -> Result<Tensor>;
73 fn silu(&self, x: &Tensor) -> Result<Tensor>;
74 fn gelu(&self, x: &Tensor) -> Result<Tensor>;
75 fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor>;
76 fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor>;
77 fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor>;
78
79 fn apply_rope_partial(
83 &self,
84 x: &Tensor,
85 positions: &[usize],
86 base: f32,
87 rotary_dim: usize,
88 ) -> Result<Tensor> {
89 common::apply_rope_partial(x, positions, base, rotary_dim)
90 }
91
92 fn gqa_attention(
93 &self,
94 q: &Tensor,
95 k: &Tensor,
96 v: &Tensor,
97 n_kv_heads: usize,
98 causal: bool,
99 ) -> Result<Tensor>;
100 fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>>;
101
102 fn all_logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<Vec<f32>>> {
106 common::all_logits_from_hidden(hidden, lm_head)
107 }
108}
109
110#[derive(Debug, Default, Clone)]
111pub struct CpuLlmBackend;
112
113impl LlmBackend for CpuLlmBackend {
114 fn name(&self) -> &'static str {
115 "cpu"
116 }
117
118 fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor> {
119 common::linear_3d(x, weight)
120 }
121
122 fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
123 common::rms_norm(x, weight, eps)
124 }
125
126 fn layer_norm(
127 &self,
128 x: &Tensor,
129 weight: &Tensor,
130 bias: Option<&Tensor>,
131 eps: f32,
132 ) -> Result<Tensor> {
133 common::layer_norm(x, weight, bias, eps)
134 }
135
136 fn silu(&self, x: &Tensor) -> Result<Tensor> {
137 common::silu(x)
138 }
139
140 fn gelu(&self, x: &Tensor) -> Result<Tensor> {
141 common::gelu(x)
142 }
143
144 fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
145 common::add(a, b)
146 }
147
148 fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
149 common::mul(a, b)
150 }
151
152 fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
153 common::apply_rope_positions(x, positions, base)
154 }
155
156 fn gqa_attention(
157 &self,
158 q: &Tensor,
159 k: &Tensor,
160 v: &Tensor,
161 n_kv_heads: usize,
162 causal: bool,
163 ) -> Result<Tensor> {
164 common::gqa_attention(q, k, v, n_kv_heads, causal)
165 }
166
167 fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
168 common::logits_from_hidden(hidden, lm_head)
169 }
170}
171
172#[derive(Debug, Default, Clone)]
173pub struct MetalLlmBackend {
174 cpu: CpuLlmBackend,
175 mlx: MlxLlmOps,
176}
177
178impl MetalLlmBackend {
179 pub fn is_available() -> bool {
180 MlxLlmOps::is_available()
181 }
182}
183
184impl LlmBackend for MetalLlmBackend {
185 fn name(&self) -> &'static str {
186 "metal"
187 }
188
189 fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor> {
190 self.mlx
191 .linear_3d(x, weight)
192 .or_else(|e| {
193 tracing::warn!(op = "linear_3d", error = %e, "MLX op failed; using CPU reference kernel");
194 self.cpu.linear_3d(x, weight)
195 })
196 }
197
198 fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
199 self.mlx.rms_norm(x, weight, eps).or_else(|e| {
200 tracing::warn!(op = "rms_norm", error = %e, "MLX op failed; using CPU reference kernel");
201 self.cpu.rms_norm(x, weight, eps)
202 })
203 }
204
205 fn layer_norm(
206 &self,
207 x: &Tensor,
208 weight: &Tensor,
209 bias: Option<&Tensor>,
210 eps: f32,
211 ) -> Result<Tensor> {
212 self.mlx.layer_norm(x, weight, bias, eps).or_else(|e| {
213 tracing::warn!(op = "layer_norm", error = %e, "MLX op failed; using CPU reference kernel");
214 self.cpu.layer_norm(x, weight, bias, eps)
215 })
216 }
217
218 fn silu(&self, x: &Tensor) -> Result<Tensor> {
219 self.mlx.silu(x).or_else(|e| {
220 tracing::warn!(op = "silu", error = %e, "MLX op failed; using CPU reference kernel");
221 self.cpu.silu(x)
222 })
223 }
224
225 fn gelu(&self, x: &Tensor) -> Result<Tensor> {
226 self.mlx.gelu(x).or_else(|e| {
227 tracing::warn!(op = "gelu", error = %e, "MLX op failed; using CPU reference kernel");
228 self.cpu.gelu(x)
229 })
230 }
231
232 fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
233 self.mlx.add(a, b).or_else(|e| {
234 tracing::warn!(op = "add", error = %e, "MLX op failed; using CPU reference kernel");
235 self.cpu.add(a, b)
236 })
237 }
238
239 fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
240 self.mlx.mul(a, b).or_else(|e| {
241 tracing::warn!(op = "mul", error = %e, "MLX op failed; using CPU reference kernel");
242 self.cpu.mul(a, b)
243 })
244 }
245
246 fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
247 self.mlx
248 .apply_rope_positions(x, positions, base)
249 .or_else(|e| {
250 tracing::warn!(op = "rope", error = %e, "MLX op failed; using CPU reference kernel");
251 self.cpu.apply_rope_positions(x, positions, base)
252 })
253 }
254
255 fn gqa_attention(
256 &self,
257 q: &Tensor,
258 k: &Tensor,
259 v: &Tensor,
260 n_kv_heads: usize,
261 causal: bool,
262 ) -> Result<Tensor> {
263 self.mlx
267 .gqa_attention(q, k, v, n_kv_heads, causal)
268 .or_else(|e| {
269 tracing::warn!(op = "gqa_attention", error = %e,
270 "MLX attention failed; falling back to CPU");
271 self.cpu.gqa_attention(q, k, v, n_kv_heads, causal)
272 })
273 }
274
275 fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
276 self.mlx.logits_from_hidden(hidden, lm_head).or_else(|e| {
277 tracing::warn!(op = "logits", error = %e, "MLX op failed; using CPU reference kernel");
278 self.cpu.logits_from_hidden(hidden, lm_head)
279 })
280 }
281}
282
283#[cfg(feature = "mlx")]
289type MlxWeightCache =
290 std::sync::Arc<parking_lot::Mutex<std::collections::HashMap<usize, mlx_rs::Array>>>;
291
292#[derive(Clone)]
293struct MlxLlmOps {
294 #[cfg(feature = "mlx")]
297 cache: MlxWeightCache,
298}
299
300impl std::fmt::Debug for MlxLlmOps {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 f.debug_struct("MlxLlmOps").finish()
303 }
304}
305
306#[allow(clippy::derivable_impls)]
309impl Default for MlxLlmOps {
310 fn default() -> Self {
311 Self {
312 #[cfg(feature = "mlx")]
313 cache: std::sync::Arc::new(parking_lot::Mutex::new(std::collections::HashMap::new())),
314 }
315 }
316}
317
318#[cfg(target_os = "macos")]
319impl MlxLlmOps {
320 fn support() -> MacGpuSupport {
321 #[cfg(all(target_os = "macos", target_arch = "aarch64", feature = "mlx"))]
322 {
323 MacGpuSupport {
324 available: true,
325 backend: "mlx",
326 reason: "Apple Silicon with MLX feature enabled",
327 }
328 }
329
330 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(feature = "mlx")))]
331 {
332 MacGpuSupport {
333 available: false,
334 backend: "cpu",
335 reason: "Apple Silicon detected, but Sapient was built without the sapient-models/mlx feature",
336 }
337 }
338
339 #[cfg(all(target_os = "macos", not(target_arch = "aarch64")))]
340 {
341 MacGpuSupport {
342 available: false,
343 backend: "cpu",
344 reason: "MLX GPU execution requires Apple Silicon; Intel Macs use CPU",
345 }
346 }
347
348 #[cfg(not(target_os = "macos"))]
349 {
350 MacGpuSupport {
351 available: false,
352 backend: "cpu",
353 reason: "MLX GPU execution is only available on macOS",
354 }
355 }
356 }
357
358 fn is_available() -> bool {
359 Self::support().available
360 }
361
362 #[cfg(feature = "mlx")]
363 fn to_shape(dims: &[usize]) -> Result<Vec<i32>> {
364 dims.iter()
365 .map(|&d| {
366 i32::try_from(d)
367 .map_err(|_| anyhow::anyhow!("shape dimension too large for MLX: {d}"))
368 })
369 .collect()
370 }
371
372 #[cfg(feature = "mlx")]
378 fn to_array(&self, tensor: &Tensor) -> Result<mlx_rs::Array> {
379 let ptr_key = tensor.as_bytes().as_ptr() as usize;
380
381 {
383 let guard = self.cache.lock();
384 if let Some(arr) = guard.get(&ptr_key) {
385 return Ok(arr.clone());
386 }
387 }
388
389 let shape = Self::to_shape(tensor.shape().dims())?;
391 let data = tensor.to_f32_cow();
392 let arr = mlx_rs::Array::from_slice(data.as_ref(), &shape);
393
394 let numel = tensor.numel();
398 if numel > 256 {
399 self.cache.lock().insert(ptr_key, arr.clone());
400 }
401
402 Ok(arr)
403 }
404
405 #[cfg(feature = "mlx")]
415 fn to_array_uncached(tensor: &Tensor) -> Result<mlx_rs::Array> {
416 let shape = Self::to_shape(tensor.shape().dims())?;
417 let numel = tensor.numel();
418 let cow = tensor.to_f32_cow();
419 let data = &cow[..numel.min(cow.len())];
422 Ok(mlx_rs::Array::from_slice(data, &shape))
423 }
424
425 #[cfg(feature = "mlx")]
426 fn to_tensor(array: mlx_rs::Array) -> Result<Tensor> {
427 let shape: Vec<usize> = array
428 .shape()
429 .iter()
430 .map(|&d| {
431 usize::try_from(d).map_err(|_| anyhow::anyhow!("negative MLX shape dimension: {d}"))
432 })
433 .collect::<Result<Vec<_>>>()?;
434 let data = array.as_slice::<f32>().to_vec();
435 Tensor::from_f32(&data, shape).map_err(|e| anyhow::anyhow!("{e}"))
436 }
437
438 fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor> {
439 #[cfg(not(feature = "mlx"))]
440 {
441 let _ = (x, weight);
442 anyhow::bail!("{}", Self::support().reason);
443 }
444
445 #[cfg(feature = "mlx")]
446 {
447 let dims = x.shape().dims();
448 if dims.len() != 3 {
449 anyhow::bail!("linear_3d expects [batch, seq, hidden]");
450 }
451 let (batch, seq, in_dim) = (dims[0], dims[1], dims[2]);
452 let w_dims = weight.shape().dims();
453 if w_dims.len() != 2 {
454 anyhow::bail!("linear weight must be 2-D");
455 }
456 let out_dim = w_dims[0];
457 if w_dims[1] != in_dim {
458 anyhow::bail!("linear weight in_dim mismatch: {} vs {in_dim}", w_dims[1]);
459 }
460
461 let x_arr =
463 Self::to_array_uncached(x)?.reshape(&Self::to_shape(&[batch * seq, in_dim])?)?;
464 let w_arr = self.to_array(weight)?.transpose()?;
465 let y = x_arr.matmul(&w_arr)?;
466 Self::to_tensor(y.reshape(&Self::to_shape(&[batch, seq, out_dim])?)?)
467 }
468 }
469
470 fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
471 #[cfg(not(feature = "mlx"))]
472 {
473 let _ = (x, weight, eps);
474 anyhow::bail!("{}", Self::support().reason);
475 }
476
477 #[cfg(feature = "mlx")]
478 {
479 let x = Self::to_array_uncached(x)?;
480 let weight = self.to_array(weight)?;
481 Self::to_tensor(mlx_rs::fast::rms_norm(&x, &weight, eps)?)
482 }
483 }
484
485 fn layer_norm(
486 &self,
487 x: &Tensor,
488 weight: &Tensor,
489 bias: Option<&Tensor>,
490 eps: f32,
491 ) -> Result<Tensor> {
492 #[cfg(not(feature = "mlx"))]
493 {
494 let _ = (x, weight, bias, eps);
495 anyhow::bail!("{}", Self::support().reason);
496 }
497
498 #[cfg(feature = "mlx")]
499 {
500 let x = Self::to_array_uncached(x)?;
501 let weight = self.to_array(weight)?;
502 let bias = bias.map(|b| self.to_array(b)).transpose()?;
503 Self::to_tensor(mlx_rs::fast::layer_norm(
504 &x,
505 Some(&weight),
506 bias.as_ref(),
507 eps,
508 )?)
509 }
510 }
511
512 fn silu(&self, x: &Tensor) -> Result<Tensor> {
513 #[cfg(not(feature = "mlx"))]
514 {
515 let _ = x;
516 anyhow::bail!("{}", Self::support().reason);
517 }
518
519 #[cfg(feature = "mlx")]
520 {
521 let x = Self::to_array_uncached(x)?;
522 Self::to_tensor(mlx_rs::nn::silu(&x)?)
523 }
524 }
525
526 fn gelu(&self, x: &Tensor) -> Result<Tensor> {
527 #[cfg(not(feature = "mlx"))]
528 {
529 let _ = x;
530 anyhow::bail!("{}", Self::support().reason);
531 }
532
533 #[cfg(feature = "mlx")]
534 {
535 let x = Self::to_array_uncached(x)?;
536 Self::to_tensor(mlx_rs::nn::gelu(&x)?)
537 }
538 }
539
540 fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
541 #[cfg(not(feature = "mlx"))]
542 {
543 let _ = (a, b);
544 anyhow::bail!("{}", Self::support().reason);
545 }
546
547 #[cfg(feature = "mlx")]
548 {
549 let a = Self::to_array_uncached(a)?;
550 let b = Self::to_array_uncached(b)?;
551 Self::to_tensor(a.add(&b)?)
552 }
553 }
554
555 fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
556 #[cfg(not(feature = "mlx"))]
557 {
558 let _ = (a, b);
559 anyhow::bail!("{}", Self::support().reason);
560 }
561
562 #[cfg(feature = "mlx")]
563 {
564 let a = Self::to_array_uncached(a)?;
565 let b = Self::to_array_uncached(b)?;
566 Self::to_tensor(a.multiply(&b)?)
567 }
568 }
569
570 fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
571 #[cfg(not(feature = "mlx"))]
572 {
573 let _ = (x, positions, base);
574 anyhow::bail!("{}", Self::support().reason);
575 }
576
577 #[cfg(feature = "mlx")]
578 {
579 let dims = x.shape().dims();
580 if dims.len() != 4 {
581 anyhow::bail!("RoPE expects [batch, heads, seq, head_dim]");
582 }
583 if positions.is_empty() {
584 anyhow::bail!("RoPE positions cannot be empty");
585 }
586 let offset = i32::try_from(positions[0])
587 .map_err(|_| anyhow::anyhow!("RoPE position too large for MLX"))?;
588 let contiguous = positions
589 .iter()
590 .enumerate()
591 .all(|(i, &p)| p == positions[0] + i);
592 if !contiguous {
593 anyhow::bail!("MLX RoPE requires contiguous positions");
594 }
595
596 let x = Self::to_array_uncached(x)?;
597 Self::to_tensor(mlx_rs::fast::rope(
598 &x,
599 dims[3] as i32,
600 false,
604 Some(base),
605 1.0,
606 offset,
607 None::<&mlx_rs::Array>,
608 )?)
609 }
610 }
611
612 fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
613 #[cfg(not(feature = "mlx"))]
614 {
615 let _ = (hidden, lm_head);
616 anyhow::bail!("{}", Self::support().reason);
617 }
618
619 #[cfg(feature = "mlx")]
620 {
621 let dims = hidden.shape().dims();
622 let hidden_size = dims[2];
623 let seq = dims[1];
624 let h = hidden.to_f32_cow();
625 let last = &h[(seq - 1) * hidden_size..seq * hidden_size];
626 let h_last = mlx_rs::Array::from_slice(last, &[1, hidden_size as i32]);
627 let head = self.to_array(lm_head)?.transpose()?;
628 let logits = h_last.matmul(&head)?;
629 Ok(logits.as_slice::<f32>().to_vec())
630 }
631 }
632
633 fn gqa_attention(
640 &self,
641 q: &Tensor,
642 k: &Tensor,
643 v: &Tensor,
644 n_kv_heads: usize,
645 causal: bool,
646 ) -> Result<Tensor> {
647 #[cfg(not(feature = "mlx"))]
648 {
649 let _ = (q, k, v, n_kv_heads, causal);
650 anyhow::bail!("{}", Self::support().reason);
651 }
652
653 #[cfg(feature = "mlx")]
654 {
655 let qs = q.shape().dims().to_vec();
656 let ks = k.shape().dims().to_vec();
657 let (batch, n_heads, seq_q, head_dim) = (qs[0], qs[1], qs[2], qs[3]);
658 let seq_k = ks[2];
659 let scale = 1.0 / (head_dim as f32).sqrt();
660
661 if n_heads != n_kv_heads {
667 anyhow::bail!(
668 "GQA (n_heads={n_heads} ≠ n_kv_heads={n_kv_heads}): using CPU attention"
669 );
670 }
671
672 let q_arr = Self::to_array_uncached(q)?;
676 let k_arr = Self::to_array_uncached(k)?;
677 let v_arr = Self::to_array_uncached(v)?;
678
679 let mask_arr: Option<mlx_rs::Array> = if causal && seq_q > 1 {
683 let offset = seq_k.saturating_sub(seq_q);
684 let mut data = vec![0.0f32; seq_q * seq_k];
685 for qi in 0..seq_q {
686 for ki in 0..seq_k {
687 if ki > qi + offset {
688 data[qi * seq_k + ki] = f32::NEG_INFINITY;
689 }
690 }
691 }
692 Some(mlx_rs::Array::from_slice(
693 &data,
694 &[seq_q as i32, seq_k as i32],
695 ))
696 } else {
697 None
698 };
699
700 let mlx_mask = mask_arr
703 .as_ref()
704 .map(mlx_rs::fast::ScaledDotProductAttentionMask::from);
705 let out_arr = mlx_rs::fast::scaled_dot_product_attention(
706 &q_arr, &k_arr, &v_arr, scale, mlx_mask,
707 )?;
708
709 Self::to_tensor(out_arr)
711 }
712 }
713}
714
715#[cfg(not(target_os = "macos"))]
716impl MlxLlmOps {
717 fn support() -> MacGpuSupport {
718 MacGpuSupport {
719 available: false,
720 backend: "cpu",
721 reason: "MLX GPU execution is only available on macOS",
722 }
723 }
724
725 fn is_available() -> bool {
726 false
727 }
728
729 fn linear_3d(&self, _x: &Tensor, _weight: &Tensor) -> Result<Tensor> {
730 anyhow::bail!("MLX is only available on macOS")
731 }
732
733 fn rms_norm(&self, _x: &Tensor, _weight: &Tensor, _eps: f32) -> Result<Tensor> {
734 anyhow::bail!("MLX is only available on macOS")
735 }
736
737 fn layer_norm(
738 &self,
739 _x: &Tensor,
740 _weight: &Tensor,
741 _bias: Option<&Tensor>,
742 _eps: f32,
743 ) -> Result<Tensor> {
744 anyhow::bail!("MLX is only available on macOS")
745 }
746
747 fn silu(&self, _x: &Tensor) -> Result<Tensor> {
748 anyhow::bail!("MLX is only available on macOS")
749 }
750
751 fn gelu(&self, _x: &Tensor) -> Result<Tensor> {
752 anyhow::bail!("MLX is only available on macOS")
753 }
754
755 fn add(&self, _a: &Tensor, _b: &Tensor) -> Result<Tensor> {
756 anyhow::bail!("MLX is only available on macOS")
757 }
758
759 fn mul(&self, _a: &Tensor, _b: &Tensor) -> Result<Tensor> {
760 anyhow::bail!("MLX is only available on macOS")
761 }
762
763 fn apply_rope_positions(
764 &self,
765 _x: &Tensor,
766 _positions: &[usize],
767 _base: f32,
768 ) -> Result<Tensor> {
769 anyhow::bail!("MLX is only available on macOS")
770 }
771
772 fn gqa_attention(
773 &self,
774 _q: &Tensor,
775 _k: &Tensor,
776 _v: &Tensor,
777 _n_kv_heads: usize,
778 _causal: bool,
779 ) -> Result<Tensor> {
780 anyhow::bail!("MLX is only available on macOS")
781 }
782
783 fn logits_from_hidden(&self, _hidden: &Tensor, _lm_head: &Tensor) -> Result<Vec<f32>> {
784 anyhow::bail!("MLX is only available on macOS")
785 }
786}
787
788#[derive(Debug, Clone)]
789pub enum LlmBackendDispatch {
790 Cpu(CpuLlmBackend),
791 Metal(MetalLlmBackend),
792}
793
794impl LlmBackendDispatch {
795 pub fn from_kind(kind: LlmBackendKind) -> Result<Self> {
796 Self::from_kind_with_model_bytes(kind, 0)
797 }
798
799 pub fn from_kind_with_model_bytes(kind: LlmBackendKind, model_bytes: u64) -> Result<Self> {
806 match kind {
807 LlmBackendKind::Cpu => Ok(Self::Cpu(CpuLlmBackend)),
808 LlmBackendKind::Auto if MetalLlmBackend::is_available() => {
809 let fits = metal_memory_fits(model_bytes);
810 if fits {
811 tracing::debug!(
812 model_bytes,
813 "auto-backend: Metal (model fits in unified memory)"
814 );
815 Ok(Self::Metal(MetalLlmBackend::default()))
816 } else {
817 tracing::info!(
818 model_bytes,
819 "auto-backend: CPU (model too large for Metal GPU memory headroom — \
820 use --backend metal to force GPU anyway)"
821 );
822 Ok(Self::Cpu(CpuLlmBackend))
823 }
824 }
825 LlmBackendKind::Auto => Ok(Self::Cpu(CpuLlmBackend)),
826 LlmBackendKind::Metal if MetalLlmBackend::is_available() => {
827 Ok(Self::Metal(MetalLlmBackend::default()))
828 }
829 LlmBackendKind::Metal => {
830 let support = mac_gpu_support();
831 anyhow::bail!(
832 "Metal/MLX generation backend is unavailable: {}",
833 support.reason
834 )
835 }
836 }
837 }
838}
839
840fn metal_memory_fits(model_bytes: u64) -> bool {
844 if model_bytes == 0 {
845 return true;
846 }
847 let total_ram = total_system_ram_bytes();
848 let usable = total_ram.saturating_sub(2 * 1024 * 1024 * 1024);
850 model_bytes as f64 * 1.5 <= usable as f64
851}
852
853pub fn total_system_ram_bytes() -> u64 {
856 #[cfg(target_os = "macos")]
857 {
858 let output = std::process::Command::new("sysctl")
859 .args(["-n", "hw.memsize"])
860 .output()
861 .ok();
862 if let Some(out) = output {
863 if let Ok(s) = std::str::from_utf8(&out.stdout) {
864 if let Ok(n) = s.trim().parse::<u64>() {
865 return n;
866 }
867 }
868 }
869 }
870 0
871}
872
873impl LlmBackend for LlmBackendDispatch {
874 fn name(&self) -> &'static str {
875 match self {
876 Self::Cpu(b) => b.name(),
877 Self::Metal(b) => b.name(),
878 }
879 }
880
881 fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor> {
882 match self {
883 Self::Cpu(b) => b.linear_3d(x, weight),
884 Self::Metal(b) => b.linear_3d(x, weight),
885 }
886 }
887
888 fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
889 match self {
890 Self::Cpu(b) => b.rms_norm(x, weight, eps),
891 Self::Metal(b) => b.rms_norm(x, weight, eps),
892 }
893 }
894
895 fn layer_norm(
896 &self,
897 x: &Tensor,
898 weight: &Tensor,
899 bias: Option<&Tensor>,
900 eps: f32,
901 ) -> Result<Tensor> {
902 match self {
903 Self::Cpu(b) => b.layer_norm(x, weight, bias, eps),
904 Self::Metal(b) => b.layer_norm(x, weight, bias, eps),
905 }
906 }
907
908 fn silu(&self, x: &Tensor) -> Result<Tensor> {
909 match self {
910 Self::Cpu(b) => b.silu(x),
911 Self::Metal(b) => b.silu(x),
912 }
913 }
914
915 fn gelu(&self, x: &Tensor) -> Result<Tensor> {
916 match self {
917 Self::Cpu(b) => b.gelu(x),
918 Self::Metal(b) => b.gelu(x),
919 }
920 }
921
922 fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
923 match self {
924 Self::Cpu(backend) => backend.add(a, b),
925 Self::Metal(backend) => backend.add(a, b),
926 }
927 }
928
929 fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
930 match self {
931 Self::Cpu(backend) => backend.mul(a, b),
932 Self::Metal(backend) => backend.mul(a, b),
933 }
934 }
935
936 fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
937 match self {
938 Self::Cpu(b) => b.apply_rope_positions(x, positions, base),
939 Self::Metal(b) => b.apply_rope_positions(x, positions, base),
940 }
941 }
942
943 fn gqa_attention(
944 &self,
945 q: &Tensor,
946 k: &Tensor,
947 v: &Tensor,
948 n_kv_heads: usize,
949 causal: bool,
950 ) -> Result<Tensor> {
951 match self {
952 Self::Cpu(b) => b.gqa_attention(q, k, v, n_kv_heads, causal),
953 Self::Metal(b) => b.gqa_attention(q, k, v, n_kv_heads, causal),
954 }
955 }
956
957 fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
958 match self {
959 Self::Cpu(b) => b.logits_from_hidden(hidden, lm_head),
960 Self::Metal(b) => b.logits_from_hidden(hidden, lm_head),
961 }
962 }
963
964 fn all_logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<Vec<f32>>> {
965 match self {
966 Self::Cpu(b) => b.all_logits_from_hidden(hidden, lm_head),
967 Self::Metal(b) => b.all_logits_from_hidden(hidden, lm_head),
968 }
969 }
970}