Skip to main content

wax_core/
loader.rs

1use std::{
2    collections::BTreeSet,
3    fs,
4    path::{Path, PathBuf},
5};
6
7use candle_transformers::models::llama::{Config, LlamaConfig};
8use serde::Deserialize;
9
10use crate::{Result, WaxError};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum ModelSource {
14    Safetensors { files: Vec<PathBuf> },
15    Gguf { file: PathBuf },
16    Mlx { path: PathBuf },
17}
18
19#[derive(Debug, Clone)]
20pub struct ModelConfig {
21    pub architectures: Vec<String>,
22    pub llama: Config,
23}
24
25#[derive(Debug, Deserialize)]
26struct ConfigFile {
27    #[serde(default)]
28    architectures: Vec<String>,
29    #[serde(flatten)]
30    llama: LlamaConfig,
31}
32
33impl ModelConfig {
34    pub fn load(model_dir: &Path) -> Result<Self> {
35        let config_path = model_dir.join("config.json");
36        if !config_path.is_file() {
37            return Err(WaxError::MissingModelFile(config_path));
38        }
39
40        let bytes = fs::read(&config_path)?;
41        let parsed: ConfigFile = serde_json::from_slice(&bytes)?;
42        validate_architectures(&parsed.architectures)?;
43
44        Ok(Self {
45            architectures: parsed.architectures,
46            llama: parsed.llama.into_config(false),
47        })
48    }
49}
50
51fn validate_architectures(architectures: &[String]) -> Result<()> {
52    if architectures.is_empty() {
53        return Ok(());
54    }
55
56    if architectures.iter().any(|arch| {
57        arch == "LlamaForCausalLM"
58            || arch == "MistralForCausalLM"
59            || arch.contains("Llama")
60            || arch.contains("TinyLlama")
61    }) {
62        return Ok(());
63    }
64
65    Err(WaxError::UnsupportedArchitecture {
66        architecture: architectures.join(", "),
67    })
68}
69
70pub fn resolve_safetensors_files(model_dir: &Path) -> Result<Vec<PathBuf>> {
71    let index_path = model_dir.join("model.safetensors.index.json");
72    if index_path.is_file() {
73        return resolve_indexed_safetensors(model_dir, &index_path);
74    }
75
76    let single_file = model_dir.join("model.safetensors");
77    if single_file.is_file() {
78        return Ok(vec![single_file]);
79    }
80
81    Err(WaxError::InvalidModelFolder {
82        path: model_dir.to_path_buf(),
83        reason: "expected model.safetensors or model.safetensors.index.json".to_string(),
84    })
85}
86
87pub fn resolve_model_source(model_dir: &Path) -> Result<ModelSource> {
88    if model_dir.is_file() && model_dir.extension().is_some_and(|ext| ext == "gguf") {
89        return Ok(ModelSource::Gguf {
90            file: model_dir.to_path_buf(),
91        });
92    }
93
94    if looks_like_mlx_model(model_dir) {
95        return Ok(ModelSource::Mlx {
96            path: model_dir.to_path_buf(),
97        });
98    }
99
100    if let Some(file) = resolve_gguf_file(model_dir)? {
101        return Ok(ModelSource::Gguf { file });
102    }
103
104    resolve_safetensors_files(model_dir).map(|files| ModelSource::Safetensors { files })
105}
106
107fn resolve_gguf_file(model_dir: &Path) -> Result<Option<PathBuf>> {
108    let direct = model_dir.join("model.gguf");
109    if direct.is_file() {
110        return Ok(Some(direct));
111    }
112
113    let mut files = fs::read_dir(model_dir)?
114        .filter_map(|entry| entry.ok().map(|entry| entry.path()))
115        .filter(|path| path.extension().is_some_and(|ext| ext == "gguf"))
116        .collect::<Vec<_>>();
117    files.sort();
118
119    match files.len() {
120        0 => Ok(None),
121        1 => Ok(files.pop()),
122        _ => Err(WaxError::InvalidModelFolder {
123            path: model_dir.to_path_buf(),
124            reason: "multiple .gguf files found; keep exactly one GGUF file or name it model.gguf"
125                .to_string(),
126        }),
127    }
128}
129
130fn looks_like_mlx_model(model_dir: &Path) -> bool {
131    model_dir.join("weights.npz").is_file()
132        || has_mlx_weight_shards(model_dir)
133        || model_dir.join("model.safetensors.index.json").is_file()
134            && fs::read_to_string(model_dir.join("config.json"))
135                .map(|config| config.contains("\"model_type\"") && config.contains("mlx"))
136                .unwrap_or(false)
137}
138
139fn has_mlx_weight_shards(model_dir: &Path) -> bool {
140    fs::read_dir(model_dir)
141        .map(|entries| {
142            entries.filter_map(|entry| entry.ok()).any(|entry| {
143                let path = entry.path();
144                let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
145                    return false;
146                };
147                name.starts_with("weights.")
148                    && path
149                        .extension()
150                        .is_some_and(|ext| ext == "safetensors" || ext == "npz")
151            })
152        })
153        .unwrap_or(false)
154}
155
156fn resolve_indexed_safetensors(model_dir: &Path, index_path: &Path) -> Result<Vec<PathBuf>> {
157    let file = fs::File::open(index_path)?;
158    let json: serde_json::Value = serde_json::from_reader(file)?;
159    let weight_map = json
160        .get("weight_map")
161        .and_then(serde_json::Value::as_object)
162        .ok_or_else(|| WaxError::InvalidModelFolder {
163            path: model_dir.to_path_buf(),
164            reason: format!(
165                "{} does not contain a weight_map object",
166                index_path.display()
167            ),
168        })?;
169
170    let mut files = BTreeSet::new();
171    for value in weight_map.values() {
172        let Some(filename) = value.as_str() else {
173            return Err(WaxError::InvalidModelFolder {
174                path: model_dir.to_path_buf(),
175                reason: "weight_map values must be safetensors filenames".to_string(),
176            });
177        };
178        files.insert(filename.to_string());
179    }
180
181    let files = files
182        .into_iter()
183        .map(|filename| model_dir.join(filename))
184        .collect::<Vec<_>>();
185    for file in &files {
186        if !file.is_file() {
187            return Err(WaxError::MissingModelFile(file.clone()));
188        }
189    }
190
191    Ok(files)
192}
193
194#[cfg(test)]
195mod tests {
196    use std::{fs, path::Path};
197
198    use super::{resolve_model_source, resolve_safetensors_files, ModelConfig, ModelSource};
199    use crate::WaxError;
200
201    fn write_min_llama_config(path: &Path, architectures: &[&str]) {
202        let architectures = serde_json::to_string(architectures).unwrap();
203        fs::write(
204            path,
205            format!(
206                r#"{{
207                    "architectures": {architectures},
208                    "hidden_size": 16,
209                    "intermediate_size": 32,
210                    "vocab_size": 128,
211                    "num_hidden_layers": 1,
212                    "num_attention_heads": 2,
213                    "num_key_value_heads": 2,
214                    "rms_norm_eps": 0.000001,
215                    "rope_theta": 10000.0,
216                    "max_position_embeddings": 64
217                }}"#
218            ),
219        )
220        .unwrap();
221    }
222
223    #[test]
224    fn accepts_llama_architecture() {
225        let dir = tempfile::tempdir().unwrap();
226        write_min_llama_config(&dir.path().join("config.json"), &["LlamaForCausalLM"]);
227
228        let config = ModelConfig::load(dir.path()).unwrap();
229
230        assert_eq!(config.architectures, vec!["LlamaForCausalLM"]);
231    }
232
233    #[test]
234    fn accepts_missing_architecture_for_hf_compatible_llama_configs() {
235        let dir = tempfile::tempdir().unwrap();
236        write_min_llama_config(&dir.path().join("config.json"), &[]);
237
238        let config = ModelConfig::load(dir.path()).unwrap();
239
240        assert!(config.architectures.is_empty());
241    }
242
243    #[test]
244    fn reports_missing_config_path() {
245        let dir = tempfile::tempdir().unwrap();
246
247        let err = ModelConfig::load(dir.path()).unwrap_err();
248
249        assert!(matches!(err, WaxError::MissingModelFile(path) if path.ends_with("config.json")));
250    }
251
252    #[test]
253    fn rejects_unsupported_architecture() {
254        let dir = tempfile::tempdir().unwrap();
255        write_min_llama_config(&dir.path().join("config.json"), &["Qwen2ForCausalLM"]);
256
257        let err = ModelConfig::load(dir.path()).unwrap_err();
258
259        assert!(matches!(err, WaxError::UnsupportedArchitecture { .. }));
260        assert!(err.to_string().contains("Qwen2ForCausalLM"));
261    }
262
263    #[test]
264    fn resolves_single_safetensors_file() {
265        let dir = tempfile::tempdir().unwrap();
266        fs::write(dir.path().join("model.safetensors"), b"").unwrap();
267
268        let files = resolve_safetensors_files(dir.path()).unwrap();
269
270        assert_eq!(files, vec![dir.path().join("model.safetensors")]);
271    }
272
273    #[test]
274    fn resolves_single_gguf_file() {
275        let dir = tempfile::tempdir().unwrap();
276        fs::write(dir.path().join("tiny.gguf"), b"GGUF").unwrap();
277
278        let source = resolve_model_source(dir.path()).unwrap();
279
280        assert_eq!(
281            source,
282            ModelSource::Gguf {
283                file: dir.path().join("tiny.gguf")
284            }
285        );
286    }
287
288    #[test]
289    fn resolves_direct_gguf_file_path() {
290        let dir = tempfile::tempdir().unwrap();
291        let file = dir.path().join("tiny.gguf");
292        fs::write(&file, b"GGUF").unwrap();
293
294        let source = resolve_model_source(&file).unwrap();
295
296        assert_eq!(source, ModelSource::Gguf { file });
297    }
298
299    #[test]
300    fn rejects_ambiguous_gguf_files() {
301        let dir = tempfile::tempdir().unwrap();
302        fs::write(dir.path().join("a.gguf"), b"GGUF").unwrap();
303        fs::write(dir.path().join("b.gguf"), b"GGUF").unwrap();
304
305        let err = resolve_model_source(dir.path()).unwrap_err();
306
307        assert!(err.to_string().contains("multiple .gguf"));
308    }
309
310    #[test]
311    fn detects_mlx_model_folder_before_safetensors() {
312        let dir = tempfile::tempdir().unwrap();
313        fs::write(dir.path().join("config.json"), r#"{"model_type":"mlx"}"#).unwrap();
314        fs::write(dir.path().join("model.safetensors.index.json"), "{}").unwrap();
315
316        let source = resolve_model_source(dir.path()).unwrap();
317
318        assert_eq!(
319            source,
320            ModelSource::Mlx {
321                path: dir.path().to_path_buf()
322            }
323        );
324    }
325
326    #[test]
327    fn detects_mlx_weight_npz_folder() {
328        let dir = tempfile::tempdir().unwrap();
329        fs::write(dir.path().join("weights.npz"), b"").unwrap();
330
331        let source = resolve_model_source(dir.path()).unwrap();
332
333        assert_eq!(
334            source,
335            ModelSource::Mlx {
336                path: dir.path().to_path_buf()
337            }
338        );
339    }
340
341    #[test]
342    fn detects_mlx_safetensors_weight_shards() {
343        let dir = tempfile::tempdir().unwrap();
344        fs::write(dir.path().join("weights.00.safetensors"), b"").unwrap();
345
346        let source = resolve_model_source(dir.path()).unwrap();
347
348        assert_eq!(
349            source,
350            ModelSource::Mlx {
351                path: dir.path().to_path_buf()
352            }
353        );
354    }
355
356    #[test]
357    fn resolves_indexed_safetensors_files_in_stable_order() {
358        let dir = tempfile::tempdir().unwrap();
359        fs::write(dir.path().join("a.safetensors"), b"").unwrap();
360        fs::write(dir.path().join("b.safetensors"), b"").unwrap();
361        fs::write(
362            dir.path().join("model.safetensors.index.json"),
363            r#"{"weight_map":{"z":"b.safetensors","a":"a.safetensors","b":"b.safetensors"}}"#,
364        )
365        .unwrap();
366
367        let files = resolve_safetensors_files(dir.path()).unwrap();
368
369        assert_eq!(
370            files,
371            vec![
372                dir.path().join("a.safetensors"),
373                dir.path().join("b.safetensors")
374            ]
375        );
376    }
377
378    #[test]
379    fn indexed_safetensors_reports_missing_shard() {
380        let dir = tempfile::tempdir().unwrap();
381        fs::write(
382            dir.path().join("model.safetensors.index.json"),
383            r#"{"weight_map":{"model.embed_tokens.weight":"missing.safetensors"}}"#,
384        )
385        .unwrap();
386
387        let err = resolve_safetensors_files(dir.path()).unwrap_err();
388
389        assert!(
390            matches!(err, WaxError::MissingModelFile(path) if path.ends_with("missing.safetensors"))
391        );
392    }
393
394    #[test]
395    fn indexed_safetensors_requires_weight_map_object() {
396        let dir = tempfile::tempdir().unwrap();
397        fs::write(
398            dir.path().join("model.safetensors.index.json"),
399            r#"{"metadata":{}}"#,
400        )
401        .unwrap();
402
403        let err = resolve_safetensors_files(dir.path()).unwrap_err();
404
405        assert!(err.to_string().contains("weight_map"));
406    }
407}