tensorlogic_compiler/passes/
scope_analysis.rs

1//! Variable scope analysis pass.
2
3use std::collections::{HashMap, HashSet};
4
5use anyhow::{bail, Result};
6use tensorlogic_ir::{IrError, TLExpr, Term, TypeAnnotation};
7
8/// Scope information for a variable
9#[derive(Debug, Clone)]
10pub struct VariableScope {
11    pub name: String,
12    pub bound_in: ScopeType,
13    pub type_annotation: Option<TypeAnnotation>,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum ScopeType {
18    Quantifier { quantifier_type: String },
19    Free,
20}
21
22/// Result of scope analysis
23#[derive(Debug, Clone, Default)]
24pub struct ScopeAnalysisResult {
25    pub variables: HashMap<String, VariableScope>,
26    pub unbound_variables: Vec<String>,
27    pub type_conflicts: Vec<TypeConflict>,
28}
29
30#[derive(Debug, Clone)]
31pub struct TypeConflict {
32    pub variable: String,
33    pub type1: String,
34    pub type2: String,
35}
36
37/// Analyze variable scopes in an expression
38pub fn analyze_scopes(expr: &TLExpr) -> Result<ScopeAnalysisResult> {
39    let mut result = ScopeAnalysisResult::default();
40    let mut bound_vars = HashSet::new();
41
42    analyze_expr(expr, &mut bound_vars, &mut result)?;
43
44    Ok(result)
45}
46
47fn analyze_expr(
48    expr: &TLExpr,
49    bound_vars: &mut HashSet<String>,
50    result: &mut ScopeAnalysisResult,
51) -> Result<()> {
52    match expr {
53        #[allow(unreachable_patterns)]
54        TLExpr::Pred { name: _, args } => {
55            // Check all variables in predicate arguments
56            for term in args {
57                check_term(term, bound_vars, result);
58            }
59        }
60        TLExpr::And(left, right)
61        | TLExpr::Or(left, right)
62        | TLExpr::Imply(left, right)
63        | TLExpr::Add(left, right)
64        | TLExpr::Sub(left, right)
65        | TLExpr::Mul(left, right)
66        | TLExpr::Div(left, right)
67        | TLExpr::Pow(left, right)
68        | TLExpr::Mod(left, right)
69        | TLExpr::Min(left, right)
70        | TLExpr::Max(left, right)
71        | TLExpr::Eq(left, right)
72        | TLExpr::Lt(left, right)
73        | TLExpr::Gt(left, right)
74        | TLExpr::Lte(left, right)
75        | TLExpr::Gte(left, right)
76        | TLExpr::TNorm { left, right, .. }
77        | TLExpr::TCoNorm { left, right, .. }
78        | TLExpr::FuzzyImplication {
79            premise: left,
80            conclusion: right,
81            ..
82        } => {
83            analyze_expr(left, bound_vars, result)?;
84            analyze_expr(right, bound_vars, result)?;
85        }
86        TLExpr::Not(inner)
87        | TLExpr::Score(inner)
88        | TLExpr::Abs(inner)
89        | TLExpr::Floor(inner)
90        | TLExpr::Ceil(inner)
91        | TLExpr::Round(inner)
92        | TLExpr::Sqrt(inner)
93        | TLExpr::Exp(inner)
94        | TLExpr::Log(inner)
95        | TLExpr::Sin(inner)
96        | TLExpr::Cos(inner)
97        | TLExpr::Tan(inner)
98        | TLExpr::FuzzyNot { expr: inner, .. }
99        | TLExpr::WeightedRule { rule: inner, .. } => {
100            analyze_expr(inner, bound_vars, result)?;
101        }
102        TLExpr::IfThenElse {
103            condition,
104            then_branch,
105            else_branch,
106        } => {
107            analyze_expr(condition, bound_vars, result)?;
108            analyze_expr(then_branch, bound_vars, result)?;
109            analyze_expr(else_branch, bound_vars, result)?;
110        }
111        TLExpr::Constant(_) => {
112            // Constants have no variables to analyze
113        }
114        TLExpr::Exists {
115            var,
116            domain: _,
117            body,
118        }
119        | TLExpr::ForAll {
120            var,
121            domain: _,
122            body,
123        }
124        | TLExpr::SoftExists {
125            var,
126            domain: _,
127            body,
128            ..
129        }
130        | TLExpr::SoftForAll {
131            var,
132            domain: _,
133            body,
134            ..
135        }
136        | TLExpr::Aggregate {
137            var,
138            domain: _,
139            body,
140            ..
141        } => {
142            // Variable is bound in this scope
143            let was_bound = bound_vars.contains(var);
144            bound_vars.insert(var.clone());
145
146            // Record the binding
147            if !result.variables.contains_key(var) {
148                result.variables.insert(
149                    var.clone(),
150                    VariableScope {
151                        name: var.clone(),
152                        bound_in: ScopeType::Quantifier {
153                            quantifier_type: match expr {
154                                TLExpr::Exists { .. } => "exists".to_string(),
155                                TLExpr::ForAll { .. } => "forall".to_string(),
156                                TLExpr::SoftExists { .. } => "soft_exists".to_string(),
157                                TLExpr::SoftForAll { .. } => "soft_forall".to_string(),
158                                TLExpr::Aggregate { .. } => "aggregate".to_string(),
159                                _ => unreachable!(),
160                            },
161                        },
162                        type_annotation: None,
163                    },
164                );
165            }
166
167            // Analyze the body
168            analyze_expr(body, bound_vars, result)?;
169
170            // Unbind if it wasn't previously bound
171            if !was_bound {
172                bound_vars.remove(var);
173            }
174        }
175        TLExpr::Let { var, value, body } => {
176            // Analyze value expression first (without the new variable bound)
177            analyze_expr(value, bound_vars, result)?;
178            // Then analyze body with the variable bound
179            let was_bound = bound_vars.contains(var);
180            bound_vars.insert(var.clone());
181            analyze_expr(body, bound_vars, result)?;
182            if !was_bound {
183                bound_vars.remove(var);
184            }
185        }
186
187        // Modal/temporal logic operators - not yet implemented, pass through with recursion
188        TLExpr::Box(inner)
189        | TLExpr::Diamond(inner)
190        | TLExpr::Next(inner)
191        | TLExpr::Eventually(inner)
192        | TLExpr::Always(inner) => {
193            analyze_expr(inner, bound_vars, result)?;
194        }
195        TLExpr::Until { before, after }
196        | TLExpr::Release {
197            released: before,
198            releaser: after,
199        }
200        | TLExpr::WeakUntil { before, after }
201        | TLExpr::StrongRelease {
202            released: before,
203            releaser: after,
204        } => {
205            analyze_expr(before, bound_vars, result)?;
206            analyze_expr(after, bound_vars, result)?;
207        }
208        TLExpr::ProbabilisticChoice { alternatives } => {
209            for (_weight, alt_expr) in alternatives {
210                analyze_expr(alt_expr, bound_vars, result)?;
211            }
212        }
213    }
214
215    Ok(())
216}
217
218fn check_term(term: &Term, bound_vars: &HashSet<String>, result: &mut ScopeAnalysisResult) {
219    match term {
220        Term::Var(var_name) => {
221            if !bound_vars.contains(var_name) && !result.variables.contains_key(var_name) {
222                // This is a free variable
223                result.variables.insert(
224                    var_name.clone(),
225                    VariableScope {
226                        name: var_name.clone(),
227                        bound_in: ScopeType::Free,
228                        type_annotation: None,
229                    },
230                );
231                result.unbound_variables.push(var_name.clone());
232            }
233
234            // Check for type annotation
235            if let Some(type_ann) = term.get_type() {
236                if let Some(existing_scope) = result.variables.get_mut(var_name) {
237                    if let Some(ref existing_type) = existing_scope.type_annotation {
238                        if existing_type != type_ann {
239                            result.type_conflicts.push(TypeConflict {
240                                variable: var_name.clone(),
241                                type1: existing_type.type_name.clone(),
242                                type2: type_ann.type_name.clone(),
243                            });
244                        }
245                    } else {
246                        existing_scope.type_annotation = Some(type_ann.clone());
247                    }
248                }
249            }
250        }
251        Term::Typed {
252            value,
253            type_annotation,
254        } => {
255            // Check the underlying term
256            check_term(value, bound_vars, result);
257
258            // Record type annotation
259            if let Term::Var(var_name) = value.untyped() {
260                if let Some(existing_scope) = result.variables.get_mut(var_name) {
261                    if let Some(ref existing_type) = existing_scope.type_annotation {
262                        if existing_type != type_annotation {
263                            result.type_conflicts.push(TypeConflict {
264                                variable: var_name.clone(),
265                                type1: existing_type.type_name.clone(),
266                                type2: type_annotation.type_name.clone(),
267                            });
268                        }
269                    } else {
270                        existing_scope.type_annotation = Some(type_annotation.clone());
271                    }
272                }
273            }
274        }
275        Term::Const(_) => {
276            // Constants don't need scope checking
277        }
278    }
279}
280
281/// Validate that all variables are properly bound
282pub fn validate_scopes(expr: &TLExpr) -> Result<()> {
283    let result = analyze_scopes(expr)?;
284
285    if !result.unbound_variables.is_empty() {
286        bail!(
287            "Unbound variables found: {}",
288            result.unbound_variables.join(", ")
289        );
290    }
291
292    if !result.type_conflicts.is_empty() {
293        let conflict = &result.type_conflicts[0];
294        return Err(IrError::InconsistentTypes {
295            var: conflict.variable.clone(),
296            type1: conflict.type1.clone(),
297            type2: conflict.type2.clone(),
298        }
299        .into());
300    }
301
302    Ok(())
303}
304
305/// Suggest quantifiers for unbound variables
306pub fn suggest_quantifiers(expr: &TLExpr) -> Result<Vec<String>> {
307    let result = analyze_scopes(expr)?;
308    let mut suggestions = Vec::new();
309
310    for unbound_var in &result.unbound_variables {
311        suggestions.push(format!(
312            "Consider adding a universal quantifier: ∀{}. <expr>",
313            unbound_var
314        ));
315    }
316
317    Ok(suggestions)
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_bound_variable() {
326        let expr = TLExpr::exists("x", "Domain", TLExpr::pred("p", vec![Term::var("x")]));
327
328        let result = analyze_scopes(&expr).unwrap();
329        assert!(result.unbound_variables.is_empty());
330        assert_eq!(result.variables.len(), 1);
331        assert_eq!(result.variables["x"].name, "x");
332    }
333
334    #[test]
335    fn test_unbound_variable() {
336        let expr = TLExpr::pred("p", vec![Term::var("x")]);
337
338        let result = analyze_scopes(&expr).unwrap();
339        assert_eq!(result.unbound_variables.len(), 1);
340        assert_eq!(result.unbound_variables[0], "x");
341    }
342
343    #[test]
344    fn test_mixed_bound_unbound() {
345        // ∃x. p(x, y) - x is bound, y is free
346        let expr = TLExpr::exists(
347            "x",
348            "Domain",
349            TLExpr::pred("p", vec![Term::var("x"), Term::var("y")]),
350        );
351
352        let result = analyze_scopes(&expr).unwrap();
353        assert_eq!(result.unbound_variables.len(), 1);
354        assert_eq!(result.unbound_variables[0], "y");
355        assert_eq!(result.variables.len(), 2);
356    }
357
358    #[test]
359    fn test_nested_quantifiers() {
360        // ∃x. ∀y. p(x, y, z) - x and y are bound, z is free
361        let expr = TLExpr::exists(
362            "x",
363            "Domain",
364            TLExpr::forall(
365                "y",
366                "Domain",
367                TLExpr::pred("p", vec![Term::var("x"), Term::var("y"), Term::var("z")]),
368            ),
369        );
370
371        let result = analyze_scopes(&expr).unwrap();
372        assert_eq!(result.unbound_variables.len(), 1);
373        assert_eq!(result.unbound_variables[0], "z");
374    }
375
376    #[test]
377    fn test_validate_scopes_success() {
378        let expr = TLExpr::exists("x", "Domain", TLExpr::pred("p", vec![Term::var("x")]));
379
380        assert!(validate_scopes(&expr).is_ok());
381    }
382
383    #[test]
384    fn test_validate_scopes_failure() {
385        let expr = TLExpr::pred("p", vec![Term::var("x")]);
386
387        assert!(validate_scopes(&expr).is_err());
388    }
389
390    #[test]
391    fn test_type_annotations() {
392        let expr = TLExpr::pred(
393            "p",
394            vec![
395                Term::typed_var("x", "Person"),
396                Term::typed_var("x", "Person"), // Same type, OK
397            ],
398        );
399
400        let result = analyze_scopes(&expr).unwrap();
401        assert!(result.type_conflicts.is_empty());
402    }
403
404    #[test]
405    fn test_type_conflicts() {
406        let expr = TLExpr::pred(
407            "p",
408            vec![
409                Term::typed_var("x", "Person"),
410                Term::typed_var("x", "Thing"), // Different type, conflict!
411            ],
412        );
413
414        let result = analyze_scopes(&expr).unwrap();
415        assert_eq!(result.type_conflicts.len(), 1);
416        assert_eq!(result.type_conflicts[0].variable, "x");
417    }
418
419    #[test]
420    fn test_suggest_quantifiers() {
421        let expr = TLExpr::pred("p", vec![Term::var("x"), Term::var("y")]);
422
423        let suggestions = suggest_quantifiers(&expr).unwrap();
424        assert_eq!(suggestions.len(), 2);
425        assert!(suggestions[0].contains("x"));
426        assert!(suggestions[1].contains("y"));
427    }
428}