ruvector_sparse_inference/model/
loader.rs

1//! Universal model loader trait and metadata
2
3use crate::error::{SparseInferenceError, ModelError};
4use crate::model::gguf::{GgufModel, GgufParser, GgufValue};
5
6type Result<T> = std::result::Result<T, SparseInferenceError>;
7use std::collections::HashMap;
8use std::path::Path;
9
10/// Universal model loader trait
11pub trait ModelLoader {
12    type Model;
13    type Error: std::error::Error;
14
15    /// Load model from bytes
16    fn load(data: &[u8]) -> Result<Self::Model>;
17
18    /// Load model from file path (native only)
19    #[cfg(not(target_arch = "wasm32"))]
20    fn load_file(path: &Path) -> Result<Self::Model> {
21        let data = std::fs::read(path).map_err(|e| {
22            SparseInferenceError::Model(ModelError::LoadFailed(format!("Failed to read file: {}", e)))
23        })?;
24        Self::load(&data)
25    }
26
27    /// Get model metadata
28    fn metadata(&self) -> &ModelMetadata;
29}
30
31/// Model metadata extracted from GGUF or other formats
32#[derive(Debug, Clone)]
33pub struct ModelMetadata {
34    pub architecture: ModelArchitecture,
35    pub hidden_size: usize,
36    pub intermediate_size: usize,
37    pub num_layers: usize,
38    pub num_heads: usize,
39    pub num_key_value_heads: Option<usize>,
40    pub vocab_size: usize,
41    pub max_position_embeddings: usize,
42    pub quantization: Option<QuantizationType>,
43    pub rope_theta: Option<f32>,
44    pub rope_scaling: Option<RopeScaling>,
45}
46
47impl ModelMetadata {
48    /// Extract metadata from GGUF model
49    pub fn from_gguf(model: &GgufModel) -> Result<Self> {
50        let arch_name = Self::get_string(&model.metadata, "general.architecture")
51            .map_err(|e| SparseInferenceError::Model(ModelError::InvalidConfig(e)))?;
52        let architecture = ModelArchitecture::from_str(&arch_name)
53            .map_err(|e| SparseInferenceError::Model(ModelError::InvalidConfig(e)))?;
54
55        let prefix = format!("{}", arch_name);
56
57        Ok(Self {
58            architecture,
59            hidden_size: Self::get_u32(&model.metadata, &format!("{}.embedding_length", prefix))? as usize,
60            intermediate_size: Self::get_u32(&model.metadata, &format!("{}.feed_forward_length", prefix))
61                .unwrap_or(0) as usize,
62            num_layers: Self::get_u32(&model.metadata, &format!("{}.block_count", prefix))? as usize,
63            num_heads: Self::get_u32(&model.metadata, &format!("{}.attention.head_count", prefix))? as usize,
64            num_key_value_heads: Self::get_u32(&model.metadata, &format!("{}.attention.head_count_kv", prefix))
65                .ok()
66                .map(|v| v as usize),
67            vocab_size: Self::get_u32(&model.metadata, "tokenizer.ggml.tokens")
68                .or_else(|_| Self::get_array_len(&model.metadata, "tokenizer.ggml.tokens"))
69                .unwrap_or(32000) as usize,
70            max_position_embeddings: Self::get_u32(&model.metadata, &format!("{}.context_length", prefix))
71                .unwrap_or(2048) as usize,
72            quantization: None, // Determined from tensor types
73            rope_theta: Self::get_f32(&model.metadata, &format!("{}.rope.freq_base", prefix)).ok(),
74            rope_scaling: None,
75        })
76    }
77
78    fn get_string(metadata: &HashMap<String, GgufValue>, key: &str) -> std::result::Result<String, String> {
79        match metadata.get(key) {
80            Some(GgufValue::String(s)) => Ok(s.clone()),
81            _ => Err(format!("Missing metadata: {}", key)),
82        }
83    }
84
85    fn get_u32(metadata: &HashMap<String, GgufValue>, key: &str) -> std::result::Result<u32, String> {
86        match metadata.get(key) {
87            Some(GgufValue::Uint32(v)) => Ok(*v),
88            Some(GgufValue::Uint64(v)) => Ok(*v as u32),
89            Some(GgufValue::Int32(v)) => Ok(*v as u32),
90            _ => Err(format!("Missing metadata: {}", key)),
91        }
92    }
93
94    fn get_f32(metadata: &HashMap<String, GgufValue>, key: &str) -> std::result::Result<f32, String> {
95        match metadata.get(key) {
96            Some(GgufValue::Float32(v)) => Ok(*v),
97            Some(GgufValue::Float64(v)) => Ok(*v as f32),
98            _ => Err(format!("Missing metadata: {}", key)),
99        }
100    }
101
102    fn get_array_len(metadata: &HashMap<String, GgufValue>, key: &str) -> std::result::Result<u32, String> {
103        match metadata.get(key) {
104            Some(GgufValue::Array(arr)) => Ok(arr.len() as u32),
105            _ => Err(format!("Missing metadata: {}", key)),
106        }
107    }
108}
109
110/// Model architecture type
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112pub enum ModelArchitecture {
113    Llama,
114    LFM2,
115    Bert,
116    Mistral,
117    Qwen,
118    Phi,
119    Gemma,
120}
121
122impl ModelArchitecture {
123    pub fn from_str(s: &str) -> std::result::Result<Self, String> {
124        match s.to_lowercase().as_str() {
125            "llama" => Ok(Self::Llama),
126            "lfm" | "lfm2" => Ok(Self::LFM2),
127            "bert" => Ok(Self::Bert),
128            "mistral" => Ok(Self::Mistral),
129            "qwen" | "qwen2" => Ok(Self::Qwen),
130            "phi" | "phi2" | "phi3" => Ok(Self::Phi),
131            "gemma" | "gemma2" => Ok(Self::Gemma),
132            _ => Err(format!("Unsupported architecture: {}", s)),
133        }
134    }
135}
136
137/// Quantization type
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub enum QuantizationType {
140    F32,
141    F16,
142    Q4_0,
143    Q4_1,
144    Q5_0,
145    Q5_1,
146    Q8_0,
147    Q8_1,
148    Q4_K,
149    Q5_K,
150    Q6_K,
151}
152
153/// RoPE scaling configuration
154#[derive(Debug, Clone)]
155pub struct RopeScaling {
156    pub scaling_type: String,
157    pub factor: f32,
158}
159
160impl Default for ModelMetadata {
161    fn default() -> Self {
162        Self {
163            architecture: ModelArchitecture::Llama,
164            hidden_size: 4096,
165            intermediate_size: 11008,
166            num_layers: 32,
167            num_heads: 32,
168            num_key_value_heads: None,
169            vocab_size: 32000,
170            max_position_embeddings: 2048,
171            quantization: None,
172            rope_theta: Some(10000.0),
173            rope_scaling: None,
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_architecture_parsing() {
184        assert_eq!(
185            ModelArchitecture::from_str("llama").unwrap(),
186            ModelArchitecture::Llama
187        );
188        assert_eq!(
189            ModelArchitecture::from_str("BERT").unwrap(),
190            ModelArchitecture::Bert
191        );
192    }
193
194    #[test]
195    fn test_default_metadata() {
196        let metadata = ModelMetadata::default();
197        assert_eq!(metadata.architecture, ModelArchitecture::Llama);
198        assert_eq!(metadata.hidden_size, 4096);
199    }
200}