Skip to main content

sapient_models/
registry.rs

1//! Architecture registry — dispatch from ArchType → graph builder.
2
3use anyhow::{bail, Result};
4use sapient_ir::graph::Graph;
5
6use sapient_hub::model_info::{ArchType, ModelInfo};
7
8use crate::architectures::{bert, gemma, gpt2, llama, mixtral, phi, qwen};
9
10/// A fully-built model graph ready for execution.
11pub struct ModelGraph {
12    pub graph: Graph,
13    pub info: ModelInfo,
14}
15
16/// Build a SAPIENT `Graph` from a parsed `ModelInfo`.
17///
18/// The returned graph has named inputs:
19/// - `"input_ids"` — (batch, seq_len) i32 token IDs
20/// - `"attention_mask"` — (batch, seq_len) i32 (1=attend, 0=mask) [optional]
21/// - `"position_ids"` — (batch, seq_len) i32 [optional, for RoPE offset]
22///
23/// And named outputs:
24/// - `"logits"` — (batch, seq_len, vocab_size) f32
25pub fn build_graph(info: &ModelInfo) -> Result<ModelGraph> {
26    let graph = match &info.arch {
27        ArchType::Llama => llama::build(info)?,
28        ArchType::Phi => phi::build(info)?,
29        ArchType::Gemma => gemma::build(info)?,
30        ArchType::Gpt2 => gpt2::build(info)?,
31        ArchType::Bert => bert::build(info)?,
32        ArchType::Qwen => qwen::build(info)?,
33        ArchType::Mixtral => mixtral::build(info)?,
34        ArchType::Falcon => {
35            // Falcon uses a Llama-like architecture with ALiBi — use Llama builder.
36            llama::build(info)?
37        }
38        ArchType::Unknown(name) => bail!(
39            "Unsupported architecture: '{name}'. \
40             If this is a GGUF model, load it directly via `Pipeline::from_gguf(path)`."
41        ),
42        _ => bail!("Architecture {:?} not yet implemented", info.arch),
43    };
44
45    Ok(ModelGraph {
46        graph,
47        info: info.clone(),
48    })
49}