Skip to main content

sapient_models/forward/
backend.rs

1//! Backend dispatch for native LLM forward passes.
2
3use anyhow::Result;
4use sapient_core::Tensor;
5
6use super::common;
7
8#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
9pub enum LlmBackendKind {
10    Cpu,
11    Metal,
12    #[default]
13    Auto,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct MacGpuSupport {
18    pub available: bool,
19    pub backend: &'static str,
20    pub reason: &'static str,
21}
22
23pub fn mac_gpu_support() -> MacGpuSupport {
24    MlxLlmOps::support()
25}
26
27impl std::str::FromStr for LlmBackendKind {
28    type Err = anyhow::Error;
29
30    fn from_str(value: &str) -> Result<Self> {
31        match value.to_ascii_lowercase().as_str() {
32            "cpu" => Ok(Self::Cpu),
33            "metal" => Ok(Self::Metal),
34            "auto" => Ok(Self::Auto),
35            other => anyhow::bail!("unsupported generation backend '{other}'"),
36        }
37    }
38}
39
40impl std::fmt::Display for LlmBackendKind {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            Self::Cpu => write!(f, "cpu"),
44            Self::Metal => write!(f, "metal"),
45            Self::Auto => write!(f, "auto"),
46        }
47    }
48}
49
50pub trait LlmBackend: Send + Sync {
51    fn name(&self) -> &'static str;
52    fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor>;
53
54    /// Linear projection with an optional bias added over the last dimension.
55    /// Backend-agnostic: computes `linear_3d` then folds in the bias on the host,
56    /// so every backend gets correct bias handling for free.
57    fn linear_3d_bias(&self, x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Result<Tensor> {
58        let y = self.linear_3d(x, weight)?;
59        match bias {
60            None => Ok(y),
61            Some(b) => common::add_bias_last_dim(&y, b),
62        }
63    }
64
65    fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor>;
66    fn layer_norm(
67        &self,
68        x: &Tensor,
69        weight: &Tensor,
70        bias: Option<&Tensor>,
71        eps: f32,
72    ) -> Result<Tensor>;
73    fn silu(&self, x: &Tensor) -> Result<Tensor>;
74    fn gelu(&self, x: &Tensor) -> Result<Tensor>;
75    fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor>;
76    fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor>;
77    fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor>;
78
79    /// RoPE over only the first `rotary_dim` channels (Phi partial rotary).
80    /// Computed on the CPU reference kernel for all backends — it is cheap and
81    /// avoids backend-specific partial-rotary support.
82    fn apply_rope_partial(
83        &self,
84        x: &Tensor,
85        positions: &[usize],
86        base: f32,
87        rotary_dim: usize,
88    ) -> Result<Tensor> {
89        common::apply_rope_partial(x, positions, base, rotary_dim)
90    }
91
92    fn gqa_attention(
93        &self,
94        q: &Tensor,
95        k: &Tensor,
96        v: &Tensor,
97        n_kv_heads: usize,
98        causal: bool,
99    ) -> Result<Tensor>;
100    fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>>;
101}
102
103#[derive(Debug, Default, Clone)]
104pub struct CpuLlmBackend;
105
106impl LlmBackend for CpuLlmBackend {
107    fn name(&self) -> &'static str {
108        "cpu"
109    }
110
111    fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor> {
112        common::linear_3d(x, weight)
113    }
114
115    fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
116        common::rms_norm(x, weight, eps)
117    }
118
119    fn layer_norm(
120        &self,
121        x: &Tensor,
122        weight: &Tensor,
123        bias: Option<&Tensor>,
124        eps: f32,
125    ) -> Result<Tensor> {
126        common::layer_norm(x, weight, bias, eps)
127    }
128
129    fn silu(&self, x: &Tensor) -> Result<Tensor> {
130        common::silu(x)
131    }
132
133    fn gelu(&self, x: &Tensor) -> Result<Tensor> {
134        common::gelu(x)
135    }
136
137    fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
138        common::add(a, b)
139    }
140
141    fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
142        common::mul(a, b)
143    }
144
145    fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
146        common::apply_rope_positions(x, positions, base)
147    }
148
149    fn gqa_attention(
150        &self,
151        q: &Tensor,
152        k: &Tensor,
153        v: &Tensor,
154        n_kv_heads: usize,
155        causal: bool,
156    ) -> Result<Tensor> {
157        common::gqa_attention(q, k, v, n_kv_heads, causal)
158    }
159
160    fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
161        common::logits_from_hidden(hidden, lm_head)
162    }
163}
164
165#[derive(Debug, Default, Clone)]
166pub struct MetalLlmBackend {
167    cpu: CpuLlmBackend,
168    mlx: MlxLlmOps,
169}
170
171impl MetalLlmBackend {
172    pub fn is_available() -> bool {
173        MlxLlmOps::is_available()
174    }
175
176    fn fallback(&self, op: &str) {
177        tracing::warn!(
178            op = op,
179            "native Metal LLM kernel is not implemented yet; using CPU reference kernel"
180        );
181    }
182}
183
184impl LlmBackend for MetalLlmBackend {
185    fn name(&self) -> &'static str {
186        "metal"
187    }
188
189    fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor> {
190        self.mlx
191            .linear_3d(x, weight)
192            .or_else(|e| {
193                tracing::warn!(op = "linear_3d", error = %e, "MLX op failed; using CPU reference kernel");
194                self.cpu.linear_3d(x, weight)
195            })
196    }
197
198    fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
199        self.mlx.rms_norm(x, weight, eps).or_else(|e| {
200            tracing::warn!(op = "rms_norm", error = %e, "MLX op failed; using CPU reference kernel");
201            self.cpu.rms_norm(x, weight, eps)
202        })
203    }
204
205    fn layer_norm(
206        &self,
207        x: &Tensor,
208        weight: &Tensor,
209        bias: Option<&Tensor>,
210        eps: f32,
211    ) -> Result<Tensor> {
212        self.mlx.layer_norm(x, weight, bias, eps).or_else(|e| {
213            tracing::warn!(op = "layer_norm", error = %e, "MLX op failed; using CPU reference kernel");
214            self.cpu.layer_norm(x, weight, bias, eps)
215        })
216    }
217
218    fn silu(&self, x: &Tensor) -> Result<Tensor> {
219        self.mlx.silu(x).or_else(|e| {
220            tracing::warn!(op = "silu", error = %e, "MLX op failed; using CPU reference kernel");
221            self.cpu.silu(x)
222        })
223    }
224
225    fn gelu(&self, x: &Tensor) -> Result<Tensor> {
226        self.mlx.gelu(x).or_else(|e| {
227            tracing::warn!(op = "gelu", error = %e, "MLX op failed; using CPU reference kernel");
228            self.cpu.gelu(x)
229        })
230    }
231
232    fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
233        self.mlx.add(a, b).or_else(|e| {
234            tracing::warn!(op = "add", error = %e, "MLX op failed; using CPU reference kernel");
235            self.cpu.add(a, b)
236        })
237    }
238
239    fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
240        self.mlx.mul(a, b).or_else(|e| {
241            tracing::warn!(op = "mul", error = %e, "MLX op failed; using CPU reference kernel");
242            self.cpu.mul(a, b)
243        })
244    }
245
246    fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
247        self.mlx
248            .apply_rope_positions(x, positions, base)
249            .or_else(|e| {
250                tracing::warn!(op = "rope", error = %e, "MLX op failed; using CPU reference kernel");
251                self.cpu.apply_rope_positions(x, positions, base)
252            })
253    }
254
255    fn gqa_attention(
256        &self,
257        q: &Tensor,
258        k: &Tensor,
259        v: &Tensor,
260        n_kv_heads: usize,
261        causal: bool,
262    ) -> Result<Tensor> {
263        // Keep this on the CPU reference path for now. Sapient's CPU attention
264        // mask handles cached decoding with q_len < kv_len; using MLX's
265        // causal shortcut without an offset would corrupt generation.
266        self.fallback("gqa_attention");
267        self.cpu.gqa_attention(q, k, v, n_kv_heads, causal)
268    }
269
270    fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
271        self.mlx.logits_from_hidden(hidden, lm_head).or_else(|e| {
272            tracing::warn!(op = "logits", error = %e, "MLX op failed; using CPU reference kernel");
273            self.cpu.logits_from_hidden(hidden, lm_head)
274        })
275    }
276}
277
278#[derive(Debug, Default, Clone)]
279struct MlxLlmOps;
280
281#[cfg(target_os = "macos")]
282impl MlxLlmOps {
283    fn support() -> MacGpuSupport {
284        #[cfg(all(target_os = "macos", target_arch = "aarch64", feature = "mlx"))]
285        {
286            MacGpuSupport {
287                available: true,
288                backend: "mlx",
289                reason: "Apple Silicon with MLX feature enabled",
290            }
291        }
292
293        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(feature = "mlx")))]
294        {
295            MacGpuSupport {
296                available: false,
297                backend: "cpu",
298                reason: "Apple Silicon detected, but Sapient was built without the sapient-models/mlx feature",
299            }
300        }
301
302        #[cfg(all(target_os = "macos", not(target_arch = "aarch64")))]
303        {
304            MacGpuSupport {
305                available: false,
306                backend: "cpu",
307                reason: "MLX GPU execution requires Apple Silicon; Intel Macs use CPU",
308            }
309        }
310
311        #[cfg(not(target_os = "macos"))]
312        {
313            MacGpuSupport {
314                available: false,
315                backend: "cpu",
316                reason: "MLX GPU execution is only available on macOS",
317            }
318        }
319    }
320
321    fn is_available() -> bool {
322        Self::support().available
323    }
324
325    #[cfg(feature = "mlx")]
326    fn to_shape(dims: &[usize]) -> Result<Vec<i32>> {
327        dims.iter()
328            .map(|&d| {
329                i32::try_from(d)
330                    .map_err(|_| anyhow::anyhow!("shape dimension too large for MLX: {d}"))
331            })
332            .collect()
333    }
334
335    #[cfg(feature = "mlx")]
336    fn to_array(tensor: &Tensor) -> Result<mlx_rs::Array> {
337        let shape = Self::to_shape(tensor.shape().dims())?;
338        // Weights are commonly F16/BF16; convert to F32 (MLX array dtype here)
339        // instead of asserting F32, which would panic on half-precision tensors.
340        let data = tensor.to_f32_cow();
341        Ok(mlx_rs::Array::from_slice(data.as_ref(), &shape))
342    }
343
344    #[cfg(feature = "mlx")]
345    fn to_tensor(array: mlx_rs::Array) -> Result<Tensor> {
346        let shape: Vec<usize> = array
347            .shape()
348            .iter()
349            .map(|&d| {
350                usize::try_from(d).map_err(|_| anyhow::anyhow!("negative MLX shape dimension: {d}"))
351            })
352            .collect::<Result<Vec<_>>>()?;
353        let data = array.as_slice::<f32>().to_vec();
354        Tensor::from_f32(&data, shape).map_err(|e| anyhow::anyhow!("{e}"))
355    }
356
357    fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor> {
358        #[cfg(not(feature = "mlx"))]
359        {
360            let _ = (x, weight);
361            anyhow::bail!("{}", Self::support().reason);
362        }
363
364        #[cfg(feature = "mlx")]
365        {
366            let dims = x.shape().dims();
367            if dims.len() != 3 {
368                anyhow::bail!("linear_3d expects [batch, seq, hidden]");
369            }
370            let (batch, seq, in_dim) = (dims[0], dims[1], dims[2]);
371            let w_dims = weight.shape().dims();
372            if w_dims.len() != 2 {
373                anyhow::bail!("linear weight must be 2-D");
374            }
375            let out_dim = w_dims[0];
376            if w_dims[1] != in_dim {
377                anyhow::bail!("linear weight in_dim mismatch: {} vs {in_dim}", w_dims[1]);
378            }
379
380            let x_arr = Self::to_array(x)?.reshape(&Self::to_shape(&[batch * seq, in_dim])?)?;
381            let w_arr = Self::to_array(weight)?.transpose()?;
382            let y = x_arr.matmul(&w_arr)?;
383            Self::to_tensor(y.reshape(&Self::to_shape(&[batch, seq, out_dim])?)?)
384        }
385    }
386
387    fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
388        #[cfg(not(feature = "mlx"))]
389        {
390            let _ = (x, weight, eps);
391            anyhow::bail!("{}", Self::support().reason);
392        }
393
394        #[cfg(feature = "mlx")]
395        {
396            let x = Self::to_array(x)?;
397            let weight = Self::to_array(weight)?;
398            Self::to_tensor(mlx_rs::fast::rms_norm(&x, &weight, eps)?)
399        }
400    }
401
402    fn layer_norm(
403        &self,
404        x: &Tensor,
405        weight: &Tensor,
406        bias: Option<&Tensor>,
407        eps: f32,
408    ) -> Result<Tensor> {
409        #[cfg(not(feature = "mlx"))]
410        {
411            let _ = (x, weight, bias, eps);
412            anyhow::bail!("{}", Self::support().reason);
413        }
414
415        #[cfg(feature = "mlx")]
416        {
417            let x = Self::to_array(x)?;
418            let weight = Self::to_array(weight)?;
419            let bias = bias.map(Self::to_array).transpose()?;
420            Self::to_tensor(mlx_rs::fast::layer_norm(
421                &x,
422                Some(&weight),
423                bias.as_ref(),
424                eps,
425            )?)
426        }
427    }
428
429    fn silu(&self, x: &Tensor) -> Result<Tensor> {
430        #[cfg(not(feature = "mlx"))]
431        {
432            let _ = x;
433            anyhow::bail!("{}", Self::support().reason);
434        }
435
436        #[cfg(feature = "mlx")]
437        {
438            let x = Self::to_array(x)?;
439            Self::to_tensor(mlx_rs::nn::silu(&x)?)
440        }
441    }
442
443    fn gelu(&self, x: &Tensor) -> Result<Tensor> {
444        #[cfg(not(feature = "mlx"))]
445        {
446            let _ = x;
447            anyhow::bail!("{}", Self::support().reason);
448        }
449
450        #[cfg(feature = "mlx")]
451        {
452            let x = Self::to_array(x)?;
453            Self::to_tensor(mlx_rs::nn::gelu(&x)?)
454        }
455    }
456
457    fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
458        #[cfg(not(feature = "mlx"))]
459        {
460            let _ = (a, b);
461            anyhow::bail!("{}", Self::support().reason);
462        }
463
464        #[cfg(feature = "mlx")]
465        {
466            let a = Self::to_array(a)?;
467            let b = Self::to_array(b)?;
468            Self::to_tensor(a.add(&b)?)
469        }
470    }
471
472    fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
473        #[cfg(not(feature = "mlx"))]
474        {
475            let _ = (a, b);
476            anyhow::bail!("{}", Self::support().reason);
477        }
478
479        #[cfg(feature = "mlx")]
480        {
481            let a = Self::to_array(a)?;
482            let b = Self::to_array(b)?;
483            Self::to_tensor(a.multiply(&b)?)
484        }
485    }
486
487    fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
488        #[cfg(not(feature = "mlx"))]
489        {
490            let _ = (x, positions, base);
491            anyhow::bail!("{}", Self::support().reason);
492        }
493
494        #[cfg(feature = "mlx")]
495        {
496            let dims = x.shape().dims();
497            if dims.len() != 4 {
498                anyhow::bail!("RoPE expects [batch, heads, seq, head_dim]");
499            }
500            if positions.is_empty() {
501                anyhow::bail!("RoPE positions cannot be empty");
502            }
503            let offset = i32::try_from(positions[0])
504                .map_err(|_| anyhow::anyhow!("RoPE position too large for MLX"))?;
505            let contiguous = positions
506                .iter()
507                .enumerate()
508                .all(|(i, &p)| p == positions[0] + i);
509            if !contiguous {
510                anyhow::bail!("MLX RoPE requires contiguous positions");
511            }
512
513            let x = Self::to_array(x)?;
514            Self::to_tensor(mlx_rs::fast::rope(
515                &x,
516                dims[3] as i32,
517                // `traditional = false` → rotate-half (NeoX/HF) convention, which
518                // matches how Llama/Qwen/Phi weights are trained and what the CPU
519                // kernel does. `true` (interleaved/GPT-J) produces garbage here.
520                false,
521                Some(base),
522                1.0,
523                offset,
524                None::<&mlx_rs::Array>,
525            )?)
526        }
527    }
528
529    fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
530        #[cfg(not(feature = "mlx"))]
531        {
532            let _ = (hidden, lm_head);
533            anyhow::bail!("{}", Self::support().reason);
534        }
535
536        #[cfg(feature = "mlx")]
537        {
538            let dims = hidden.shape().dims();
539            let hidden_size = dims[2];
540            let seq = dims[1];
541            let h = hidden.to_f32_cow();
542            let last = &h[(seq - 1) * hidden_size..seq * hidden_size];
543            let h_last = mlx_rs::Array::from_slice(last, &[1, hidden_size as i32]);
544            let head = Self::to_array(lm_head)?.transpose()?;
545            let logits = h_last.matmul(&head)?;
546            Ok(logits.as_slice::<f32>().to_vec())
547        }
548    }
549}
550
551#[cfg(not(target_os = "macos"))]
552impl MlxLlmOps {
553    fn support() -> MacGpuSupport {
554        MacGpuSupport {
555            available: false,
556            backend: "cpu",
557            reason: "MLX GPU execution is only available on macOS",
558        }
559    }
560
561    fn is_available() -> bool {
562        false
563    }
564
565    fn linear_3d(&self, _x: &Tensor, _weight: &Tensor) -> Result<Tensor> {
566        anyhow::bail!("MLX is only available on macOS")
567    }
568
569    fn rms_norm(&self, _x: &Tensor, _weight: &Tensor, _eps: f32) -> Result<Tensor> {
570        anyhow::bail!("MLX is only available on macOS")
571    }
572
573    fn layer_norm(
574        &self,
575        _x: &Tensor,
576        _weight: &Tensor,
577        _bias: Option<&Tensor>,
578        _eps: f32,
579    ) -> Result<Tensor> {
580        anyhow::bail!("MLX is only available on macOS")
581    }
582
583    fn silu(&self, _x: &Tensor) -> Result<Tensor> {
584        anyhow::bail!("MLX is only available on macOS")
585    }
586
587    fn gelu(&self, _x: &Tensor) -> Result<Tensor> {
588        anyhow::bail!("MLX is only available on macOS")
589    }
590
591    fn add(&self, _a: &Tensor, _b: &Tensor) -> Result<Tensor> {
592        anyhow::bail!("MLX is only available on macOS")
593    }
594
595    fn mul(&self, _a: &Tensor, _b: &Tensor) -> Result<Tensor> {
596        anyhow::bail!("MLX is only available on macOS")
597    }
598
599    fn apply_rope_positions(
600        &self,
601        _x: &Tensor,
602        _positions: &[usize],
603        _base: f32,
604    ) -> Result<Tensor> {
605        anyhow::bail!("MLX is only available on macOS")
606    }
607
608    fn logits_from_hidden(&self, _hidden: &Tensor, _lm_head: &Tensor) -> Result<Vec<f32>> {
609        anyhow::bail!("MLX is only available on macOS")
610    }
611}
612
613#[derive(Debug, Clone)]
614pub enum LlmBackendDispatch {
615    Cpu(CpuLlmBackend),
616    Metal(MetalLlmBackend),
617}
618
619impl LlmBackendDispatch {
620    pub fn from_kind(kind: LlmBackendKind) -> Result<Self> {
621        match kind {
622            LlmBackendKind::Cpu => Ok(Self::Cpu(CpuLlmBackend)),
623            LlmBackendKind::Auto if MetalLlmBackend::is_available() => {
624                Ok(Self::Metal(MetalLlmBackend::default()))
625            }
626            LlmBackendKind::Auto => Ok(Self::Cpu(CpuLlmBackend)),
627            LlmBackendKind::Metal if MetalLlmBackend::is_available() => {
628                Ok(Self::Metal(MetalLlmBackend::default()))
629            }
630            LlmBackendKind::Metal => {
631                let support = mac_gpu_support();
632                anyhow::bail!(
633                    "Metal/MLX generation backend is unavailable: {}",
634                    support.reason
635                )
636            }
637        }
638    }
639}
640
641impl LlmBackend for LlmBackendDispatch {
642    fn name(&self) -> &'static str {
643        match self {
644            Self::Cpu(b) => b.name(),
645            Self::Metal(b) => b.name(),
646        }
647    }
648
649    fn linear_3d(&self, x: &Tensor, weight: &Tensor) -> Result<Tensor> {
650        match self {
651            Self::Cpu(b) => b.linear_3d(x, weight),
652            Self::Metal(b) => b.linear_3d(x, weight),
653        }
654    }
655
656    fn rms_norm(&self, x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
657        match self {
658            Self::Cpu(b) => b.rms_norm(x, weight, eps),
659            Self::Metal(b) => b.rms_norm(x, weight, eps),
660        }
661    }
662
663    fn layer_norm(
664        &self,
665        x: &Tensor,
666        weight: &Tensor,
667        bias: Option<&Tensor>,
668        eps: f32,
669    ) -> Result<Tensor> {
670        match self {
671            Self::Cpu(b) => b.layer_norm(x, weight, bias, eps),
672            Self::Metal(b) => b.layer_norm(x, weight, bias, eps),
673        }
674    }
675
676    fn silu(&self, x: &Tensor) -> Result<Tensor> {
677        match self {
678            Self::Cpu(b) => b.silu(x),
679            Self::Metal(b) => b.silu(x),
680        }
681    }
682
683    fn gelu(&self, x: &Tensor) -> Result<Tensor> {
684        match self {
685            Self::Cpu(b) => b.gelu(x),
686            Self::Metal(b) => b.gelu(x),
687        }
688    }
689
690    fn add(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
691        match self {
692            Self::Cpu(backend) => backend.add(a, b),
693            Self::Metal(backend) => backend.add(a, b),
694        }
695    }
696
697    fn mul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
698        match self {
699            Self::Cpu(backend) => backend.mul(a, b),
700            Self::Metal(backend) => backend.mul(a, b),
701        }
702    }
703
704    fn apply_rope_positions(&self, x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
705        match self {
706            Self::Cpu(b) => b.apply_rope_positions(x, positions, base),
707            Self::Metal(b) => b.apply_rope_positions(x, positions, base),
708        }
709    }
710
711    fn gqa_attention(
712        &self,
713        q: &Tensor,
714        k: &Tensor,
715        v: &Tensor,
716        n_kv_heads: usize,
717        causal: bool,
718    ) -> Result<Tensor> {
719        match self {
720            Self::Cpu(b) => b.gqa_attention(q, k, v, n_kv_heads, causal),
721            Self::Metal(b) => b.gqa_attention(q, k, v, n_kv_heads, causal),
722        }
723    }
724
725    fn logits_from_hidden(&self, hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
726        match self {
727            Self::Cpu(b) => b.logits_from_hidden(hidden, lm_head),
728            Self::Metal(b) => b.logits_from_hidden(hidden, lm_head),
729        }
730    }
731}