1use std::{
2 collections::BTreeSet,
3 fs,
4 path::{Path, PathBuf},
5};
6
7use candle_transformers::models::llama::{Config, LlamaConfig};
8use serde::Deserialize;
9
10use crate::{Result, WaxError};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum ModelSource {
14 Safetensors { files: Vec<PathBuf> },
15 Gguf { file: PathBuf },
16 Mlx { path: PathBuf },
17}
18
19#[derive(Debug, Clone)]
20pub struct ModelConfig {
21 pub architectures: Vec<String>,
22 pub llama: Config,
23}
24
25#[derive(Debug, Deserialize)]
26struct ConfigFile {
27 #[serde(default)]
28 architectures: Vec<String>,
29 #[serde(flatten)]
30 llama: LlamaConfig,
31}
32
33impl ModelConfig {
34 pub fn load(model_dir: &Path) -> Result<Self> {
35 let config_path = model_dir.join("config.json");
36 if !config_path.is_file() {
37 return Err(WaxError::MissingModelFile(config_path));
38 }
39
40 let bytes = fs::read(&config_path)?;
41 let parsed: ConfigFile = serde_json::from_slice(&bytes)?;
42 validate_architectures(&parsed.architectures)?;
43
44 Ok(Self {
45 architectures: parsed.architectures,
46 llama: parsed.llama.into_config(false),
47 })
48 }
49}
50
51fn validate_architectures(architectures: &[String]) -> Result<()> {
52 if architectures.is_empty() {
53 return Ok(());
54 }
55
56 if architectures.iter().any(|arch| {
57 arch == "LlamaForCausalLM"
58 || arch == "MistralForCausalLM"
59 || arch.contains("Llama")
60 || arch.contains("TinyLlama")
61 }) {
62 return Ok(());
63 }
64
65 Err(WaxError::UnsupportedArchitecture {
66 architecture: architectures.join(", "),
67 })
68}
69
70pub fn resolve_safetensors_files(model_dir: &Path) -> Result<Vec<PathBuf>> {
71 let index_path = model_dir.join("model.safetensors.index.json");
72 if index_path.is_file() {
73 return resolve_indexed_safetensors(model_dir, &index_path);
74 }
75
76 let single_file = model_dir.join("model.safetensors");
77 if single_file.is_file() {
78 return Ok(vec![single_file]);
79 }
80
81 Err(WaxError::InvalidModelFolder {
82 path: model_dir.to_path_buf(),
83 reason: "expected model.safetensors or model.safetensors.index.json".to_string(),
84 })
85}
86
87pub fn resolve_model_source(model_dir: &Path) -> Result<ModelSource> {
88 if model_dir.is_file() && model_dir.extension().is_some_and(|ext| ext == "gguf") {
89 return Ok(ModelSource::Gguf {
90 file: model_dir.to_path_buf(),
91 });
92 }
93
94 if looks_like_mlx_model(model_dir) {
95 return Ok(ModelSource::Mlx {
96 path: model_dir.to_path_buf(),
97 });
98 }
99
100 if let Some(file) = resolve_gguf_file(model_dir)? {
101 return Ok(ModelSource::Gguf { file });
102 }
103
104 resolve_safetensors_files(model_dir).map(|files| ModelSource::Safetensors { files })
105}
106
107fn resolve_gguf_file(model_dir: &Path) -> Result<Option<PathBuf>> {
108 let direct = model_dir.join("model.gguf");
109 if direct.is_file() {
110 return Ok(Some(direct));
111 }
112
113 let mut files = fs::read_dir(model_dir)?
114 .filter_map(|entry| entry.ok().map(|entry| entry.path()))
115 .filter(|path| path.extension().is_some_and(|ext| ext == "gguf"))
116 .collect::<Vec<_>>();
117 files.sort();
118
119 match files.len() {
120 0 => Ok(None),
121 1 => Ok(files.pop()),
122 _ => Err(WaxError::InvalidModelFolder {
123 path: model_dir.to_path_buf(),
124 reason: "multiple .gguf files found; keep exactly one GGUF file or name it model.gguf"
125 .to_string(),
126 }),
127 }
128}
129
130fn looks_like_mlx_model(model_dir: &Path) -> bool {
131 model_dir.join("weights.npz").is_file()
132 || has_mlx_weight_shards(model_dir)
133 || model_dir.join("model.safetensors.index.json").is_file()
134 && fs::read_to_string(model_dir.join("config.json"))
135 .map(|config| config.contains("\"model_type\"") && config.contains("mlx"))
136 .unwrap_or(false)
137}
138
139fn has_mlx_weight_shards(model_dir: &Path) -> bool {
140 fs::read_dir(model_dir)
141 .map(|entries| {
142 entries.filter_map(|entry| entry.ok()).any(|entry| {
143 let path = entry.path();
144 let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
145 return false;
146 };
147 name.starts_with("weights.")
148 && path
149 .extension()
150 .is_some_and(|ext| ext == "safetensors" || ext == "npz")
151 })
152 })
153 .unwrap_or(false)
154}
155
156fn resolve_indexed_safetensors(model_dir: &Path, index_path: &Path) -> Result<Vec<PathBuf>> {
157 let file = fs::File::open(index_path)?;
158 let json: serde_json::Value = serde_json::from_reader(file)?;
159 let weight_map = json
160 .get("weight_map")
161 .and_then(serde_json::Value::as_object)
162 .ok_or_else(|| WaxError::InvalidModelFolder {
163 path: model_dir.to_path_buf(),
164 reason: format!(
165 "{} does not contain a weight_map object",
166 index_path.display()
167 ),
168 })?;
169
170 let mut files = BTreeSet::new();
171 for value in weight_map.values() {
172 let Some(filename) = value.as_str() else {
173 return Err(WaxError::InvalidModelFolder {
174 path: model_dir.to_path_buf(),
175 reason: "weight_map values must be safetensors filenames".to_string(),
176 });
177 };
178 files.insert(filename.to_string());
179 }
180
181 let files = files
182 .into_iter()
183 .map(|filename| model_dir.join(filename))
184 .collect::<Vec<_>>();
185 for file in &files {
186 if !file.is_file() {
187 return Err(WaxError::MissingModelFile(file.clone()));
188 }
189 }
190
191 Ok(files)
192}
193
194#[cfg(test)]
195mod tests {
196 use std::{fs, path::Path};
197
198 use super::{resolve_model_source, resolve_safetensors_files, ModelConfig, ModelSource};
199 use crate::WaxError;
200
201 fn write_min_llama_config(path: &Path, architectures: &[&str]) {
202 let architectures = serde_json::to_string(architectures).unwrap();
203 fs::write(
204 path,
205 format!(
206 r#"{{
207 "architectures": {architectures},
208 "hidden_size": 16,
209 "intermediate_size": 32,
210 "vocab_size": 128,
211 "num_hidden_layers": 1,
212 "num_attention_heads": 2,
213 "num_key_value_heads": 2,
214 "rms_norm_eps": 0.000001,
215 "rope_theta": 10000.0,
216 "max_position_embeddings": 64
217 }}"#
218 ),
219 )
220 .unwrap();
221 }
222
223 #[test]
224 fn accepts_llama_architecture() {
225 let dir = tempfile::tempdir().unwrap();
226 write_min_llama_config(&dir.path().join("config.json"), &["LlamaForCausalLM"]);
227
228 let config = ModelConfig::load(dir.path()).unwrap();
229
230 assert_eq!(config.architectures, vec!["LlamaForCausalLM"]);
231 }
232
233 #[test]
234 fn accepts_missing_architecture_for_hf_compatible_llama_configs() {
235 let dir = tempfile::tempdir().unwrap();
236 write_min_llama_config(&dir.path().join("config.json"), &[]);
237
238 let config = ModelConfig::load(dir.path()).unwrap();
239
240 assert!(config.architectures.is_empty());
241 }
242
243 #[test]
244 fn reports_missing_config_path() {
245 let dir = tempfile::tempdir().unwrap();
246
247 let err = ModelConfig::load(dir.path()).unwrap_err();
248
249 assert!(matches!(err, WaxError::MissingModelFile(path) if path.ends_with("config.json")));
250 }
251
252 #[test]
253 fn rejects_unsupported_architecture() {
254 let dir = tempfile::tempdir().unwrap();
255 write_min_llama_config(&dir.path().join("config.json"), &["Qwen2ForCausalLM"]);
256
257 let err = ModelConfig::load(dir.path()).unwrap_err();
258
259 assert!(matches!(err, WaxError::UnsupportedArchitecture { .. }));
260 assert!(err.to_string().contains("Qwen2ForCausalLM"));
261 }
262
263 #[test]
264 fn resolves_single_safetensors_file() {
265 let dir = tempfile::tempdir().unwrap();
266 fs::write(dir.path().join("model.safetensors"), b"").unwrap();
267
268 let files = resolve_safetensors_files(dir.path()).unwrap();
269
270 assert_eq!(files, vec![dir.path().join("model.safetensors")]);
271 }
272
273 #[test]
274 fn resolves_single_gguf_file() {
275 let dir = tempfile::tempdir().unwrap();
276 fs::write(dir.path().join("tiny.gguf"), b"GGUF").unwrap();
277
278 let source = resolve_model_source(dir.path()).unwrap();
279
280 assert_eq!(
281 source,
282 ModelSource::Gguf {
283 file: dir.path().join("tiny.gguf")
284 }
285 );
286 }
287
288 #[test]
289 fn resolves_direct_gguf_file_path() {
290 let dir = tempfile::tempdir().unwrap();
291 let file = dir.path().join("tiny.gguf");
292 fs::write(&file, b"GGUF").unwrap();
293
294 let source = resolve_model_source(&file).unwrap();
295
296 assert_eq!(source, ModelSource::Gguf { file });
297 }
298
299 #[test]
300 fn rejects_ambiguous_gguf_files() {
301 let dir = tempfile::tempdir().unwrap();
302 fs::write(dir.path().join("a.gguf"), b"GGUF").unwrap();
303 fs::write(dir.path().join("b.gguf"), b"GGUF").unwrap();
304
305 let err = resolve_model_source(dir.path()).unwrap_err();
306
307 assert!(err.to_string().contains("multiple .gguf"));
308 }
309
310 #[test]
311 fn detects_mlx_model_folder_before_safetensors() {
312 let dir = tempfile::tempdir().unwrap();
313 fs::write(dir.path().join("config.json"), r#"{"model_type":"mlx"}"#).unwrap();
314 fs::write(dir.path().join("model.safetensors.index.json"), "{}").unwrap();
315
316 let source = resolve_model_source(dir.path()).unwrap();
317
318 assert_eq!(
319 source,
320 ModelSource::Mlx {
321 path: dir.path().to_path_buf()
322 }
323 );
324 }
325
326 #[test]
327 fn detects_mlx_weight_npz_folder() {
328 let dir = tempfile::tempdir().unwrap();
329 fs::write(dir.path().join("weights.npz"), b"").unwrap();
330
331 let source = resolve_model_source(dir.path()).unwrap();
332
333 assert_eq!(
334 source,
335 ModelSource::Mlx {
336 path: dir.path().to_path_buf()
337 }
338 );
339 }
340
341 #[test]
342 fn detects_mlx_safetensors_weight_shards() {
343 let dir = tempfile::tempdir().unwrap();
344 fs::write(dir.path().join("weights.00.safetensors"), b"").unwrap();
345
346 let source = resolve_model_source(dir.path()).unwrap();
347
348 assert_eq!(
349 source,
350 ModelSource::Mlx {
351 path: dir.path().to_path_buf()
352 }
353 );
354 }
355
356 #[test]
357 fn resolves_indexed_safetensors_files_in_stable_order() {
358 let dir = tempfile::tempdir().unwrap();
359 fs::write(dir.path().join("a.safetensors"), b"").unwrap();
360 fs::write(dir.path().join("b.safetensors"), b"").unwrap();
361 fs::write(
362 dir.path().join("model.safetensors.index.json"),
363 r#"{"weight_map":{"z":"b.safetensors","a":"a.safetensors","b":"b.safetensors"}}"#,
364 )
365 .unwrap();
366
367 let files = resolve_safetensors_files(dir.path()).unwrap();
368
369 assert_eq!(
370 files,
371 vec![
372 dir.path().join("a.safetensors"),
373 dir.path().join("b.safetensors")
374 ]
375 );
376 }
377
378 #[test]
379 fn indexed_safetensors_reports_missing_shard() {
380 let dir = tempfile::tempdir().unwrap();
381 fs::write(
382 dir.path().join("model.safetensors.index.json"),
383 r#"{"weight_map":{"model.embed_tokens.weight":"missing.safetensors"}}"#,
384 )
385 .unwrap();
386
387 let err = resolve_safetensors_files(dir.path()).unwrap_err();
388
389 assert!(
390 matches!(err, WaxError::MissingModelFile(path) if path.ends_with("missing.safetensors"))
391 );
392 }
393
394 #[test]
395 fn indexed_safetensors_requires_weight_map_object() {
396 let dir = tempfile::tempdir().unwrap();
397 fs::write(
398 dir.path().join("model.safetensors.index.json"),
399 r#"{"metadata":{}}"#,
400 )
401 .unwrap();
402
403 let err = resolve_safetensors_files(dir.path()).unwrap_err();
404
405 assert!(err.to_string().contains("weight_map"));
406 }
407}