1use crate::ast::Program;
4use crate::module::{module_path_to_string, LoadedModule, ModuleError, ModulePath};
5use crate::parser::parse_program;
6use std::collections::{HashMap, HashSet};
7use std::fs;
8use std::path::{Path, PathBuf};
9
10pub struct ModuleResolver {
12 search_paths: Vec<PathBuf>,
14 loaded: HashMap<String, LoadedModule>,
16 loading: Vec<ModulePath>,
18}
19
20impl ModuleResolver {
21 pub fn new(search_paths: Vec<PathBuf>) -> Self {
23 Self {
24 search_paths,
25 loaded: HashMap::new(),
26 loading: Vec::new(),
27 }
28 }
29
30 pub fn find_module_file(&self, base_dir: &Path, module_path: &[String]) -> Option<PathBuf> {
32 let relative_path = format!("{}.xlog", module_path.join("/"));
33
34 let candidate = base_dir.join(&relative_path);
36 if candidate.exists() {
37 return Some(candidate);
38 }
39
40 for search_path in &self.search_paths {
42 let candidate = search_path.join(&relative_path);
43 if candidate.exists() {
44 return Some(candidate);
45 }
46 }
47
48 None
49 }
50
51 fn searched_paths(&self, base_dir: &Path, module_path: &[String]) -> Vec<PathBuf> {
53 let relative_path = format!("{}.xlog", module_path.join("/"));
54 let mut searched = vec![base_dir.join(&relative_path)];
55 for sp in &self.search_paths {
56 searched.push(sp.join(&relative_path));
57 }
58 searched
59 }
60
61 fn check_cycle(&self, module_path: &[String]) -> Option<Vec<ModulePath>> {
63 let path_str = module_path_to_string(module_path);
64 for (i, loading_path) in self.loading.iter().enumerate() {
65 if module_path_to_string(loading_path) == path_str {
66 let mut cycle: Vec<ModulePath> = self.loading[i..].to_vec();
68 cycle.push(module_path.to_vec());
69 return Some(cycle);
70 }
71 }
72 None
73 }
74
75 pub fn extract_exports(program: &Program) -> (HashSet<String>, HashSet<String>) {
78 let mut pred_exports = HashSet::new();
79 let mut func_exports = HashSet::new();
80
81 for pred in &program.predicates {
83 if !pred.is_private {
84 pred_exports.insert(pred.name.clone());
85 }
86 }
87
88 for rule in &program.rules {
90 let is_private = program
92 .predicates
93 .iter()
94 .any(|p| p.name == rule.head.predicate && p.is_private);
95 if !is_private {
96 pred_exports.insert(rule.head.predicate.clone());
97 }
98 }
99
100 for func in &program.functions {
102 if !func.is_private {
103 func_exports.insert(func.name.clone());
104 }
105 }
106
107 (pred_exports, func_exports)
108 }
109
110 pub fn load_module(
112 &mut self,
113 base_dir: &Path,
114 module_path: &[String],
115 ) -> Result<&LoadedModule, ModuleError> {
116 let path_key = module_path_to_string(module_path);
117
118 if self.loaded.contains_key(&path_key) {
120 return Ok(self.loaded.get(&path_key).unwrap());
121 }
122
123 if let Some(cycle) = self.check_cycle(module_path) {
125 return Err(ModuleError::CircularImport { cycle });
126 }
127
128 let source_file = self
130 .find_module_file(base_dir, module_path)
131 .ok_or_else(|| ModuleError::NotFound {
132 path: module_path.to_vec(),
133 searched: self.searched_paths(base_dir, module_path),
134 })?;
135
136 self.loading.push(module_path.to_vec());
138
139 let source = fs::read_to_string(&source_file).map_err(|e| ModuleError::ParseError {
141 path: source_file.clone(),
142 message: e.to_string(),
143 })?;
144
145 let program = parse_program(&source).map_err(|e| ModuleError::ParseError {
146 path: source_file.clone(),
147 message: e.to_string(),
148 })?;
149
150 let (exports, function_exports) = Self::extract_exports(&program);
152
153 let module_dir = source_file.parent().unwrap_or(base_dir);
155 for import in &program.imports {
156 self.load_module(module_dir, &import.module_path)?;
157 }
158
159 self.loading.pop();
161
162 let module = LoadedModule {
164 path: module_path.to_vec(),
165 source_file,
166 exports,
167 function_exports,
168 program,
169 };
170
171 self.loaded.insert(path_key.clone(), module);
172 Ok(self.loaded.get(&path_key).unwrap())
173 }
174
175 pub fn check_import(&self, module_path: &[String], predicate: &str) -> Result<(), ModuleError> {
177 let path_key = module_path_to_string(module_path);
178 let module = self
179 .loaded
180 .get(&path_key)
181 .ok_or_else(|| ModuleError::NotFound {
182 path: module_path.to_vec(),
183 searched: vec![],
184 })?;
185
186 if !module.exports.contains(predicate) {
187 return Err(ModuleError::PredicateNotFound {
188 name: predicate.to_string(),
189 module: module_path.to_vec(),
190 });
191 }
192
193 Ok(())
194 }
195
196 #[allow(clippy::type_complexity)]
199 pub fn validate_imports(
200 &self,
201 program: &Program,
202 ) -> Result<(HashMap<String, ModulePath>, HashMap<String, ModulePath>), ModuleError> {
203 let mut imported_predicates: HashMap<String, ModulePath> = HashMap::new();
204 let mut imported_functions: HashMap<String, ModulePath> = HashMap::new();
205
206 for use_decl in &program.imports {
207 let module = self
208 .loaded
209 .get(&module_path_to_string(&use_decl.module_path))
210 .expect("module should be loaded");
211
212 let all_exports: HashSet<String> = module
214 .exports
215 .iter()
216 .chain(module.function_exports.iter())
217 .cloned()
218 .collect();
219
220 let names_to_import: Vec<String> = match &use_decl.imports {
221 Some(specific) => specific.clone(),
222 None => all_exports.iter().cloned().collect(),
223 };
224
225 for name in names_to_import {
226 let is_predicate = module.exports.contains(&name);
228 let is_function = module.function_exports.contains(&name);
229
230 if !is_predicate && !is_function {
231 return Err(ModuleError::PredicateNotFound {
232 name: name.clone(),
233 module: use_decl.module_path.clone(),
234 });
235 }
236
237 if is_predicate {
239 if let Some(prev_module) = imported_predicates.get(&name) {
240 if prev_module != &use_decl.module_path {
241 return Err(ModuleError::ImportConflict {
242 name,
243 module1: prev_module.clone(),
244 module2: use_decl.module_path.clone(),
245 });
246 }
247 }
248 imported_predicates.insert(name.clone(), use_decl.module_path.clone());
249 }
250
251 if is_function {
253 if let Some(prev_module) = imported_functions.get(&name) {
254 if prev_module != &use_decl.module_path {
255 return Err(ModuleError::ImportConflict {
256 name,
257 module1: prev_module.clone(),
258 module2: use_decl.module_path.clone(),
259 });
260 }
261 }
262 imported_functions.insert(name.clone(), use_decl.module_path.clone());
263 }
264 }
265 }
266
267 Ok((imported_predicates, imported_functions))
268 }
269
270 pub fn get_module(&self, module_path: &[String]) -> Option<&LoadedModule> {
272 self.loaded.get(&module_path_to_string(module_path))
273 }
274
275 pub fn is_loaded(&self, module_path: &str) -> bool {
277 self.loaded.contains_key(module_path)
278 }
279
280 pub fn loaded_modules(&self) -> Vec<&str> {
282 self.loaded.keys().map(|s| s.as_str()).collect()
283 }
284
285 pub fn merge_imports(&self, mut program: Program) -> Result<Program, ModuleError> {
294 for use_decl in &program.imports.clone() {
295 let path_key = module_path_to_string(&use_decl.module_path);
296 let loaded_module =
297 self.loaded
298 .get(&path_key)
299 .ok_or_else(|| ModuleError::NotFound {
300 path: use_decl.module_path.clone(),
301 searched: vec![],
302 })?;
303
304 let imported_items = match &use_decl.imports {
306 Some(items) if !items.is_empty() => {
307 Some(items.iter().cloned().collect())
309 }
310 _ => {
311 None
313 }
314 };
315
316 program.merge_from(&loaded_module.program, imported_items.as_ref());
318 }
319
320 Ok(program)
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use std::io::Write;
328 use tempfile::TempDir;
329
330 fn create_test_module(dir: &Path, name: &str, content: &str) -> PathBuf {
331 let path = dir.join(format!("{}.xlog", name));
332 let mut file = fs::File::create(&path).unwrap();
333 file.write_all(content.as_bytes()).unwrap();
334 path
335 }
336
337 #[test]
338 fn test_find_module_file() {
339 let tmp = TempDir::new().unwrap();
340 create_test_module(tmp.path(), "graph", "edge(1, 2).");
341
342 let resolver = ModuleResolver::new(vec![]);
343 let found = resolver.find_module_file(tmp.path(), &["graph".into()]);
344 assert!(found.is_some());
345 }
346
347 #[test]
348 fn test_module_not_found() {
349 let tmp = TempDir::new().unwrap();
350 let mut resolver = ModuleResolver::new(vec![]);
351
352 let result = resolver.load_module(tmp.path(), &["nonexistent".into()]);
353 assert!(matches!(result, Err(ModuleError::NotFound { .. })));
354 }
355
356 #[test]
357 fn test_circular_import() {
358 let tmp = TempDir::new().unwrap();
359 create_test_module(tmp.path(), "a", "use b.");
360 create_test_module(tmp.path(), "b", "use a.");
361
362 let mut resolver = ModuleResolver::new(vec![]);
363 let result = resolver.load_module(tmp.path(), &["a".into()]);
364 assert!(matches!(result, Err(ModuleError::CircularImport { .. })));
365 }
366
367 #[test]
368 fn test_load_simple_module() {
369 let tmp = TempDir::new().unwrap();
370 create_test_module(
371 tmp.path(),
372 "math",
373 r#"
374 pred add(u32, u32, u32).
375 add(1, 2, 3).
376 "#,
377 );
378
379 let mut resolver = ModuleResolver::new(vec![]);
380 let result = resolver.load_module(tmp.path(), &["math".into()]);
381 assert!(result.is_ok());
382 let module = result.unwrap();
383 assert!(module.exports.contains("add"));
384 }
385
386 #[test]
387 fn test_private_not_exported() {
388 let tmp = TempDir::new().unwrap();
389 create_test_module(
390 tmp.path(),
391 "graph",
392 r#"
393 pred edge(u32, u32).
394 private pred helper(u32).
395 edge(1, 2).
396 helper(1).
397 "#,
398 );
399
400 let mut resolver = ModuleResolver::new(vec![]);
401 let result = resolver.load_module(tmp.path(), &["graph".into()]);
402 assert!(result.is_ok());
403 let module = result.unwrap();
404 assert!(module.exports.contains("edge"));
405 assert!(!module.exports.contains("helper"));
406 }
407
408 #[test]
409 fn test_search_paths() {
410 let tmp = TempDir::new().unwrap();
411 let lib_dir = tmp.path().join("lib");
412 fs::create_dir(&lib_dir).unwrap();
413 create_test_module(&lib_dir, "stdlib", "helper(1).");
414
415 let resolver = ModuleResolver::new(vec![lib_dir.clone()]);
416 let found = resolver.find_module_file(tmp.path(), &["stdlib".into()]);
417 assert!(found.is_some());
418 assert!(found.unwrap().starts_with(&lib_dir));
419 }
420
421 #[test]
422 fn test_function_exports() {
423 let tmp = TempDir::new().unwrap();
424 create_test_module(
425 tmp.path(),
426 "mathfuncs",
427 r#"
428 func square(X) = X * X.
429 func cube(X) = X * X * X.
430 private func helper(X) = X.
431 "#,
432 );
433
434 let mut resolver = ModuleResolver::new(vec![]);
435 let result = resolver.load_module(tmp.path(), &["mathfuncs".into()]);
436 assert!(result.is_ok());
437 let module = result.unwrap();
438
439 assert!(module.function_exports.contains("square"));
441 assert!(module.function_exports.contains("cube"));
442
443 assert!(!module.function_exports.contains("helper"));
445 }
446
447 #[test]
448 fn test_mixed_exports() {
449 let tmp = TempDir::new().unwrap();
450 create_test_module(
451 tmp.path(),
452 "mixed",
453 r#"
454 pred value(i64).
455 value(42).
456 func double(X) = X * 2.
457 "#,
458 );
459
460 let mut resolver = ModuleResolver::new(vec![]);
461 let result = resolver.load_module(tmp.path(), &["mixed".into()]);
462 assert!(result.is_ok());
463 let module = result.unwrap();
464
465 assert!(module.exports.contains("value"));
467 assert!(module.function_exports.contains("double"));
468 }
469}