sapient_models/forward/
mod.rs1pub mod backend;
4pub mod common;
5mod llama;
6mod phi;
7
8use std::path::{Path, PathBuf};
9
10use anyhow::{bail, Context, Result};
11use sapient_hub::model_info::{ArchType, ModelInfo};
12use sapient_hub::resolver::WeightFormat;
13
14use crate::gguf_weights::load_gguf_hf_weights;
15
16pub use backend::{mac_gpu_support, LlmBackendKind, MacGpuSupport};
17pub use llama::LlamaForward;
18pub use phi::PhiForward;
19
20pub enum ForwardEngine {
22 Llama(LlamaForward),
23 Phi(PhiForward),
24}
25
26fn weight_format_from_paths(weight_paths: &[PathBuf]) -> WeightFormat {
27 match weight_paths
28 .first()
29 .and_then(|p| p.extension())
30 .and_then(|e| e.to_str())
31 {
32 Some("gguf") => WeightFormat::Gguf,
33 Some("safetensors") => WeightFormat::Safetensors,
34 Some("bin") => WeightFormat::PyTorchBin,
35 _ => WeightFormat::Unknown,
36 }
37}
38
39impl ForwardEngine {
40 pub fn from_pretrained(info: ModelInfo, weight_paths: &[PathBuf]) -> Result<Self> {
41 Self::from_weight_paths(info, weight_paths)
42 }
43
44 pub fn from_weight_paths(info: ModelInfo, weight_paths: &[PathBuf]) -> Result<Self> {
45 Self::from_weight_paths_with_backend(info, weight_paths, LlmBackendKind::Auto)
46 }
47
48 pub fn from_weight_paths_with_backend(
49 info: ModelInfo,
50 weight_paths: &[PathBuf],
51 backend: LlmBackendKind,
52 ) -> Result<Self> {
53 match weight_format_from_paths(weight_paths) {
54 WeightFormat::Gguf => {
55 let path = weight_paths
56 .first()
57 .context("GGUF model has no weight path")?;
58 Self::from_gguf_with_backend(info, path, backend)
59 }
60 WeightFormat::Safetensors | WeightFormat::PyTorchBin => match info.arch {
61 ArchType::Llama | ArchType::Qwen | ArchType::Gemma | ArchType::Mixtral => {
62 Ok(Self::Llama(LlamaForward::from_files_with_backend(
63 info,
64 weight_paths,
65 backend,
66 )?))
67 }
68 ArchType::Phi => Ok(Self::Phi(PhiForward::from_files_with_backend(
69 info,
70 weight_paths,
71 backend,
72 )?)),
73 other => bail!(
74 "architecture {other:?} does not yet have a native forward engine — \
75 use safetensors weights for Llama, Phi, or Qwen models"
76 ),
77 },
78 WeightFormat::Unknown => bail!("unknown or missing weight file format"),
79 }
80 }
81
82 pub fn from_gguf(info: ModelInfo, path: &Path) -> Result<Self> {
83 Self::from_gguf_with_backend(info, path, LlmBackendKind::Auto)
84 }
85
86 pub fn from_gguf_with_backend(
87 info: ModelInfo,
88 path: &Path,
89 backend: LlmBackendKind,
90 ) -> Result<Self> {
91 let weights = load_gguf_hf_weights(path)?;
92 match info.arch {
93 ArchType::Llama | ArchType::Qwen | ArchType::Gemma | ArchType::Mixtral => {
94 Ok(Self::Llama(LlamaForward::from_weights_with_backend(
95 info, weights, backend,
96 )?))
97 }
98 ArchType::Phi => {
99 bail!("GGUF Phi models are not yet supported — use safetensors weights")
100 }
101 other => bail!(
102 "architecture {other:?} does not yet support GGUF loading — \
103 try a Llama-family GGUF model or use safetensors weights"
104 ),
105 }
106 }
107
108 pub fn reset_cache(&mut self) {
109 match self {
110 Self::Llama(f) => f.reset_cache(),
111 Self::Phi(f) => f.reset_cache(),
112 }
113 }
114
115 pub fn forward_logits(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Vec<f32>> {
116 match self {
117 Self::Llama(f) => f.forward_logits(input_ids, use_cache),
118 Self::Phi(f) => f.forward_logits(input_ids, use_cache),
119 }
120 }
121
122 pub fn embed(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
123 match self {
124 Self::Llama(f) => f.embed(input_ids),
125 Self::Phi(f) => f.embed(input_ids),
126 }
127 }
128}