sapient_models/forward/
mod.rs1pub mod backend;
4pub mod common;
5mod llama;
6mod phi;
7
8use std::path::{Path, PathBuf};
9
10use anyhow::{bail, Context, Result};
11use sapient_hub::model_info::{ArchType, ModelInfo};
12use sapient_hub::resolver::WeightFormat;
13
14use crate::gguf_weights::{load_gguf_hf_weights, load_gguf_hf_weights_mmap};
15
16pub use backend::{mac_gpu_support, total_system_ram_bytes, LlmBackendKind, MacGpuSupport};
17pub use llama::LlamaForward;
18pub use phi::PhiForward;
19
20pub enum ForwardEngine {
22 Llama(LlamaForward),
23 Phi(PhiForward),
24}
25
26fn weight_format_from_paths(weight_paths: &[PathBuf]) -> WeightFormat {
27 match weight_paths
28 .first()
29 .and_then(|p| p.extension())
30 .and_then(|e| e.to_str())
31 {
32 Some("gguf") => WeightFormat::Gguf,
33 Some("safetensors") => WeightFormat::Safetensors,
34 Some("bin") => WeightFormat::PyTorchBin,
35 _ => WeightFormat::Unknown,
36 }
37}
38
39impl ForwardEngine {
40 pub fn from_pretrained(info: ModelInfo, weight_paths: &[PathBuf]) -> Result<Self> {
41 Self::from_weight_paths(info, weight_paths)
42 }
43
44 pub fn from_weight_paths(info: ModelInfo, weight_paths: &[PathBuf]) -> Result<Self> {
45 Self::from_weight_paths_with_backend(info, weight_paths, LlmBackendKind::Auto)
46 }
47
48 pub fn from_weight_paths_with_backend(
49 info: ModelInfo,
50 weight_paths: &[PathBuf],
51 backend: LlmBackendKind,
52 ) -> Result<Self> {
53 match weight_format_from_paths(weight_paths) {
54 WeightFormat::Gguf => {
55 let path = weight_paths
56 .first()
57 .context("GGUF model has no weight path")?;
58 Self::from_gguf_with_backend(info, path, backend)
59 }
60 WeightFormat::Safetensors | WeightFormat::PyTorchBin => match info.arch {
61 ArchType::Llama | ArchType::Qwen | ArchType::Gemma | ArchType::Mixtral => {
62 Ok(Self::Llama(LlamaForward::from_files_with_backend(
63 info,
64 weight_paths,
65 backend,
66 )?))
67 }
68 ArchType::Phi => Ok(Self::Phi(PhiForward::from_files_with_backend(
69 info,
70 weight_paths,
71 backend,
72 )?)),
73 other => bail!(
74 "architecture {other:?} does not yet have a native forward engine — \
75 use safetensors weights for Llama, Phi, or Qwen models"
76 ),
77 },
78 WeightFormat::Unknown => bail!("unknown or missing weight file format"),
79 }
80 }
81
82 pub fn from_gguf(info: ModelInfo, path: &Path) -> Result<Self> {
83 Self::from_gguf_with_backend(info, path, LlmBackendKind::Auto)
84 }
85
86 pub fn from_gguf_with_backend(
87 info: ModelInfo,
88 path: &Path,
89 backend: LlmBackendKind,
90 ) -> Result<Self> {
91 let weights = load_gguf_hf_weights(path)?;
92 Self::from_gguf_weights(info, weights, backend)
93 }
94
95 pub fn from_gguf_mmap_with_backend(
98 info: ModelInfo,
99 path: &Path,
100 backend: LlmBackendKind,
101 ) -> Result<Self> {
102 let weights = load_gguf_hf_weights_mmap(path)?;
103 Self::from_gguf_weights(info, weights, backend)
104 }
105
106 fn from_gguf_weights(
107 info: ModelInfo,
108 weights: std::collections::HashMap<String, sapient_core::Tensor>,
109 backend: LlmBackendKind,
110 ) -> Result<Self> {
111 match info.arch {
112 ArchType::Llama | ArchType::Qwen | ArchType::Gemma | ArchType::Mixtral => {
113 Ok(Self::Llama(LlamaForward::from_weights_with_backend(
114 info, weights, backend,
115 )?))
116 }
117 ArchType::Phi => {
118 bail!("GGUF Phi models are not yet supported — use safetensors weights")
119 }
120 other => bail!(
121 "architecture {other:?} does not yet support GGUF loading — \
122 try a Llama-family GGUF model or use safetensors weights"
123 ),
124 }
125 }
126
127 pub fn reset_cache(&mut self) {
128 match self {
129 Self::Llama(f) => f.reset_cache(),
130 Self::Phi(f) => f.reset_cache(),
131 }
132 }
133
134 pub fn forward_logits(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Vec<f32>> {
135 match self {
136 Self::Llama(f) => f.forward_logits(input_ids, use_cache),
137 Self::Phi(f) => f.forward_logits(input_ids, use_cache),
138 }
139 }
140
141 pub fn forward_all_logits(&mut self, input_ids: &[u32]) -> Result<Vec<Vec<f32>>> {
145 match self {
146 Self::Llama(f) => f.forward_all_logits(input_ids),
147 Self::Phi(f) => f.forward_all_logits(input_ids),
148 }
149 }
150
151 pub fn embed(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
152 match self {
153 Self::Llama(f) => f.embed(input_ids),
154 Self::Phi(f) => f.embed(input_ids),
155 }
156 }
157}