Skip to main content

sapient_models/forward/
mod.rs

1//! Real transformer forward passes for text generation.
2
3pub mod backend;
4pub mod common;
5mod llama;
6mod phi;
7
8use std::path::{Path, PathBuf};
9
10use anyhow::{bail, Context, Result};
11use sapient_hub::model_info::{ArchType, ModelInfo};
12use sapient_hub::resolver::WeightFormat;
13
14use crate::gguf_weights::{load_gguf_hf_weights, load_gguf_hf_weights_mmap};
15
16pub use backend::{mac_gpu_support, total_system_ram_bytes, LlmBackendKind, MacGpuSupport};
17pub use llama::LlamaForward;
18pub use phi::PhiForward;
19
20/// Architecture-specific inference engine with KV-cache support.
21pub enum ForwardEngine {
22    Llama(LlamaForward),
23    Phi(PhiForward),
24}
25
26fn weight_format_from_paths(weight_paths: &[PathBuf]) -> WeightFormat {
27    match weight_paths
28        .first()
29        .and_then(|p| p.extension())
30        .and_then(|e| e.to_str())
31    {
32        Some("gguf") => WeightFormat::Gguf,
33        Some("safetensors") => WeightFormat::Safetensors,
34        Some("bin") => WeightFormat::PyTorchBin,
35        _ => WeightFormat::Unknown,
36    }
37}
38
39impl ForwardEngine {
40    pub fn from_pretrained(info: ModelInfo, weight_paths: &[PathBuf]) -> Result<Self> {
41        Self::from_weight_paths(info, weight_paths)
42    }
43
44    pub fn from_weight_paths(info: ModelInfo, weight_paths: &[PathBuf]) -> Result<Self> {
45        Self::from_weight_paths_with_backend(info, weight_paths, LlmBackendKind::Auto)
46    }
47
48    pub fn from_weight_paths_with_backend(
49        info: ModelInfo,
50        weight_paths: &[PathBuf],
51        backend: LlmBackendKind,
52    ) -> Result<Self> {
53        match weight_format_from_paths(weight_paths) {
54            WeightFormat::Gguf => {
55                let path = weight_paths
56                    .first()
57                    .context("GGUF model has no weight path")?;
58                Self::from_gguf_with_backend(info, path, backend)
59            }
60            WeightFormat::Safetensors | WeightFormat::PyTorchBin => match info.arch {
61                ArchType::Llama | ArchType::Qwen | ArchType::Gemma | ArchType::Mixtral => {
62                    Ok(Self::Llama(LlamaForward::from_files_with_backend(
63                        info,
64                        weight_paths,
65                        backend,
66                    )?))
67                }
68                ArchType::Phi => Ok(Self::Phi(PhiForward::from_files_with_backend(
69                    info,
70                    weight_paths,
71                    backend,
72                )?)),
73                other => bail!(
74                    "architecture {other:?} does not yet have a native forward engine — \
75                     use safetensors weights for Llama, Phi, or Qwen models"
76                ),
77            },
78            WeightFormat::Unknown => bail!("unknown or missing weight file format"),
79        }
80    }
81
82    pub fn from_gguf(info: ModelInfo, path: &Path) -> Result<Self> {
83        Self::from_gguf_with_backend(info, path, LlmBackendKind::Auto)
84    }
85
86    pub fn from_gguf_with_backend(
87        info: ModelInfo,
88        path: &Path,
89        backend: LlmBackendKind,
90    ) -> Result<Self> {
91        let weights = load_gguf_hf_weights(path)?;
92        Self::from_gguf_weights(info, weights, backend)
93    }
94
95    /// Load via memory-mapping — Q4_0/Q8_0 tensors are zero-copy from disk.
96    /// The OS pages in weight blocks on demand; only active layers are resident.
97    pub fn from_gguf_mmap_with_backend(
98        info: ModelInfo,
99        path: &Path,
100        backend: LlmBackendKind,
101    ) -> Result<Self> {
102        let weights = load_gguf_hf_weights_mmap(path)?;
103        Self::from_gguf_weights(info, weights, backend)
104    }
105
106    fn from_gguf_weights(
107        info: ModelInfo,
108        weights: std::collections::HashMap<String, sapient_core::Tensor>,
109        backend: LlmBackendKind,
110    ) -> Result<Self> {
111        match info.arch {
112            ArchType::Llama | ArchType::Qwen | ArchType::Gemma | ArchType::Mixtral => {
113                Ok(Self::Llama(LlamaForward::from_weights_with_backend(
114                    info, weights, backend,
115                )?))
116            }
117            ArchType::Phi => {
118                bail!("GGUF Phi models are not yet supported — use safetensors weights")
119            }
120            other => bail!(
121                "architecture {other:?} does not yet support GGUF loading — \
122                 try a Llama-family GGUF model or use safetensors weights"
123            ),
124        }
125    }
126
127    pub fn reset_cache(&mut self) {
128        match self {
129            Self::Llama(f) => f.reset_cache(),
130            Self::Phi(f) => f.reset_cache(),
131        }
132    }
133
134    pub fn forward_logits(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Vec<f32>> {
135        match self {
136            Self::Llama(f) => f.forward_logits(input_ids, use_cache),
137            Self::Phi(f) => f.forward_logits(input_ids, use_cache),
138        }
139    }
140
141    /// Run the model on `input_ids` WITHOUT updating the KV cache and return
142    /// logits for ALL positions. Used by speculative decoding to verify K draft
143    /// tokens in a single target-model forward pass.
144    pub fn forward_all_logits(&mut self, input_ids: &[u32]) -> Result<Vec<Vec<f32>>> {
145        match self {
146            Self::Llama(f) => f.forward_all_logits(input_ids),
147            Self::Phi(f) => f.forward_all_logits(input_ids),
148        }
149    }
150
151    pub fn embed(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
152        match self {
153            Self::Llama(f) => f.embed(input_ids),
154            Self::Phi(f) => f.embed(input_ids),
155        }
156    }
157}