vespertide_loader/
models.rs

1use std::fs;
2use std::path::Path;
3
4use anyhow::{Context, Result};
5use vespertide_config::VespertideConfig;
6use vespertide_core::TableDef;
7use vespertide_planner::validate_schema;
8
9/// Load all model definitions from the models directory (recursively).
10pub fn load_models(config: &VespertideConfig) -> Result<Vec<TableDef>> {
11    let models_dir = config.models_dir();
12    if !models_dir.exists() {
13        return Ok(Vec::new());
14    }
15
16    let mut tables = Vec::new();
17    load_models_recursive(models_dir, &mut tables)?;
18
19    // Validate schema integrity using normalized version
20    // But return the original tables to preserve inline constraints
21    if !tables.is_empty() {
22        let normalized_tables: Vec<TableDef> = tables
23            .iter()
24            .map(|t| {
25                t.normalize()
26                    .map_err(|e| anyhow::anyhow!("Failed to normalize table '{}': {}", t.name, e))
27            })
28            .collect::<Result<Vec<_>, _>>()?;
29
30        validate_schema(&normalized_tables)
31            .map_err(|e| anyhow::anyhow!("schema validation failed: {}", e))?;
32    }
33
34    Ok(tables)
35}
36
37/// Recursively walk directory and load model files.
38fn load_models_recursive(dir: &Path, tables: &mut Vec<TableDef>) -> Result<()> {
39    let entries =
40        fs::read_dir(dir).with_context(|| format!("read models directory: {}", dir.display()))?;
41
42    for entry in entries {
43        let entry = entry.context("read directory entry")?;
44        let path = entry.path();
45
46        if path.is_dir() {
47            // Recursively process subdirectories
48            load_models_recursive(&path, tables)?;
49            continue;
50        }
51
52        if path.is_file() {
53            let ext = path.extension().and_then(|s| s.to_str());
54            if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
55                let content = fs::read_to_string(&path)
56                    .with_context(|| format!("read model file: {}", path.display()))?;
57
58                let table: TableDef = if ext == Some("json") {
59                    serde_json::from_str(&content)
60                        .with_context(|| format!("parse JSON model: {}", path.display()))?
61                } else {
62                    serde_yaml::from_str(&content)
63                        .with_context(|| format!("parse YAML model: {}", path.display()))?
64                };
65
66                tables.push(table);
67            }
68        }
69    }
70
71    Ok(())
72}
73
74/// Load models from a specific directory (for compile-time use in macros).
75pub fn load_models_from_dir(
76    project_root: Option<std::path::PathBuf>,
77) -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
78    use std::env;
79
80    // Locate project root from CARGO_MANIFEST_DIR or use provided path
81    let project_root = if let Some(root) = project_root {
82        root
83    } else {
84        std::path::PathBuf::from(
85            env::var("CARGO_MANIFEST_DIR")
86                .context("CARGO_MANIFEST_DIR environment variable not set")?,
87        )
88    };
89
90    // Read vespertide.json or use defaults
91    let config = crate::config::load_config_or_default(Some(project_root.clone()))
92        .map_err(|e| format!("Failed to load config: {}", e))?;
93
94    // Read models directory
95    let models_dir = project_root.join(config.models_dir());
96    if !models_dir.exists() {
97        return Ok(Vec::new());
98    }
99
100    let mut tables = Vec::new();
101    load_models_recursive_internal(&models_dir, &mut tables)
102        .map_err(|e| format!("Failed to load models: {}", e))?;
103
104    // Normalize tables
105    let normalized_tables: Vec<TableDef> = tables
106        .into_iter()
107        .map(|t| {
108            t.normalize()
109                .map_err(|e| format!("Failed to normalize table '{}': {}", t.name, e))
110        })
111        .collect::<Result<Vec<_>, _>>()
112        .map_err(|e| e.to_string())?;
113
114    Ok(normalized_tables)
115}
116
117/// Internal recursive function for loading models (used by both runtime and compile-time).
118fn load_models_recursive_internal(
119    dir: &Path,
120    tables: &mut Vec<TableDef>,
121) -> Result<(), Box<dyn std::error::Error>> {
122    use std::fs;
123
124    let entries = fs::read_dir(dir)
125        .map_err(|e| format!("Failed to read models directory {}: {}", dir.display(), e))?;
126
127    for entry in entries {
128        let entry = entry.map_err(|e| format!("Failed to read directory entry: {}", e))?;
129        let path = entry.path();
130
131        if path.is_dir() {
132            // Recursively process subdirectories
133            load_models_recursive_internal(&path, tables)?;
134            continue;
135        }
136
137        if path.is_file() {
138            let ext = path.extension().and_then(|s| s.to_str());
139            if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
140                let content = fs::read_to_string(&path)
141                    .map_err(|e| format!("Failed to read model file {}: {}", path.display(), e))?;
142
143                let table: TableDef = if ext == Some("json") {
144                    serde_json::from_str(&content).map_err(|e| {
145                        format!("Failed to parse JSON model {}: {}", path.display(), e)
146                    })?
147                } else {
148                    serde_yaml::from_str(&content).map_err(|e| {
149                        format!("Failed to parse YAML model {}: {}", path.display(), e)
150                    })?
151                };
152
153                tables.push(table);
154            }
155        }
156    }
157
158    Ok(())
159}
160
161/// Load models at compile time (for macro use).
162pub fn load_models_at_compile_time() -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
163    load_models_from_dir(None)
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use serial_test::serial;
170    use std::fs;
171    use tempfile::tempdir;
172    use vespertide_core::{
173        ColumnDef, ColumnType, SimpleColumnType, TableConstraint,
174        schema::foreign_key::ForeignKeySyntax,
175    };
176
177    struct CwdGuard {
178        original: std::path::PathBuf,
179    }
180
181    impl CwdGuard {
182        fn new(dir: &std::path::PathBuf) -> Self {
183            let original = std::env::current_dir().unwrap();
184            std::env::set_current_dir(dir).unwrap();
185            Self { original }
186        }
187    }
188
189    impl Drop for CwdGuard {
190        fn drop(&mut self) {
191            let _ = std::env::set_current_dir(&self.original);
192        }
193    }
194
195    fn write_config() {
196        let cfg = VespertideConfig::default();
197        let text = serde_json::to_string_pretty(&cfg).unwrap();
198        fs::write("vespertide.json", text).unwrap();
199    }
200
201    #[test]
202    #[serial]
203    fn load_models_returns_empty_when_no_models_dir() {
204        let tmp = tempdir().unwrap();
205        let _guard = CwdGuard::new(&tmp.path().to_path_buf());
206        write_config();
207
208        // Don't create models directory
209        let models = load_models(&VespertideConfig::default()).unwrap();
210        assert_eq!(models.len(), 0);
211    }
212
213    #[test]
214    #[serial]
215    fn load_models_reads_yaml_and_validates() {
216        let tmp = tempdir().unwrap();
217        let _guard = CwdGuard::new(&tmp.path().to_path_buf());
218        write_config();
219
220        fs::create_dir_all("models").unwrap();
221        let table = TableDef {
222            name: "users".into(),
223            columns: vec![ColumnDef {
224                name: "id".into(),
225                r#type: ColumnType::Simple(SimpleColumnType::Integer),
226                nullable: false,
227                default: None,
228                comment: None,
229                primary_key: None,
230                unique: None,
231                index: None,
232                foreign_key: None,
233            }],
234            constraints: vec![TableConstraint::PrimaryKey {
235                auto_increment: false,
236                columns: vec!["id".into()],
237            }],
238        };
239        fs::write("models/users.yaml", serde_yaml::to_string(&table).unwrap()).unwrap();
240
241        let models = load_models(&VespertideConfig::default()).unwrap();
242        assert_eq!(models.len(), 1);
243        assert_eq!(models[0].name, "users");
244    }
245
246    #[test]
247    #[serial]
248    fn load_models_recursive_processes_subdirectories() {
249        let tmp = tempdir().unwrap();
250        let _guard = CwdGuard::new(&tmp.path().to_path_buf());
251        write_config();
252
253        fs::create_dir_all("models/subdir").unwrap();
254
255        // Create model in subdirectory
256        let table = TableDef {
257            name: "subtable".into(),
258            columns: vec![ColumnDef {
259                name: "id".into(),
260                r#type: ColumnType::Simple(SimpleColumnType::Integer),
261                nullable: false,
262                default: None,
263                comment: None,
264                primary_key: None,
265                unique: None,
266                index: None,
267                foreign_key: None,
268            }],
269            constraints: vec![TableConstraint::PrimaryKey {
270                auto_increment: false,
271                columns: vec!["id".into()],
272            }],
273        };
274        let content = serde_json::to_string_pretty(&table).unwrap();
275        fs::write("models/subdir/subtable.json", content).unwrap();
276
277        let models = load_models(&VespertideConfig::default()).unwrap();
278        assert_eq!(models.len(), 1);
279        assert_eq!(models[0].name, "subtable");
280    }
281
282    #[test]
283    #[serial]
284    fn load_models_fails_on_invalid_fk_format() {
285        let tmp = tempdir().unwrap();
286        let _guard = CwdGuard::new(&tmp.path().to_path_buf());
287        write_config();
288
289        fs::create_dir_all("models").unwrap();
290
291        // Create a model with invalid FK string format (missing dot separator)
292        let table = TableDef {
293            name: "orders".into(),
294            columns: vec![ColumnDef {
295                name: "user_id".into(),
296                r#type: ColumnType::Simple(SimpleColumnType::Integer),
297                nullable: false,
298                default: None,
299                comment: None,
300                primary_key: None,
301                unique: None,
302                index: None,
303                // Invalid FK format: should be "table.column" but missing the dot
304                foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
305            }],
306            constraints: vec![],
307        };
308        fs::write(
309            "models/orders.json",
310            serde_json::to_string_pretty(&table).unwrap(),
311        )
312        .unwrap();
313
314        let result = load_models(&VespertideConfig::default());
315        assert!(result.is_err());
316        let err_msg = result.unwrap_err().to_string();
317        assert!(err_msg.contains("Failed to normalize table 'orders'"));
318    }
319
320    #[test]
321    #[serial]
322    fn test_load_models_from_dir_with_root() {
323        let temp_dir = tempdir().unwrap();
324        let models_dir = temp_dir.path().join("models");
325        fs::create_dir_all(&models_dir).unwrap();
326
327        let table = TableDef {
328            name: "users".into(),
329            columns: vec![ColumnDef {
330                name: "id".into(),
331                r#type: ColumnType::Simple(SimpleColumnType::Integer),
332                nullable: false,
333                default: None,
334                comment: None,
335                primary_key: None,
336                unique: None,
337                index: None,
338                foreign_key: None,
339            }],
340            constraints: vec![],
341        };
342        fs::write(
343            models_dir.join("users.json"),
344            serde_json::to_string_pretty(&table).unwrap(),
345        )
346        .unwrap();
347
348        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
349        assert!(result.is_ok());
350        let models = result.unwrap();
351        assert_eq!(models.len(), 1);
352        assert_eq!(models[0].name, "users");
353    }
354
355    #[test]
356    #[serial]
357    fn test_load_models_from_dir_without_root() {
358        use std::env;
359
360        // Save the original value
361        let original = env::var("CARGO_MANIFEST_DIR").ok();
362
363        // Remove CARGO_MANIFEST_DIR to test the error path
364        unsafe {
365            env::remove_var("CARGO_MANIFEST_DIR");
366        }
367
368        let result = load_models_from_dir(None);
369        assert!(result.is_err());
370        let err_msg = result.unwrap_err().to_string();
371        assert!(err_msg.contains("CARGO_MANIFEST_DIR environment variable not set"));
372
373        drop(original);
374    }
375
376    #[test]
377    #[serial]
378    fn test_load_models_from_dir_no_models_dir() {
379        let temp_dir = tempdir().unwrap();
380        // Don't create models directory
381
382        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
383        assert!(result.is_ok());
384        let models = result.unwrap();
385        assert_eq!(models.len(), 0);
386    }
387
388    #[test]
389    #[serial]
390    fn test_load_models_from_dir_with_yaml() {
391        let temp_dir = tempdir().unwrap();
392        let models_dir = temp_dir.path().join("models");
393        fs::create_dir_all(&models_dir).unwrap();
394
395        let table = TableDef {
396            name: "users".into(),
397            columns: vec![ColumnDef {
398                name: "id".into(),
399                r#type: ColumnType::Simple(SimpleColumnType::Integer),
400                nullable: false,
401                default: None,
402                comment: None,
403                primary_key: None,
404                unique: None,
405                index: None,
406                foreign_key: None,
407            }],
408            constraints: vec![],
409        };
410        fs::write(
411            models_dir.join("users.yaml"),
412            serde_yaml::to_string(&table).unwrap(),
413        )
414        .unwrap();
415
416        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
417        assert!(result.is_ok());
418        let models = result.unwrap();
419        assert_eq!(models.len(), 1);
420        assert_eq!(models[0].name, "users");
421    }
422
423    #[test]
424    #[serial]
425    fn test_load_models_from_dir_with_yml() {
426        let temp_dir = tempdir().unwrap();
427        let models_dir = temp_dir.path().join("models");
428        fs::create_dir_all(&models_dir).unwrap();
429
430        let table = TableDef {
431            name: "users".into(),
432            columns: vec![ColumnDef {
433                name: "id".into(),
434                r#type: ColumnType::Simple(SimpleColumnType::Integer),
435                nullable: false,
436                default: None,
437                comment: None,
438                primary_key: None,
439                unique: None,
440                index: None,
441                foreign_key: None,
442            }],
443            constraints: vec![],
444        };
445        fs::write(
446            models_dir.join("users.yml"),
447            serde_yaml::to_string(&table).unwrap(),
448        )
449        .unwrap();
450
451        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
452        assert!(result.is_ok());
453        let models = result.unwrap();
454        assert_eq!(models.len(), 1);
455        assert_eq!(models[0].name, "users");
456    }
457
458    #[test]
459    #[serial]
460    fn test_load_models_from_dir_recursive() {
461        let temp_dir = tempdir().unwrap();
462        let models_dir = temp_dir.path().join("models");
463        let subdir = models_dir.join("subdir");
464        fs::create_dir_all(&subdir).unwrap();
465
466        let table = TableDef {
467            name: "subtable".into(),
468            columns: vec![ColumnDef {
469                name: "id".into(),
470                r#type: ColumnType::Simple(SimpleColumnType::Integer),
471                nullable: false,
472                default: None,
473                comment: None,
474                primary_key: None,
475                unique: None,
476                index: None,
477                foreign_key: None,
478            }],
479            constraints: vec![],
480        };
481        fs::write(
482            subdir.join("subtable.json"),
483            serde_json::to_string_pretty(&table).unwrap(),
484        )
485        .unwrap();
486
487        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
488        assert!(result.is_ok());
489        let models = result.unwrap();
490        assert_eq!(models.len(), 1);
491        assert_eq!(models[0].name, "subtable");
492    }
493
494    #[test]
495    #[serial]
496    fn test_load_models_from_dir_with_invalid_json() {
497        let temp_dir = tempdir().unwrap();
498        let models_dir = temp_dir.path().join("models");
499        fs::create_dir_all(&models_dir).unwrap();
500
501        fs::write(models_dir.join("invalid.json"), r#"{"invalid": json}"#).unwrap();
502
503        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
504        assert!(result.is_err());
505        let err_msg = result.unwrap_err().to_string();
506        assert!(err_msg.contains("Failed to parse JSON model"));
507    }
508
509    #[test]
510    #[serial]
511    fn test_load_models_from_dir_with_invalid_yaml() {
512        let temp_dir = tempdir().unwrap();
513        let models_dir = temp_dir.path().join("models");
514        fs::create_dir_all(&models_dir).unwrap();
515
516        fs::write(models_dir.join("invalid.yaml"), r#"invalid: [yaml"#).unwrap();
517
518        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
519        assert!(result.is_err());
520        let err_msg = result.unwrap_err().to_string();
521        assert!(err_msg.contains("Failed to parse YAML model"));
522    }
523
524    #[test]
525    #[serial]
526    fn test_load_models_from_dir_normalization_error() {
527        let temp_dir = tempdir().unwrap();
528        let models_dir = temp_dir.path().join("models");
529        fs::create_dir_all(&models_dir).unwrap();
530
531        // Create a model with invalid FK format
532        let table = TableDef {
533            name: "orders".into(),
534            columns: vec![ColumnDef {
535                name: "user_id".into(),
536                r#type: ColumnType::Simple(SimpleColumnType::Integer),
537                nullable: false,
538                default: None,
539                comment: None,
540                primary_key: None,
541                unique: None,
542                index: None,
543                foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
544            }],
545            constraints: vec![],
546        };
547        fs::write(
548            models_dir.join("orders.json"),
549            serde_json::to_string_pretty(&table).unwrap(),
550        )
551        .unwrap();
552
553        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
554        assert!(result.is_err());
555        let err_msg = result.unwrap_err().to_string();
556        assert!(err_msg.contains("Failed to normalize table 'orders'"));
557    }
558
559    #[test]
560    #[serial]
561    fn test_load_models_from_dir_with_cargo_manifest_dir() {
562        // Test the path where CARGO_MANIFEST_DIR is set (line 87)
563        // In cargo test environment, CARGO_MANIFEST_DIR is usually set
564        let result = load_models_from_dir(None);
565        // This might succeed if CARGO_MANIFEST_DIR is set (like in cargo test)
566        // or fail if it's not set
567        // Either way, we're testing the code path including line 87
568        let _ = result;
569    }
570
571    #[test]
572    #[serial]
573    fn test_load_models_at_compile_time() {
574        // This function just calls load_models_from_dir(None)
575        // We can't easily test it without CARGO_MANIFEST_DIR, but we can verify
576        // it doesn't panic
577        let result = load_models_at_compile_time();
578        // This might succeed if CARGO_MANIFEST_DIR is set (like in cargo test)
579        // or fail if it's not set
580        // Either way, we're testing the code path
581        let _ = result;
582    }
583}