ruvector_sparse_inference/model/
loader.rs1use crate::error::{SparseInferenceError, ModelError};
4use crate::model::gguf::{GgufModel, GgufParser, GgufValue};
5
6type Result<T> = std::result::Result<T, SparseInferenceError>;
7use std::collections::HashMap;
8use std::path::Path;
9
10pub trait ModelLoader {
12 type Model;
13 type Error: std::error::Error;
14
15 fn load(data: &[u8]) -> Result<Self::Model>;
17
18 #[cfg(not(target_arch = "wasm32"))]
20 fn load_file(path: &Path) -> Result<Self::Model> {
21 let data = std::fs::read(path).map_err(|e| {
22 SparseInferenceError::Model(ModelError::LoadFailed(format!("Failed to read file: {}", e)))
23 })?;
24 Self::load(&data)
25 }
26
27 fn metadata(&self) -> &ModelMetadata;
29}
30
31#[derive(Debug, Clone)]
33pub struct ModelMetadata {
34 pub architecture: ModelArchitecture,
35 pub hidden_size: usize,
36 pub intermediate_size: usize,
37 pub num_layers: usize,
38 pub num_heads: usize,
39 pub num_key_value_heads: Option<usize>,
40 pub vocab_size: usize,
41 pub max_position_embeddings: usize,
42 pub quantization: Option<QuantizationType>,
43 pub rope_theta: Option<f32>,
44 pub rope_scaling: Option<RopeScaling>,
45}
46
47impl ModelMetadata {
48 pub fn from_gguf(model: &GgufModel) -> Result<Self> {
50 let arch_name = Self::get_string(&model.metadata, "general.architecture")
51 .map_err(|e| SparseInferenceError::Model(ModelError::InvalidConfig(e)))?;
52 let architecture = ModelArchitecture::from_str(&arch_name)
53 .map_err(|e| SparseInferenceError::Model(ModelError::InvalidConfig(e)))?;
54
55 let prefix = format!("{}", arch_name);
56
57 Ok(Self {
58 architecture,
59 hidden_size: Self::get_u32(&model.metadata, &format!("{}.embedding_length", prefix))? as usize,
60 intermediate_size: Self::get_u32(&model.metadata, &format!("{}.feed_forward_length", prefix))
61 .unwrap_or(0) as usize,
62 num_layers: Self::get_u32(&model.metadata, &format!("{}.block_count", prefix))? as usize,
63 num_heads: Self::get_u32(&model.metadata, &format!("{}.attention.head_count", prefix))? as usize,
64 num_key_value_heads: Self::get_u32(&model.metadata, &format!("{}.attention.head_count_kv", prefix))
65 .ok()
66 .map(|v| v as usize),
67 vocab_size: Self::get_u32(&model.metadata, "tokenizer.ggml.tokens")
68 .or_else(|_| Self::get_array_len(&model.metadata, "tokenizer.ggml.tokens"))
69 .unwrap_or(32000) as usize,
70 max_position_embeddings: Self::get_u32(&model.metadata, &format!("{}.context_length", prefix))
71 .unwrap_or(2048) as usize,
72 quantization: None, rope_theta: Self::get_f32(&model.metadata, &format!("{}.rope.freq_base", prefix)).ok(),
74 rope_scaling: None,
75 })
76 }
77
78 fn get_string(metadata: &HashMap<String, GgufValue>, key: &str) -> std::result::Result<String, String> {
79 match metadata.get(key) {
80 Some(GgufValue::String(s)) => Ok(s.clone()),
81 _ => Err(format!("Missing metadata: {}", key)),
82 }
83 }
84
85 fn get_u32(metadata: &HashMap<String, GgufValue>, key: &str) -> std::result::Result<u32, String> {
86 match metadata.get(key) {
87 Some(GgufValue::Uint32(v)) => Ok(*v),
88 Some(GgufValue::Uint64(v)) => Ok(*v as u32),
89 Some(GgufValue::Int32(v)) => Ok(*v as u32),
90 _ => Err(format!("Missing metadata: {}", key)),
91 }
92 }
93
94 fn get_f32(metadata: &HashMap<String, GgufValue>, key: &str) -> std::result::Result<f32, String> {
95 match metadata.get(key) {
96 Some(GgufValue::Float32(v)) => Ok(*v),
97 Some(GgufValue::Float64(v)) => Ok(*v as f32),
98 _ => Err(format!("Missing metadata: {}", key)),
99 }
100 }
101
102 fn get_array_len(metadata: &HashMap<String, GgufValue>, key: &str) -> std::result::Result<u32, String> {
103 match metadata.get(key) {
104 Some(GgufValue::Array(arr)) => Ok(arr.len() as u32),
105 _ => Err(format!("Missing metadata: {}", key)),
106 }
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112pub enum ModelArchitecture {
113 Llama,
114 LFM2,
115 Bert,
116 Mistral,
117 Qwen,
118 Phi,
119 Gemma,
120}
121
122impl ModelArchitecture {
123 pub fn from_str(s: &str) -> std::result::Result<Self, String> {
124 match s.to_lowercase().as_str() {
125 "llama" => Ok(Self::Llama),
126 "lfm" | "lfm2" => Ok(Self::LFM2),
127 "bert" => Ok(Self::Bert),
128 "mistral" => Ok(Self::Mistral),
129 "qwen" | "qwen2" => Ok(Self::Qwen),
130 "phi" | "phi2" | "phi3" => Ok(Self::Phi),
131 "gemma" | "gemma2" => Ok(Self::Gemma),
132 _ => Err(format!("Unsupported architecture: {}", s)),
133 }
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub enum QuantizationType {
140 F32,
141 F16,
142 Q4_0,
143 Q4_1,
144 Q5_0,
145 Q5_1,
146 Q8_0,
147 Q8_1,
148 Q4_K,
149 Q5_K,
150 Q6_K,
151}
152
153#[derive(Debug, Clone)]
155pub struct RopeScaling {
156 pub scaling_type: String,
157 pub factor: f32,
158}
159
160impl Default for ModelMetadata {
161 fn default() -> Self {
162 Self {
163 architecture: ModelArchitecture::Llama,
164 hidden_size: 4096,
165 intermediate_size: 11008,
166 num_layers: 32,
167 num_heads: 32,
168 num_key_value_heads: None,
169 vocab_size: 32000,
170 max_position_embeddings: 2048,
171 quantization: None,
172 rope_theta: Some(10000.0),
173 rope_scaling: None,
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_architecture_parsing() {
184 assert_eq!(
185 ModelArchitecture::from_str("llama").unwrap(),
186 ModelArchitecture::Llama
187 );
188 assert_eq!(
189 ModelArchitecture::from_str("BERT").unwrap(),
190 ModelArchitecture::Bert
191 );
192 }
193
194 #[test]
195 fn test_default_metadata() {
196 let metadata = ModelMetadata::default();
197 assert_eq!(metadata.architecture, ModelArchitecture::Llama);
198 assert_eq!(metadata.hidden_size, 4096);
199 }
200}