Skip to main content

tensorlogic_compiler/passes/
scope_analysis.rs

1//! Variable scope analysis pass.
2
3use std::collections::{HashMap, HashSet};
4
5use anyhow::{bail, Result};
6use tensorlogic_ir::{IrError, TLExpr, Term, TypeAnnotation};
7
8/// Scope information for a variable
9#[derive(Debug, Clone)]
10pub struct VariableScope {
11    pub name: String,
12    pub bound_in: ScopeType,
13    pub type_annotation: Option<TypeAnnotation>,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum ScopeType {
18    Quantifier { quantifier_type: String },
19    Free,
20}
21
22/// Result of scope analysis
23#[derive(Debug, Clone, Default)]
24pub struct ScopeAnalysisResult {
25    pub variables: HashMap<String, VariableScope>,
26    pub unbound_variables: Vec<String>,
27    pub type_conflicts: Vec<TypeConflict>,
28}
29
30#[derive(Debug, Clone)]
31pub struct TypeConflict {
32    pub variable: String,
33    pub type1: String,
34    pub type2: String,
35}
36
37/// Analyze variable scopes in an expression
38pub fn analyze_scopes(expr: &TLExpr) -> Result<ScopeAnalysisResult> {
39    let mut result = ScopeAnalysisResult::default();
40    let mut bound_vars = HashSet::new();
41
42    analyze_expr(expr, &mut bound_vars, &mut result)?;
43
44    Ok(result)
45}
46
47fn analyze_expr(
48    expr: &TLExpr,
49    bound_vars: &mut HashSet<String>,
50    result: &mut ScopeAnalysisResult,
51) -> Result<()> {
52    match expr {
53        #[allow(unreachable_patterns)]
54        TLExpr::Pred { name: _, args } => {
55            // Check all variables in predicate arguments
56            for term in args {
57                check_term(term, bound_vars, result);
58            }
59        }
60        TLExpr::And(left, right)
61        | TLExpr::Or(left, right)
62        | TLExpr::Imply(left, right)
63        | TLExpr::Add(left, right)
64        | TLExpr::Sub(left, right)
65        | TLExpr::Mul(left, right)
66        | TLExpr::Div(left, right)
67        | TLExpr::Pow(left, right)
68        | TLExpr::Mod(left, right)
69        | TLExpr::Min(left, right)
70        | TLExpr::Max(left, right)
71        | TLExpr::Eq(left, right)
72        | TLExpr::Lt(left, right)
73        | TLExpr::Gt(left, right)
74        | TLExpr::Lte(left, right)
75        | TLExpr::Gte(left, right)
76        | TLExpr::TNorm { left, right, .. }
77        | TLExpr::TCoNorm { left, right, .. }
78        | TLExpr::FuzzyImplication {
79            premise: left,
80            conclusion: right,
81            ..
82        } => {
83            analyze_expr(left, bound_vars, result)?;
84            analyze_expr(right, bound_vars, result)?;
85        }
86        TLExpr::Not(inner)
87        | TLExpr::Score(inner)
88        | TLExpr::Abs(inner)
89        | TLExpr::Floor(inner)
90        | TLExpr::Ceil(inner)
91        | TLExpr::Round(inner)
92        | TLExpr::Sqrt(inner)
93        | TLExpr::Exp(inner)
94        | TLExpr::Log(inner)
95        | TLExpr::Sin(inner)
96        | TLExpr::Cos(inner)
97        | TLExpr::Tan(inner)
98        | TLExpr::FuzzyNot { expr: inner, .. }
99        | TLExpr::WeightedRule { rule: inner, .. } => {
100            analyze_expr(inner, bound_vars, result)?;
101        }
102        TLExpr::IfThenElse {
103            condition,
104            then_branch,
105            else_branch,
106        } => {
107            analyze_expr(condition, bound_vars, result)?;
108            analyze_expr(then_branch, bound_vars, result)?;
109            analyze_expr(else_branch, bound_vars, result)?;
110        }
111        TLExpr::Constant(_) => {
112            // Constants have no variables to analyze
113        }
114        TLExpr::Exists {
115            var,
116            domain: _,
117            body,
118        }
119        | TLExpr::ForAll {
120            var,
121            domain: _,
122            body,
123        }
124        | TLExpr::SoftExists {
125            var,
126            domain: _,
127            body,
128            ..
129        }
130        | TLExpr::SoftForAll {
131            var,
132            domain: _,
133            body,
134            ..
135        }
136        | TLExpr::Aggregate {
137            var,
138            domain: _,
139            body,
140            ..
141        } => {
142            // Variable is bound in this scope
143            let was_bound = bound_vars.contains(var);
144            bound_vars.insert(var.clone());
145
146            // Record the binding
147            if !result.variables.contains_key(var) {
148                result.variables.insert(
149                    var.clone(),
150                    VariableScope {
151                        name: var.clone(),
152                        bound_in: ScopeType::Quantifier {
153                            quantifier_type: match expr {
154                                TLExpr::Exists { .. } => "exists".to_string(),
155                                TLExpr::ForAll { .. } => "forall".to_string(),
156                                TLExpr::SoftExists { .. } => "soft_exists".to_string(),
157                                TLExpr::SoftForAll { .. } => "soft_forall".to_string(),
158                                TLExpr::Aggregate { .. } => "aggregate".to_string(),
159                                _ => unreachable!(),
160                            },
161                        },
162                        type_annotation: None,
163                    },
164                );
165            }
166
167            // Analyze the body
168            analyze_expr(body, bound_vars, result)?;
169
170            // Unbind if it wasn't previously bound
171            if !was_bound {
172                bound_vars.remove(var);
173            }
174        }
175        TLExpr::Let { var, value, body } => {
176            // Analyze value expression first (without the new variable bound)
177            analyze_expr(value, bound_vars, result)?;
178            // Then analyze body with the variable bound
179            let was_bound = bound_vars.contains(var);
180            bound_vars.insert(var.clone());
181            analyze_expr(body, bound_vars, result)?;
182            if !was_bound {
183                bound_vars.remove(var);
184            }
185        }
186
187        // Modal/temporal logic operators - not yet implemented, pass through with recursion
188        TLExpr::Box(inner)
189        | TLExpr::Diamond(inner)
190        | TLExpr::Next(inner)
191        | TLExpr::Eventually(inner)
192        | TLExpr::Always(inner) => {
193            analyze_expr(inner, bound_vars, result)?;
194        }
195        TLExpr::Until { before, after }
196        | TLExpr::Release {
197            released: before,
198            releaser: after,
199        }
200        | TLExpr::WeakUntil { before, after }
201        | TLExpr::StrongRelease {
202            released: before,
203            releaser: after,
204        } => {
205            analyze_expr(before, bound_vars, result)?;
206            analyze_expr(after, bound_vars, result)?;
207        }
208        TLExpr::ProbabilisticChoice { alternatives } => {
209            for (_weight, alt_expr) in alternatives {
210                analyze_expr(alt_expr, bound_vars, result)?;
211            }
212        }
213        // Counting quantifiers
214        TLExpr::CountingExists {
215            var,
216            domain: _,
217            body,
218            ..
219        }
220        | TLExpr::CountingForAll {
221            var,
222            domain: _,
223            body,
224            ..
225        }
226        | TLExpr::ExactCount {
227            var,
228            domain: _,
229            body,
230            ..
231        }
232        | TLExpr::Majority {
233            var,
234            domain: _,
235            body,
236        } => {
237            // Variable is bound in this scope
238            let was_bound = bound_vars.contains(var);
239            bound_vars.insert(var.clone());
240
241            // Record the binding
242            if !result.variables.contains_key(var) {
243                result.variables.insert(
244                    var.clone(),
245                    VariableScope {
246                        name: var.clone(),
247                        bound_in: ScopeType::Quantifier {
248                            quantifier_type: match expr {
249                                TLExpr::CountingExists { .. } => "counting_exists".to_string(),
250                                TLExpr::CountingForAll { .. } => "counting_forall".to_string(),
251                                TLExpr::ExactCount { .. } => "exact_count".to_string(),
252                                TLExpr::Majority { .. } => "majority".to_string(),
253                                _ => unreachable!(),
254                            },
255                        },
256                        type_annotation: None,
257                    },
258                );
259            }
260
261            // Analyze the body
262            analyze_expr(body, bound_vars, result)?;
263
264            // Unbind if it wasn't previously bound
265            if !was_bound {
266                bound_vars.remove(var);
267            }
268        }
269        // All other expression types (enhancements) - skip for now
270        _ => {
271            // For unimplemented expression types, no scope analysis yet
272        }
273    }
274
275    Ok(())
276}
277
278fn check_term(term: &Term, bound_vars: &HashSet<String>, result: &mut ScopeAnalysisResult) {
279    match term {
280        Term::Var(var_name) => {
281            if !bound_vars.contains(var_name) && !result.variables.contains_key(var_name) {
282                // This is a free variable
283                result.variables.insert(
284                    var_name.clone(),
285                    VariableScope {
286                        name: var_name.clone(),
287                        bound_in: ScopeType::Free,
288                        type_annotation: None,
289                    },
290                );
291                result.unbound_variables.push(var_name.clone());
292            }
293
294            // Check for type annotation
295            if let Some(type_ann) = term.get_type() {
296                if let Some(existing_scope) = result.variables.get_mut(var_name) {
297                    if let Some(ref existing_type) = existing_scope.type_annotation {
298                        if existing_type != type_ann {
299                            result.type_conflicts.push(TypeConflict {
300                                variable: var_name.clone(),
301                                type1: existing_type.type_name.clone(),
302                                type2: type_ann.type_name.clone(),
303                            });
304                        }
305                    } else {
306                        existing_scope.type_annotation = Some(type_ann.clone());
307                    }
308                }
309            }
310        }
311        Term::Typed {
312            value,
313            type_annotation,
314        } => {
315            // Check the underlying term
316            check_term(value, bound_vars, result);
317
318            // Record type annotation
319            if let Term::Var(var_name) = value.untyped() {
320                if let Some(existing_scope) = result.variables.get_mut(var_name) {
321                    if let Some(ref existing_type) = existing_scope.type_annotation {
322                        if existing_type != type_annotation {
323                            result.type_conflicts.push(TypeConflict {
324                                variable: var_name.clone(),
325                                type1: existing_type.type_name.clone(),
326                                type2: type_annotation.type_name.clone(),
327                            });
328                        }
329                    } else {
330                        existing_scope.type_annotation = Some(type_annotation.clone());
331                    }
332                }
333            }
334        }
335        Term::Const(_) => {
336            // Constants don't need scope checking
337        }
338    }
339}
340
341/// Validate that all variables are properly bound
342pub fn validate_scopes(expr: &TLExpr) -> Result<()> {
343    let result = analyze_scopes(expr)?;
344
345    if !result.unbound_variables.is_empty() {
346        bail!(
347            "Unbound variables found: {}",
348            result.unbound_variables.join(", ")
349        );
350    }
351
352    if !result.type_conflicts.is_empty() {
353        let conflict = &result.type_conflicts[0];
354        return Err(IrError::InconsistentTypes {
355            var: conflict.variable.clone(),
356            type1: conflict.type1.clone(),
357            type2: conflict.type2.clone(),
358        }
359        .into());
360    }
361
362    Ok(())
363}
364
365/// Suggest quantifiers for unbound variables
366pub fn suggest_quantifiers(expr: &TLExpr) -> Result<Vec<String>> {
367    let result = analyze_scopes(expr)?;
368    let mut suggestions = Vec::new();
369
370    for unbound_var in &result.unbound_variables {
371        suggestions.push(format!(
372            "Consider adding a universal quantifier: ∀{}. <expr>",
373            unbound_var
374        ));
375    }
376
377    Ok(suggestions)
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_bound_variable() {
386        let expr = TLExpr::exists("x", "Domain", TLExpr::pred("p", vec![Term::var("x")]));
387
388        let result = analyze_scopes(&expr).unwrap();
389        assert!(result.unbound_variables.is_empty());
390        assert_eq!(result.variables.len(), 1);
391        assert_eq!(result.variables["x"].name, "x");
392    }
393
394    #[test]
395    fn test_unbound_variable() {
396        let expr = TLExpr::pred("p", vec![Term::var("x")]);
397
398        let result = analyze_scopes(&expr).unwrap();
399        assert_eq!(result.unbound_variables.len(), 1);
400        assert_eq!(result.unbound_variables[0], "x");
401    }
402
403    #[test]
404    fn test_mixed_bound_unbound() {
405        // ∃x. p(x, y) - x is bound, y is free
406        let expr = TLExpr::exists(
407            "x",
408            "Domain",
409            TLExpr::pred("p", vec![Term::var("x"), Term::var("y")]),
410        );
411
412        let result = analyze_scopes(&expr).unwrap();
413        assert_eq!(result.unbound_variables.len(), 1);
414        assert_eq!(result.unbound_variables[0], "y");
415        assert_eq!(result.variables.len(), 2);
416    }
417
418    #[test]
419    fn test_nested_quantifiers() {
420        // ∃x. ∀y. p(x, y, z) - x and y are bound, z is free
421        let expr = TLExpr::exists(
422            "x",
423            "Domain",
424            TLExpr::forall(
425                "y",
426                "Domain",
427                TLExpr::pred("p", vec![Term::var("x"), Term::var("y"), Term::var("z")]),
428            ),
429        );
430
431        let result = analyze_scopes(&expr).unwrap();
432        assert_eq!(result.unbound_variables.len(), 1);
433        assert_eq!(result.unbound_variables[0], "z");
434    }
435
436    #[test]
437    fn test_validate_scopes_success() {
438        let expr = TLExpr::exists("x", "Domain", TLExpr::pred("p", vec![Term::var("x")]));
439
440        assert!(validate_scopes(&expr).is_ok());
441    }
442
443    #[test]
444    fn test_validate_scopes_failure() {
445        let expr = TLExpr::pred("p", vec![Term::var("x")]);
446
447        assert!(validate_scopes(&expr).is_err());
448    }
449
450    #[test]
451    fn test_type_annotations() {
452        let expr = TLExpr::pred(
453            "p",
454            vec![
455                Term::typed_var("x", "Person"),
456                Term::typed_var("x", "Person"), // Same type, OK
457            ],
458        );
459
460        let result = analyze_scopes(&expr).unwrap();
461        assert!(result.type_conflicts.is_empty());
462    }
463
464    #[test]
465    fn test_type_conflicts() {
466        let expr = TLExpr::pred(
467            "p",
468            vec![
469                Term::typed_var("x", "Person"),
470                Term::typed_var("x", "Thing"), // Different type, conflict!
471            ],
472        );
473
474        let result = analyze_scopes(&expr).unwrap();
475        assert_eq!(result.type_conflicts.len(), 1);
476        assert_eq!(result.type_conflicts[0].variable, "x");
477    }
478
479    #[test]
480    fn test_suggest_quantifiers() {
481        let expr = TLExpr::pred("p", vec![Term::var("x"), Term::var("y")]);
482
483        let suggestions = suggest_quantifiers(&expr).unwrap();
484        assert_eq!(suggestions.len(), 2);
485        assert!(suggestions[0].contains("x"));
486        assert!(suggestions[1].contains("y"));
487    }
488}