Skip to main content

sapient_models/forward/
mod.rs

1//! Real transformer forward passes for text generation.
2
3pub mod backend;
4pub mod common;
5mod llama;
6mod phi;
7
8use std::path::{Path, PathBuf};
9
10use anyhow::{bail, Context, Result};
11use sapient_hub::model_info::{ArchType, ModelInfo};
12use sapient_hub::resolver::WeightFormat;
13
14use crate::gguf_weights::load_gguf_hf_weights;
15
16pub use backend::{mac_gpu_support, LlmBackendKind, MacGpuSupport};
17pub use llama::LlamaForward;
18pub use phi::PhiForward;
19
20/// Architecture-specific inference engine with KV-cache support.
21pub enum ForwardEngine {
22    Llama(LlamaForward),
23    Phi(PhiForward),
24}
25
26fn weight_format_from_paths(weight_paths: &[PathBuf]) -> WeightFormat {
27    match weight_paths
28        .first()
29        .and_then(|p| p.extension())
30        .and_then(|e| e.to_str())
31    {
32        Some("gguf") => WeightFormat::Gguf,
33        Some("safetensors") => WeightFormat::Safetensors,
34        Some("bin") => WeightFormat::PyTorchBin,
35        _ => WeightFormat::Unknown,
36    }
37}
38
39impl ForwardEngine {
40    pub fn from_pretrained(info: ModelInfo, weight_paths: &[PathBuf]) -> Result<Self> {
41        Self::from_weight_paths(info, weight_paths)
42    }
43
44    pub fn from_weight_paths(info: ModelInfo, weight_paths: &[PathBuf]) -> Result<Self> {
45        Self::from_weight_paths_with_backend(info, weight_paths, LlmBackendKind::Auto)
46    }
47
48    pub fn from_weight_paths_with_backend(
49        info: ModelInfo,
50        weight_paths: &[PathBuf],
51        backend: LlmBackendKind,
52    ) -> Result<Self> {
53        match weight_format_from_paths(weight_paths) {
54            WeightFormat::Gguf => {
55                let path = weight_paths
56                    .first()
57                    .context("GGUF model has no weight path")?;
58                Self::from_gguf_with_backend(info, path, backend)
59            }
60            WeightFormat::Safetensors | WeightFormat::PyTorchBin => match info.arch {
61                ArchType::Llama | ArchType::Qwen | ArchType::Gemma | ArchType::Mixtral => {
62                    Ok(Self::Llama(LlamaForward::from_files_with_backend(
63                        info,
64                        weight_paths,
65                        backend,
66                    )?))
67                }
68                ArchType::Phi => Ok(Self::Phi(PhiForward::from_files_with_backend(
69                    info,
70                    weight_paths,
71                    backend,
72                )?)),
73                other => bail!(
74                    "architecture {other:?} does not yet have a native forward engine — \
75                     use safetensors weights for Llama, Phi, or Qwen models"
76                ),
77            },
78            WeightFormat::Unknown => bail!("unknown or missing weight file format"),
79        }
80    }
81
82    pub fn from_gguf(info: ModelInfo, path: &Path) -> Result<Self> {
83        Self::from_gguf_with_backend(info, path, LlmBackendKind::Auto)
84    }
85
86    pub fn from_gguf_with_backend(
87        info: ModelInfo,
88        path: &Path,
89        backend: LlmBackendKind,
90    ) -> Result<Self> {
91        let weights = load_gguf_hf_weights(path)?;
92        match info.arch {
93            ArchType::Llama | ArchType::Qwen | ArchType::Gemma | ArchType::Mixtral => {
94                Ok(Self::Llama(LlamaForward::from_weights_with_backend(
95                    info, weights, backend,
96                )?))
97            }
98            ArchType::Phi => {
99                bail!("GGUF Phi models are not yet supported — use safetensors weights")
100            }
101            other => bail!(
102                "architecture {other:?} does not yet support GGUF loading — \
103                 try a Llama-family GGUF model or use safetensors weights"
104            ),
105        }
106    }
107
108    pub fn reset_cache(&mut self) {
109        match self {
110            Self::Llama(f) => f.reset_cache(),
111            Self::Phi(f) => f.reset_cache(),
112        }
113    }
114
115    pub fn forward_logits(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Vec<f32>> {
116        match self {
117            Self::Llama(f) => f.forward_logits(input_ids, use_cache),
118            Self::Phi(f) => f.forward_logits(input_ids, use_cache),
119        }
120    }
121
122    pub fn embed(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
123        match self {
124            Self::Llama(f) => f.embed(input_ids),
125            Self::Phi(f) => f.embed(input_ids),
126        }
127    }
128}