1use crate::ast::{ArithExpr, CompOp, CondExpr, FuncBody, FuncDef, Program};
4use std::collections::{HashMap, HashSet};
5use xlog_core::ScalarType;
6
7#[derive(Debug, Clone)]
9#[non_exhaustive]
10pub enum FunctionError {
11 DuplicateDefinition {
13 name: String,
15 },
16 RecursionWithoutBaseCase {
18 name: String,
20 },
21 UndefinedFunction {
23 name: String,
25 },
26 MaxRecursionDepth {
28 name: String,
30 depth: u32,
32 },
33 NameConflict {
35 name: String,
37 },
38}
39
40impl std::fmt::Display for FunctionError {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 FunctionError::DuplicateDefinition { name } => {
44 write!(f, "error[E0501]: duplicate function definition `{}`", name)
45 }
46 FunctionError::RecursionWithoutBaseCase { name } => {
47 writeln!(
48 f,
49 "error[E0502]: recursive function `{}` without base case",
50 name
51 )?;
52 write!(
53 f,
54 " = help: use conditional form: `if <condition> then <base> else <recursive>`"
55 )
56 }
57 FunctionError::UndefinedFunction { name } => {
58 write!(f, "error[E0503]: undefined function `{}`", name)
59 }
60 FunctionError::MaxRecursionDepth { name, depth } => {
61 write!(
62 f,
63 "error[E0504]: maximum recursion depth ({}) exceeded in function `{}`",
64 depth, name
65 )
66 }
67 FunctionError::NameConflict { name } => {
68 write!(
69 f,
70 "error[E0505]: `{}` is already defined as a predicate",
71 name
72 )
73 }
74 }
75 }
76}
77
78impl std::error::Error for FunctionError {}
79
80#[derive(Debug, Clone)]
82#[allow(dead_code)] pub(crate) enum TypeError {
84 Mismatch {
86 expected: ScalarType,
87 found: ScalarType,
88 location: String,
89 },
90 CannotInfer { name: String },
92}
93
94impl std::fmt::Display for TypeError {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 match self {
97 TypeError::Mismatch {
98 expected,
99 found,
100 location,
101 } => {
102 writeln!(f, "error[E0506]: type mismatch in {}", location)?;
103 write!(f, " expected {:?}, found {:?}", expected, found)
104 }
105 TypeError::CannotInfer { name } => {
106 write!(f, "error[E0507]: cannot infer type for `{}`", name)
107 }
108 }
109 }
110}
111
112impl std::error::Error for TypeError {}
113
114impl From<FunctionError> for xlog_core::XlogError {
115 fn from(e: FunctionError) -> Self {
116 xlog_core::XlogError::Compilation(e.to_string())
117 }
118}
119
120impl From<TypeError> for xlog_core::XlogError {
121 fn from(e: TypeError) -> Self {
122 xlog_core::XlogError::Type(e.to_string())
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct RecursionWarning {
129 pub func_name: String,
131 pub message: String,
133}
134
135impl std::fmt::Display for RecursionWarning {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 writeln!(
138 f,
139 "warning[W0502]: potentially infinite recursion in `{}`",
140 self.func_name
141 )?;
142 writeln!(f, " {}", self.message)?;
143 write!(
144 f,
145 " = note: base case may be unreachable with given recursive call"
146 )
147 }
148}
149
150#[derive(Debug, Default)]
152pub struct FunctionRegistry {
153 functions: HashMap<String, FuncDef>,
154 call_graph: HashMap<String, HashSet<String>>,
155}
156
157impl FunctionRegistry {
158 pub fn new() -> Self {
160 Self::default()
161 }
162
163 pub fn register(&mut self, func: FuncDef) -> Result<(), FunctionError> {
165 if self.functions.contains_key(&func.name) {
166 return Err(FunctionError::DuplicateDefinition {
167 name: func.name.clone(),
168 });
169 }
170
171 let calls = Self::extract_calls(&func.body);
173 self.call_graph.insert(func.name.clone(), calls);
174 self.functions.insert(func.name.clone(), func);
175
176 Ok(())
177 }
178
179 pub fn get(&self, name: &str) -> Option<&FuncDef> {
181 self.functions.get(name)
182 }
183
184 pub fn contains(&self, name: &str) -> bool {
186 self.functions.contains_key(name)
187 }
188
189 fn extract_calls(body: &FuncBody) -> HashSet<String> {
191 let mut calls = HashSet::new();
192 Self::extract_calls_from_body(body, &mut calls);
193 calls
194 }
195
196 fn extract_calls_from_body(body: &FuncBody, calls: &mut HashSet<String>) {
197 match body {
198 FuncBody::Arithmetic(expr) => Self::extract_calls_from_expr(expr, calls),
199 FuncBody::Conditional(cond) => {
200 Self::extract_calls_from_expr(&cond.cond_left, calls);
201 Self::extract_calls_from_expr(&cond.cond_right, calls);
202 Self::extract_calls_from_body(&cond.then_branch, calls);
203 Self::extract_calls_from_body(&cond.else_branch, calls);
204 }
205 FuncBody::Predicate { .. } => {
206 }
208 }
209 }
210
211 fn extract_calls_from_expr(expr: &ArithExpr, calls: &mut HashSet<String>) {
212 match expr {
213 ArithExpr::FuncCall { name, args } => {
214 calls.insert(name.clone());
215 for arg in args {
216 Self::extract_calls_from_expr(arg, calls);
217 }
218 }
219 ArithExpr::Add(l, r)
220 | ArithExpr::Sub(l, r)
221 | ArithExpr::Mul(l, r)
222 | ArithExpr::Div(l, r)
223 | ArithExpr::Mod(l, r)
224 | ArithExpr::Min(l, r)
225 | ArithExpr::Max(l, r)
226 | ArithExpr::Pow(l, r) => {
227 Self::extract_calls_from_expr(l, calls);
228 Self::extract_calls_from_expr(r, calls);
229 }
230 ArithExpr::Abs(e) | ArithExpr::Cast(e, _) => {
231 Self::extract_calls_from_expr(e, calls);
232 }
233 ArithExpr::Variable(_) | ArithExpr::Integer(_) | ArithExpr::Float(_) => {}
234 ArithExpr::Conditional {
235 cond_left,
236 cond_right,
237 then_expr,
238 else_expr,
239 ..
240 } => {
241 Self::extract_calls_from_expr(cond_left, calls);
242 Self::extract_calls_from_expr(cond_right, calls);
243 Self::extract_calls_from_expr(then_expr, calls);
244 Self::extract_calls_from_expr(else_expr, calls);
245 }
246 }
247 }
248
249 pub fn is_recursive(&self, name: &str) -> bool {
251 self.reaches(name, name, &mut HashSet::new())
252 }
253
254 fn reaches(&self, from: &str, target: &str, visited: &mut HashSet<String>) -> bool {
255 if visited.contains(from) {
256 return false;
257 }
258 visited.insert(from.to_string());
259
260 if let Some(calls) = self.call_graph.get(from) {
261 if calls.contains(target) {
262 return true;
263 }
264 for call in calls {
265 if self.reaches(call, target, visited) {
266 return true;
267 }
268 }
269 }
270 false
271 }
272
273 pub fn validate(&self) -> Result<(), FunctionError> {
275 for (name, func) in &self.functions {
276 if let Some(calls) = self.call_graph.get(name) {
278 for call in calls {
279 if !self.functions.contains_key(call) && !is_builtin(call) {
280 return Err(FunctionError::UndefinedFunction { name: call.clone() });
281 }
282 }
283 }
284
285 if self.is_recursive(name) && !Self::has_base_case(&func.body) {
287 return Err(FunctionError::RecursionWithoutBaseCase { name: name.clone() });
288 }
289 }
290 Ok(())
291 }
292
293 fn has_base_case(body: &FuncBody) -> bool {
294 matches!(body, FuncBody::Conditional(_))
295 }
296
297 pub fn from_program(program: &Program) -> Result<Self, FunctionError> {
299 let mut registry = Self::new();
300
301 let pred_names: HashSet<_> = program.predicates.iter().map(|p| p.name.clone()).collect();
303
304 for func in &program.functions {
305 if pred_names.contains(&func.name) {
306 return Err(FunctionError::NameConflict {
307 name: func.name.clone(),
308 });
309 }
310 registry.register(func.clone())?;
311 }
312
313 registry.validate()?;
314 Ok(registry)
315 }
316
317 pub fn functions(&self) -> impl Iterator<Item = &FuncDef> {
319 self.functions.values()
320 }
321
322 pub fn analyze_recursion(&self, func: &FuncDef) -> Option<RecursionWarning> {
324 if !self.is_recursive(&func.name) {
325 return None;
326 }
327
328 match &func.body {
329 FuncBody::Conditional(cond) => self.check_convergence(func, cond),
330 _ => None,
331 }
332 }
333
334 fn check_convergence(&self, func: &FuncDef, cond: &CondExpr) -> Option<RecursionWarning> {
335 let recursive_calls = Self::find_recursive_calls_in_body(&func.name, &cond.else_branch);
337
338 for call_args in recursive_calls {
339 if call_args.is_empty() {
340 continue;
341 }
342
343 if let (ArithExpr::Variable(var), CompOp::Le | CompOp::Lt) =
346 (&cond.cond_left, cond.cond_op)
347 {
348 if let ArithExpr::Add(left, right) = &call_args[0] {
349 if let (ArithExpr::Variable(arg_var), ArithExpr::Integer(n)) =
350 (left.as_ref(), right.as_ref())
351 {
352 if arg_var == var && *n > 0 {
353 return Some(RecursionWarning {
354 func_name: func.name.clone(),
355 message: format!(
356 "recursive call increases `{}`, but base case requires it to decrease",
357 var
358 ),
359 });
360 }
361 }
362 }
363 }
364 }
365
366 None
367 }
368
369 fn find_recursive_calls_in_body(name: &str, body: &FuncBody) -> Vec<Vec<ArithExpr>> {
370 let mut calls = Vec::new();
371 match body {
372 FuncBody::Arithmetic(expr) => {
373 Self::find_recursive_calls_in_expr(name, expr, &mut calls);
374 }
375 FuncBody::Conditional(cond) => {
376 Self::find_recursive_calls_in_expr(name, &cond.cond_left, &mut calls);
377 Self::find_recursive_calls_in_expr(name, &cond.cond_right, &mut calls);
378 calls.extend(Self::find_recursive_calls_in_body(name, &cond.then_branch));
379 calls.extend(Self::find_recursive_calls_in_body(name, &cond.else_branch));
380 }
381 FuncBody::Predicate { .. } => {}
382 }
383 calls
384 }
385
386 fn find_recursive_calls_in_expr(name: &str, expr: &ArithExpr, calls: &mut Vec<Vec<ArithExpr>>) {
387 match expr {
388 ArithExpr::FuncCall {
389 name: fn_name,
390 args,
391 } if fn_name == name => {
392 calls.push(args.clone());
393 }
394 ArithExpr::Add(l, r)
395 | ArithExpr::Sub(l, r)
396 | ArithExpr::Mul(l, r)
397 | ArithExpr::Div(l, r)
398 | ArithExpr::Mod(l, r)
399 | ArithExpr::Min(l, r)
400 | ArithExpr::Max(l, r)
401 | ArithExpr::Pow(l, r) => {
402 Self::find_recursive_calls_in_expr(name, l, calls);
403 Self::find_recursive_calls_in_expr(name, r, calls);
404 }
405 ArithExpr::Abs(e) | ArithExpr::Cast(e, _) => {
406 Self::find_recursive_calls_in_expr(name, e, calls);
407 }
408 ArithExpr::FuncCall { args, .. } => {
409 for arg in args {
410 Self::find_recursive_calls_in_expr(name, arg, calls);
411 }
412 }
413 ArithExpr::Conditional {
414 cond_left,
415 cond_right,
416 then_expr,
417 else_expr,
418 ..
419 } => {
420 Self::find_recursive_calls_in_expr(name, cond_left, calls);
421 Self::find_recursive_calls_in_expr(name, cond_right, calls);
422 Self::find_recursive_calls_in_expr(name, then_expr, calls);
423 Self::find_recursive_calls_in_expr(name, else_expr, calls);
424 }
425 _ => {}
426 }
427 }
428
429 pub fn validate_with_warnings(&self) -> (Result<(), FunctionError>, Vec<RecursionWarning>) {
431 let mut warnings = Vec::new();
432
433 for func in self.functions.values() {
434 if let Some(warning) = self.analyze_recursion(func) {
435 warnings.push(warning);
436 }
437 }
438
439 (self.validate(), warnings)
440 }
441}
442
443fn is_builtin(name: &str) -> bool {
445 matches!(name, "abs" | "min" | "max" | "pow" | "cast")
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use crate::ast::FuncParam;
452 use xlog_core::XlogError;
453
454 #[test]
455 fn test_function_error_into_xlog() {
456 let err = FunctionError::UndefinedFunction {
457 name: "foo".to_string(),
458 };
459 let xlog_err: XlogError = err.into();
460 let msg = xlog_err.to_string();
461 assert!(msg.contains("foo"), "Expected 'foo' in: {msg}");
462 }
463
464 #[test]
465 fn test_type_error_into_xlog() {
466 let err = TypeError::CannotInfer {
467 name: "X".to_string(),
468 };
469 let xlog_err: XlogError = err.into();
470 let msg = xlog_err.to_string();
471 assert!(msg.contains("X"), "Expected 'X' in: {msg}");
472 }
473
474 fn make_arith_func(name: &str, body: ArithExpr) -> FuncDef {
475 FuncDef {
476 name: name.to_string(),
477 params: vec![FuncParam {
478 name: "X".to_string(),
479 typ: None,
480 }],
481 return_type: None,
482 body: FuncBody::Arithmetic(body),
483 is_private: false,
484 }
485 }
486
487 #[test]
488 fn test_register_function() {
489 let mut reg = FunctionRegistry::new();
490 let func = make_arith_func("square", ArithExpr::Variable("X".to_string()));
491 assert!(reg.register(func).is_ok());
492 }
493
494 #[test]
495 fn test_duplicate_error() {
496 let mut reg = FunctionRegistry::new();
497 let func = make_arith_func("f", ArithExpr::Variable("X".to_string()));
498 reg.register(func.clone()).unwrap();
499 let result = reg.register(func);
500 assert!(matches!(
501 result,
502 Err(FunctionError::DuplicateDefinition { .. })
503 ));
504 }
505
506 #[test]
507 fn test_recursive_detection() {
508 let mut reg = FunctionRegistry::new();
509
510 let f = FuncDef {
512 name: "f".to_string(),
513 params: vec![],
514 return_type: None,
515 body: FuncBody::Arithmetic(ArithExpr::FuncCall {
516 name: "f".to_string(),
517 args: vec![],
518 }),
519 is_private: false,
520 };
521 reg.register(f).unwrap();
522
523 assert!(reg.is_recursive("f"));
524 }
525
526 #[test]
527 fn test_get_function() {
528 let mut reg = FunctionRegistry::new();
529 let func = make_arith_func("square", ArithExpr::Variable("X".to_string()));
530 reg.register(func).unwrap();
531
532 assert!(reg.get("square").is_some());
533 assert!(reg.get("nonexistent").is_none());
534 }
535
536 #[test]
537 fn test_contains_function() {
538 let mut reg = FunctionRegistry::new();
539 let func = make_arith_func("square", ArithExpr::Variable("X".to_string()));
540 reg.register(func).unwrap();
541
542 assert!(reg.contains("square"));
543 assert!(!reg.contains("nonexistent"));
544 }
545
546 #[test]
547 fn test_undefined_function_error() {
548 let mut reg = FunctionRegistry::new();
549
550 let f = FuncDef {
552 name: "f".to_string(),
553 params: vec![],
554 return_type: None,
555 body: FuncBody::Arithmetic(ArithExpr::FuncCall {
556 name: "undefined_func".to_string(),
557 args: vec![],
558 }),
559 is_private: false,
560 };
561 reg.register(f).unwrap();
562
563 let result = reg.validate();
564 assert!(matches!(
565 result,
566 Err(FunctionError::UndefinedFunction { .. })
567 ));
568 }
569
570 #[test]
571 fn test_builtin_function_allowed() {
572 let mut reg = FunctionRegistry::new();
573
574 let f = FuncDef {
576 name: "f".to_string(),
577 params: vec![FuncParam {
578 name: "X".to_string(),
579 typ: None,
580 }],
581 return_type: None,
582 body: FuncBody::Arithmetic(ArithExpr::FuncCall {
583 name: "abs".to_string(),
584 args: vec![ArithExpr::Variable("X".to_string())],
585 }),
586 is_private: false,
587 };
588 reg.register(f).unwrap();
589
590 assert!(reg.validate().is_ok());
592 }
593
594 #[test]
595 fn test_indirect_recursion() {
596 let mut reg = FunctionRegistry::new();
597
598 let f = FuncDef {
600 name: "f".to_string(),
601 params: vec![],
602 return_type: None,
603 body: FuncBody::Arithmetic(ArithExpr::FuncCall {
604 name: "g".to_string(),
605 args: vec![],
606 }),
607 is_private: false,
608 };
609
610 let g = FuncDef {
612 name: "g".to_string(),
613 params: vec![],
614 return_type: None,
615 body: FuncBody::Arithmetic(ArithExpr::FuncCall {
616 name: "f".to_string(),
617 args: vec![],
618 }),
619 is_private: false,
620 };
621
622 reg.register(f).unwrap();
623 reg.register(g).unwrap();
624
625 assert!(reg.is_recursive("f"));
626 assert!(reg.is_recursive("g"));
627 }
628
629 #[test]
630 fn test_functions_iterator() {
631 let mut reg = FunctionRegistry::new();
632 let f1 = make_arith_func("f1", ArithExpr::Variable("X".to_string()));
633 let f2 = make_arith_func("f2", ArithExpr::Variable("X".to_string()));
634 reg.register(f1).unwrap();
635 reg.register(f2).unwrap();
636
637 let names: HashSet<_> = reg.functions().map(|f| f.name.as_str()).collect();
638 assert!(names.contains("f1"));
639 assert!(names.contains("f2"));
640 assert_eq!(names.len(), 2);
641 }
642
643 #[test]
644 fn test_type_error_display() {
645 let err = TypeError::Mismatch {
646 expected: ScalarType::I64,
647 found: ScalarType::F64,
648 location: "function f".to_string(),
649 };
650 let msg = err.to_string();
651 assert!(msg.contains("E0506"));
652 assert!(msg.contains("type mismatch"));
653
654 let err2 = TypeError::CannotInfer {
655 name: "X".to_string(),
656 };
657 let msg2 = err2.to_string();
658 assert!(msg2.contains("E0507"));
659 assert!(msg2.contains("cannot infer"));
660 }
661
662 #[test]
663 fn test_recursion_warning_display() {
664 let warning = RecursionWarning {
665 func_name: "fib".to_string(),
666 message: "recursive call increases `N`".to_string(),
667 };
668 let msg = warning.to_string();
669 assert!(msg.contains("W0502"));
670 assert!(msg.contains("infinite recursion"));
671 assert!(msg.contains("fib"));
672 }
673
674 #[test]
675 fn test_analyze_non_recursive() {
676 let mut reg = FunctionRegistry::new();
677 let func = make_arith_func("square", ArithExpr::Variable("X".to_string()));
678 reg.register(func.clone()).unwrap();
679
680 assert!(reg.analyze_recursion(&func).is_none());
682 }
683
684 #[test]
685 fn test_analyze_recursive_with_proper_convergence() {
686 use crate::ast::CondExpr;
687
688 let mut reg = FunctionRegistry::new();
689
690 let factorial = FuncDef {
692 name: "fact".to_string(),
693 params: vec![FuncParam {
694 name: "N".to_string(),
695 typ: None,
696 }],
697 return_type: None,
698 body: FuncBody::Conditional(CondExpr {
699 cond_left: ArithExpr::Variable("N".to_string()),
700 cond_op: CompOp::Le,
701 cond_right: ArithExpr::Integer(1),
702 then_branch: Box::new(FuncBody::Arithmetic(ArithExpr::Integer(1))),
703 else_branch: Box::new(FuncBody::Arithmetic(ArithExpr::Mul(
704 Box::new(ArithExpr::Variable("N".to_string())),
705 Box::new(ArithExpr::FuncCall {
706 name: "fact".to_string(),
707 args: vec![ArithExpr::Sub(
708 Box::new(ArithExpr::Variable("N".to_string())),
709 Box::new(ArithExpr::Integer(1)),
710 )],
711 }),
712 ))),
713 }),
714 is_private: false,
715 };
716
717 reg.register(factorial.clone()).unwrap();
718
719 assert!(reg.analyze_recursion(&factorial).is_none());
721 }
722
723 #[test]
724 fn test_analyze_recursive_with_divergence() {
725 use crate::ast::CondExpr;
726
727 let mut reg = FunctionRegistry::new();
728
729 let bad_func = FuncDef {
732 name: "badfunc".to_string(),
733 params: vec![FuncParam {
734 name: "N".to_string(),
735 typ: None,
736 }],
737 return_type: None,
738 body: FuncBody::Conditional(CondExpr {
739 cond_left: ArithExpr::Variable("N".to_string()),
740 cond_op: CompOp::Le,
741 cond_right: ArithExpr::Integer(1),
742 then_branch: Box::new(FuncBody::Arithmetic(ArithExpr::Integer(1))),
743 else_branch: Box::new(FuncBody::Arithmetic(ArithExpr::FuncCall {
744 name: "badfunc".to_string(),
745 args: vec![ArithExpr::Add(
746 Box::new(ArithExpr::Variable("N".to_string())),
747 Box::new(ArithExpr::Integer(1)),
748 )],
749 })),
750 }),
751 is_private: false,
752 };
753
754 reg.register(bad_func.clone()).unwrap();
755
756 let warning = reg.analyze_recursion(&bad_func);
758 assert!(warning.is_some());
759 assert!(warning.unwrap().message.contains("increases"));
760 }
761
762 #[test]
763 fn test_validate_with_warnings() {
764 use crate::ast::CondExpr;
765
766 let mut reg = FunctionRegistry::new();
767
768 let bad_func = FuncDef {
770 name: "diverging".to_string(),
771 params: vec![FuncParam {
772 name: "X".to_string(),
773 typ: None,
774 }],
775 return_type: None,
776 body: FuncBody::Conditional(CondExpr {
777 cond_left: ArithExpr::Variable("X".to_string()),
778 cond_op: CompOp::Lt,
779 cond_right: ArithExpr::Integer(0),
780 then_branch: Box::new(FuncBody::Arithmetic(ArithExpr::Integer(0))),
781 else_branch: Box::new(FuncBody::Arithmetic(ArithExpr::FuncCall {
782 name: "diverging".to_string(),
783 args: vec![ArithExpr::Add(
784 Box::new(ArithExpr::Variable("X".to_string())),
785 Box::new(ArithExpr::Integer(1)),
786 )],
787 })),
788 }),
789 is_private: false,
790 };
791
792 reg.register(bad_func).unwrap();
793
794 let (result, warnings) = reg.validate_with_warnings();
795 assert!(result.is_ok());
796 assert_eq!(warnings.len(), 1);
797 assert!(warnings[0].func_name == "diverging");
798 }
799}