vespertide_loader/
models.rs

1use std::fs;
2use std::path::Path;
3
4use anyhow::{Context, Result};
5use vespertide_config::VespertideConfig;
6use vespertide_core::TableDef;
7use vespertide_planner::validate_schema;
8
9/// Load all model definitions from the models directory (recursively).
10pub fn load_models(config: &VespertideConfig) -> Result<Vec<TableDef>> {
11    let models_dir = config.models_dir();
12    if !models_dir.exists() {
13        return Ok(Vec::new());
14    }
15
16    let mut tables = Vec::new();
17    load_models_recursive(models_dir, &mut tables)?;
18
19    // Normalize tables to convert inline constraints (primary_key, foreign_key, etc.) to table-level constraints
20    // This must happen before validation so that foreign key references can be checked
21    let normalized_tables: Vec<TableDef> = tables
22        .into_iter()
23        .map(|t| {
24            t.normalize()
25                .map_err(|e| anyhow::anyhow!("Failed to normalize table '{}': {}", t.name, e))
26        })
27        .collect::<Result<Vec<_>, _>>()?;
28
29    // Validate schema integrity before returning
30    if !normalized_tables.is_empty() {
31        validate_schema(&normalized_tables)
32            .map_err(|e| anyhow::anyhow!("schema validation failed: {}", e))?;
33    }
34
35    Ok(normalized_tables)
36}
37
38/// Recursively walk directory and load model files.
39fn load_models_recursive(dir: &Path, tables: &mut Vec<TableDef>) -> Result<()> {
40    let entries =
41        fs::read_dir(dir).with_context(|| format!("read models directory: {}", dir.display()))?;
42
43    for entry in entries {
44        let entry = entry.context("read directory entry")?;
45        let path = entry.path();
46
47        if path.is_dir() {
48            // Recursively process subdirectories
49            load_models_recursive(&path, tables)?;
50            continue;
51        }
52
53        if path.is_file() {
54            let ext = path.extension().and_then(|s| s.to_str());
55            if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
56                let content = fs::read_to_string(&path)
57                    .with_context(|| format!("read model file: {}", path.display()))?;
58
59                let table: TableDef = if ext == Some("json") {
60                    serde_json::from_str(&content)
61                        .with_context(|| format!("parse JSON model: {}", path.display()))?
62                } else {
63                    serde_yaml::from_str(&content)
64                        .with_context(|| format!("parse YAML model: {}", path.display()))?
65                };
66
67                tables.push(table);
68            }
69        }
70    }
71
72    Ok(())
73}
74
75/// Load models from a specific directory (for compile-time use in macros).
76pub fn load_models_from_dir(
77    project_root: Option<std::path::PathBuf>,
78) -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
79    use std::env;
80
81    // Locate project root from CARGO_MANIFEST_DIR or use provided path
82    let project_root = if let Some(root) = project_root {
83        root
84    } else {
85        std::path::PathBuf::from(
86            env::var("CARGO_MANIFEST_DIR")
87                .context("CARGO_MANIFEST_DIR environment variable not set")?,
88        )
89    };
90
91    // Read vespertide.json or use defaults
92    let config = crate::config::load_config_or_default(Some(project_root.clone()))
93        .map_err(|e| format!("Failed to load config: {}", e))?;
94
95    // Read models directory
96    let models_dir = project_root.join(config.models_dir());
97    if !models_dir.exists() {
98        return Ok(Vec::new());
99    }
100
101    let mut tables = Vec::new();
102    load_models_recursive_internal(&models_dir, &mut tables)
103        .map_err(|e| format!("Failed to load models: {}", e))?;
104
105    // Normalize tables
106    let normalized_tables: Vec<TableDef> = tables
107        .into_iter()
108        .map(|t| {
109            t.normalize()
110                .map_err(|e| format!("Failed to normalize table '{}': {}", t.name, e))
111        })
112        .collect::<Result<Vec<_>, _>>()
113        .map_err(|e| e.to_string())?;
114
115    Ok(normalized_tables)
116}
117
118/// Internal recursive function for loading models (used by both runtime and compile-time).
119fn load_models_recursive_internal(
120    dir: &Path,
121    tables: &mut Vec<TableDef>,
122) -> Result<(), Box<dyn std::error::Error>> {
123    use std::fs;
124
125    let entries = fs::read_dir(dir)
126        .map_err(|e| format!("Failed to read models directory {}: {}", dir.display(), e))?;
127
128    for entry in entries {
129        let entry = entry.map_err(|e| format!("Failed to read directory entry: {}", e))?;
130        let path = entry.path();
131
132        if path.is_dir() {
133            // Recursively process subdirectories
134            load_models_recursive_internal(&path, tables)?;
135            continue;
136        }
137
138        if path.is_file() {
139            let ext = path.extension().and_then(|s| s.to_str());
140            if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
141                let content = fs::read_to_string(&path)
142                    .map_err(|e| format!("Failed to read model file {}: {}", path.display(), e))?;
143
144                let table: TableDef = if ext == Some("json") {
145                    serde_json::from_str(&content).map_err(|e| {
146                        format!("Failed to parse JSON model {}: {}", path.display(), e)
147                    })?
148                } else {
149                    serde_yaml::from_str(&content).map_err(|e| {
150                        format!("Failed to parse YAML model {}: {}", path.display(), e)
151                    })?
152                };
153
154                tables.push(table);
155            }
156        }
157    }
158
159    Ok(())
160}
161
162/// Load models at compile time (for macro use).
163pub fn load_models_at_compile_time() -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
164    load_models_from_dir(None)
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use serial_test::serial;
171    use std::fs;
172    use tempfile::tempdir;
173    use vespertide_core::{
174        ColumnDef, ColumnType, SimpleColumnType, TableConstraint,
175        schema::foreign_key::ForeignKeySyntax,
176    };
177
178    struct CwdGuard {
179        original: std::path::PathBuf,
180    }
181
182    impl CwdGuard {
183        fn new(dir: &std::path::PathBuf) -> Self {
184            let original = std::env::current_dir().unwrap();
185            std::env::set_current_dir(dir).unwrap();
186            Self { original }
187        }
188    }
189
190    impl Drop for CwdGuard {
191        fn drop(&mut self) {
192            let _ = std::env::set_current_dir(&self.original);
193        }
194    }
195
196    fn write_config() {
197        let cfg = VespertideConfig::default();
198        let text = serde_json::to_string_pretty(&cfg).unwrap();
199        fs::write("vespertide.json", text).unwrap();
200    }
201
202    #[test]
203    #[serial]
204    fn load_models_returns_empty_when_no_models_dir() {
205        let tmp = tempdir().unwrap();
206        let _guard = CwdGuard::new(&tmp.path().to_path_buf());
207        write_config();
208
209        // Don't create models directory
210        let models = load_models(&VespertideConfig::default()).unwrap();
211        assert_eq!(models.len(), 0);
212    }
213
214    #[test]
215    #[serial]
216    fn load_models_reads_yaml_and_validates() {
217        let tmp = tempdir().unwrap();
218        let _guard = CwdGuard::new(&tmp.path().to_path_buf());
219        write_config();
220
221        fs::create_dir_all("models").unwrap();
222        let table = TableDef {
223            name: "users".into(),
224            columns: vec![ColumnDef {
225                name: "id".into(),
226                r#type: ColumnType::Simple(SimpleColumnType::Integer),
227                nullable: false,
228                default: None,
229                comment: None,
230                primary_key: None,
231                unique: None,
232                index: None,
233                foreign_key: None,
234            }],
235            constraints: vec![TableConstraint::PrimaryKey {
236                auto_increment: false,
237                columns: vec!["id".into()],
238            }],
239            indexes: vec![],
240        };
241        fs::write("models/users.yaml", serde_yaml::to_string(&table).unwrap()).unwrap();
242
243        let models = load_models(&VespertideConfig::default()).unwrap();
244        assert_eq!(models.len(), 1);
245        assert_eq!(models[0].name, "users");
246    }
247
248    #[test]
249    #[serial]
250    fn load_models_recursive_processes_subdirectories() {
251        let tmp = tempdir().unwrap();
252        let _guard = CwdGuard::new(&tmp.path().to_path_buf());
253        write_config();
254
255        fs::create_dir_all("models/subdir").unwrap();
256
257        // Create model in subdirectory
258        let table = TableDef {
259            name: "subtable".into(),
260            columns: vec![ColumnDef {
261                name: "id".into(),
262                r#type: ColumnType::Simple(SimpleColumnType::Integer),
263                nullable: false,
264                default: None,
265                comment: None,
266                primary_key: None,
267                unique: None,
268                index: None,
269                foreign_key: None,
270            }],
271            constraints: vec![TableConstraint::PrimaryKey {
272                auto_increment: false,
273                columns: vec!["id".into()],
274            }],
275            indexes: vec![],
276        };
277        let content = serde_json::to_string_pretty(&table).unwrap();
278        fs::write("models/subdir/subtable.json", content).unwrap();
279
280        let models = load_models(&VespertideConfig::default()).unwrap();
281        assert_eq!(models.len(), 1);
282        assert_eq!(models[0].name, "subtable");
283    }
284
285    #[test]
286    #[serial]
287    fn load_models_fails_on_invalid_fk_format() {
288        let tmp = tempdir().unwrap();
289        let _guard = CwdGuard::new(&tmp.path().to_path_buf());
290        write_config();
291
292        fs::create_dir_all("models").unwrap();
293
294        // Create a model with invalid FK string format (missing dot separator)
295        let table = TableDef {
296            name: "orders".into(),
297            columns: vec![ColumnDef {
298                name: "user_id".into(),
299                r#type: ColumnType::Simple(SimpleColumnType::Integer),
300                nullable: false,
301                default: None,
302                comment: None,
303                primary_key: None,
304                unique: None,
305                index: None,
306                // Invalid FK format: should be "table.column" but missing the dot
307                foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
308            }],
309            constraints: vec![],
310            indexes: vec![],
311        };
312        fs::write(
313            "models/orders.json",
314            serde_json::to_string_pretty(&table).unwrap(),
315        )
316        .unwrap();
317
318        let result = load_models(&VespertideConfig::default());
319        assert!(result.is_err());
320        let err_msg = result.unwrap_err().to_string();
321        assert!(err_msg.contains("Failed to normalize table 'orders'"));
322    }
323
324    #[test]
325    #[serial]
326    fn test_load_models_from_dir_with_root() {
327        let temp_dir = tempdir().unwrap();
328        let models_dir = temp_dir.path().join("models");
329        fs::create_dir_all(&models_dir).unwrap();
330
331        let table = TableDef {
332            name: "users".into(),
333            columns: vec![ColumnDef {
334                name: "id".into(),
335                r#type: ColumnType::Simple(SimpleColumnType::Integer),
336                nullable: false,
337                default: None,
338                comment: None,
339                primary_key: None,
340                unique: None,
341                index: None,
342                foreign_key: None,
343            }],
344            constraints: vec![],
345            indexes: vec![],
346        };
347        fs::write(
348            models_dir.join("users.json"),
349            serde_json::to_string_pretty(&table).unwrap(),
350        )
351        .unwrap();
352
353        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
354        assert!(result.is_ok());
355        let models = result.unwrap();
356        assert_eq!(models.len(), 1);
357        assert_eq!(models[0].name, "users");
358    }
359
360    #[test]
361    #[serial]
362    fn test_load_models_from_dir_without_root() {
363        use std::env;
364
365        // Save the original value
366        let original = env::var("CARGO_MANIFEST_DIR").ok();
367
368        // Remove CARGO_MANIFEST_DIR to test the error path
369        unsafe {
370            env::remove_var("CARGO_MANIFEST_DIR");
371        }
372
373        let result = load_models_from_dir(None);
374        assert!(result.is_err());
375        let err_msg = result.unwrap_err().to_string();
376        assert!(err_msg.contains("CARGO_MANIFEST_DIR environment variable not set"));
377
378        drop(original);
379    }
380
381    #[test]
382    #[serial]
383    fn test_load_models_from_dir_no_models_dir() {
384        let temp_dir = tempdir().unwrap();
385        // Don't create models directory
386
387        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
388        assert!(result.is_ok());
389        let models = result.unwrap();
390        assert_eq!(models.len(), 0);
391    }
392
393    #[test]
394    #[serial]
395    fn test_load_models_from_dir_with_yaml() {
396        let temp_dir = tempdir().unwrap();
397        let models_dir = temp_dir.path().join("models");
398        fs::create_dir_all(&models_dir).unwrap();
399
400        let table = TableDef {
401            name: "users".into(),
402            columns: vec![ColumnDef {
403                name: "id".into(),
404                r#type: ColumnType::Simple(SimpleColumnType::Integer),
405                nullable: false,
406                default: None,
407                comment: None,
408                primary_key: None,
409                unique: None,
410                index: None,
411                foreign_key: None,
412            }],
413            constraints: vec![],
414            indexes: vec![],
415        };
416        fs::write(
417            models_dir.join("users.yaml"),
418            serde_yaml::to_string(&table).unwrap(),
419        )
420        .unwrap();
421
422        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
423        assert!(result.is_ok());
424        let models = result.unwrap();
425        assert_eq!(models.len(), 1);
426        assert_eq!(models[0].name, "users");
427    }
428
429    #[test]
430    #[serial]
431    fn test_load_models_from_dir_with_yml() {
432        let temp_dir = tempdir().unwrap();
433        let models_dir = temp_dir.path().join("models");
434        fs::create_dir_all(&models_dir).unwrap();
435
436        let table = TableDef {
437            name: "users".into(),
438            columns: vec![ColumnDef {
439                name: "id".into(),
440                r#type: ColumnType::Simple(SimpleColumnType::Integer),
441                nullable: false,
442                default: None,
443                comment: None,
444                primary_key: None,
445                unique: None,
446                index: None,
447                foreign_key: None,
448            }],
449            constraints: vec![],
450            indexes: vec![],
451        };
452        fs::write(
453            models_dir.join("users.yml"),
454            serde_yaml::to_string(&table).unwrap(),
455        )
456        .unwrap();
457
458        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
459        assert!(result.is_ok());
460        let models = result.unwrap();
461        assert_eq!(models.len(), 1);
462        assert_eq!(models[0].name, "users");
463    }
464
465    #[test]
466    #[serial]
467    fn test_load_models_from_dir_recursive() {
468        let temp_dir = tempdir().unwrap();
469        let models_dir = temp_dir.path().join("models");
470        let subdir = models_dir.join("subdir");
471        fs::create_dir_all(&subdir).unwrap();
472
473        let table = TableDef {
474            name: "subtable".into(),
475            columns: vec![ColumnDef {
476                name: "id".into(),
477                r#type: ColumnType::Simple(SimpleColumnType::Integer),
478                nullable: false,
479                default: None,
480                comment: None,
481                primary_key: None,
482                unique: None,
483                index: None,
484                foreign_key: None,
485            }],
486            constraints: vec![],
487            indexes: vec![],
488        };
489        fs::write(
490            subdir.join("subtable.json"),
491            serde_json::to_string_pretty(&table).unwrap(),
492        )
493        .unwrap();
494
495        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
496        assert!(result.is_ok());
497        let models = result.unwrap();
498        assert_eq!(models.len(), 1);
499        assert_eq!(models[0].name, "subtable");
500    }
501
502    #[test]
503    #[serial]
504    fn test_load_models_from_dir_with_invalid_json() {
505        let temp_dir = tempdir().unwrap();
506        let models_dir = temp_dir.path().join("models");
507        fs::create_dir_all(&models_dir).unwrap();
508
509        fs::write(models_dir.join("invalid.json"), r#"{"invalid": json}"#).unwrap();
510
511        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
512        assert!(result.is_err());
513        let err_msg = result.unwrap_err().to_string();
514        assert!(err_msg.contains("Failed to parse JSON model"));
515    }
516
517    #[test]
518    #[serial]
519    fn test_load_models_from_dir_with_invalid_yaml() {
520        let temp_dir = tempdir().unwrap();
521        let models_dir = temp_dir.path().join("models");
522        fs::create_dir_all(&models_dir).unwrap();
523
524        fs::write(models_dir.join("invalid.yaml"), r#"invalid: [yaml"#).unwrap();
525
526        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
527        assert!(result.is_err());
528        let err_msg = result.unwrap_err().to_string();
529        assert!(err_msg.contains("Failed to parse YAML model"));
530    }
531
532    #[test]
533    #[serial]
534    fn test_load_models_from_dir_normalization_error() {
535        let temp_dir = tempdir().unwrap();
536        let models_dir = temp_dir.path().join("models");
537        fs::create_dir_all(&models_dir).unwrap();
538
539        // Create a model with invalid FK format
540        let table = TableDef {
541            name: "orders".into(),
542            columns: vec![ColumnDef {
543                name: "user_id".into(),
544                r#type: ColumnType::Simple(SimpleColumnType::Integer),
545                nullable: false,
546                default: None,
547                comment: None,
548                primary_key: None,
549                unique: None,
550                index: None,
551                foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
552            }],
553            constraints: vec![],
554            indexes: vec![],
555        };
556        fs::write(
557            models_dir.join("orders.json"),
558            serde_json::to_string_pretty(&table).unwrap(),
559        )
560        .unwrap();
561
562        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
563        assert!(result.is_err());
564        let err_msg = result.unwrap_err().to_string();
565        assert!(err_msg.contains("Failed to normalize table 'orders'"));
566    }
567
568    #[test]
569    #[serial]
570    fn test_load_models_from_dir_with_cargo_manifest_dir() {
571        // Test the path where CARGO_MANIFEST_DIR is set (line 87)
572        // In cargo test environment, CARGO_MANIFEST_DIR is usually set
573        let result = load_models_from_dir(None);
574        // This might succeed if CARGO_MANIFEST_DIR is set (like in cargo test)
575        // or fail if it's not set
576        // Either way, we're testing the code path including line 87
577        let _ = result;
578    }
579
580    #[test]
581    #[serial]
582    fn test_load_models_at_compile_time() {
583        // This function just calls load_models_from_dir(None)
584        // We can't easily test it without CARGO_MANIFEST_DIR, but we can verify
585        // it doesn't panic
586        let result = load_models_at_compile_time();
587        // This might succeed if CARGO_MANIFEST_DIR is set (like in cargo test)
588        // or fail if it's not set
589        // Either way, we're testing the code path
590        let _ = result;
591    }
592}