1use std::collections::{HashMap, HashSet};
4
5use anyhow::{bail, Result};
6use tensorlogic_ir::{IrError, TLExpr, Term, TypeAnnotation};
7
8#[derive(Debug, Clone)]
10pub struct VariableScope {
11 pub name: String,
12 pub bound_in: ScopeType,
13 pub type_annotation: Option<TypeAnnotation>,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum ScopeType {
18 Quantifier { quantifier_type: String },
19 Free,
20}
21
22#[derive(Debug, Clone, Default)]
24pub struct ScopeAnalysisResult {
25 pub variables: HashMap<String, VariableScope>,
26 pub unbound_variables: Vec<String>,
27 pub type_conflicts: Vec<TypeConflict>,
28}
29
30#[derive(Debug, Clone)]
31pub struct TypeConflict {
32 pub variable: String,
33 pub type1: String,
34 pub type2: String,
35}
36
37pub fn analyze_scopes(expr: &TLExpr) -> Result<ScopeAnalysisResult> {
39 let mut result = ScopeAnalysisResult::default();
40 let mut bound_vars = HashSet::new();
41
42 analyze_expr(expr, &mut bound_vars, &mut result)?;
43
44 Ok(result)
45}
46
47fn analyze_expr(
48 expr: &TLExpr,
49 bound_vars: &mut HashSet<String>,
50 result: &mut ScopeAnalysisResult,
51) -> Result<()> {
52 match expr {
53 #[allow(unreachable_patterns)]
54 TLExpr::Pred { name: _, args } => {
55 for term in args {
57 check_term(term, bound_vars, result);
58 }
59 }
60 TLExpr::And(left, right)
61 | TLExpr::Or(left, right)
62 | TLExpr::Imply(left, right)
63 | TLExpr::Add(left, right)
64 | TLExpr::Sub(left, right)
65 | TLExpr::Mul(left, right)
66 | TLExpr::Div(left, right)
67 | TLExpr::Pow(left, right)
68 | TLExpr::Mod(left, right)
69 | TLExpr::Min(left, right)
70 | TLExpr::Max(left, right)
71 | TLExpr::Eq(left, right)
72 | TLExpr::Lt(left, right)
73 | TLExpr::Gt(left, right)
74 | TLExpr::Lte(left, right)
75 | TLExpr::Gte(left, right)
76 | TLExpr::TNorm { left, right, .. }
77 | TLExpr::TCoNorm { left, right, .. }
78 | TLExpr::FuzzyImplication {
79 premise: left,
80 conclusion: right,
81 ..
82 } => {
83 analyze_expr(left, bound_vars, result)?;
84 analyze_expr(right, bound_vars, result)?;
85 }
86 TLExpr::Not(inner)
87 | TLExpr::Score(inner)
88 | TLExpr::Abs(inner)
89 | TLExpr::Floor(inner)
90 | TLExpr::Ceil(inner)
91 | TLExpr::Round(inner)
92 | TLExpr::Sqrt(inner)
93 | TLExpr::Exp(inner)
94 | TLExpr::Log(inner)
95 | TLExpr::Sin(inner)
96 | TLExpr::Cos(inner)
97 | TLExpr::Tan(inner)
98 | TLExpr::FuzzyNot { expr: inner, .. }
99 | TLExpr::WeightedRule { rule: inner, .. } => {
100 analyze_expr(inner, bound_vars, result)?;
101 }
102 TLExpr::IfThenElse {
103 condition,
104 then_branch,
105 else_branch,
106 } => {
107 analyze_expr(condition, bound_vars, result)?;
108 analyze_expr(then_branch, bound_vars, result)?;
109 analyze_expr(else_branch, bound_vars, result)?;
110 }
111 TLExpr::Constant(_) => {
112 }
114 TLExpr::Exists {
115 var,
116 domain: _,
117 body,
118 }
119 | TLExpr::ForAll {
120 var,
121 domain: _,
122 body,
123 }
124 | TLExpr::SoftExists {
125 var,
126 domain: _,
127 body,
128 ..
129 }
130 | TLExpr::SoftForAll {
131 var,
132 domain: _,
133 body,
134 ..
135 }
136 | TLExpr::Aggregate {
137 var,
138 domain: _,
139 body,
140 ..
141 } => {
142 let was_bound = bound_vars.contains(var);
144 bound_vars.insert(var.clone());
145
146 if !result.variables.contains_key(var) {
148 result.variables.insert(
149 var.clone(),
150 VariableScope {
151 name: var.clone(),
152 bound_in: ScopeType::Quantifier {
153 quantifier_type: match expr {
154 TLExpr::Exists { .. } => "exists".to_string(),
155 TLExpr::ForAll { .. } => "forall".to_string(),
156 TLExpr::SoftExists { .. } => "soft_exists".to_string(),
157 TLExpr::SoftForAll { .. } => "soft_forall".to_string(),
158 TLExpr::Aggregate { .. } => "aggregate".to_string(),
159 _ => unreachable!(),
160 },
161 },
162 type_annotation: None,
163 },
164 );
165 }
166
167 analyze_expr(body, bound_vars, result)?;
169
170 if !was_bound {
172 bound_vars.remove(var);
173 }
174 }
175 TLExpr::Let { var, value, body } => {
176 analyze_expr(value, bound_vars, result)?;
178 let was_bound = bound_vars.contains(var);
180 bound_vars.insert(var.clone());
181 analyze_expr(body, bound_vars, result)?;
182 if !was_bound {
183 bound_vars.remove(var);
184 }
185 }
186
187 TLExpr::Box(inner)
189 | TLExpr::Diamond(inner)
190 | TLExpr::Next(inner)
191 | TLExpr::Eventually(inner)
192 | TLExpr::Always(inner) => {
193 analyze_expr(inner, bound_vars, result)?;
194 }
195 TLExpr::Until { before, after }
196 | TLExpr::Release {
197 released: before,
198 releaser: after,
199 }
200 | TLExpr::WeakUntil { before, after }
201 | TLExpr::StrongRelease {
202 released: before,
203 releaser: after,
204 } => {
205 analyze_expr(before, bound_vars, result)?;
206 analyze_expr(after, bound_vars, result)?;
207 }
208 TLExpr::ProbabilisticChoice { alternatives } => {
209 for (_weight, alt_expr) in alternatives {
210 analyze_expr(alt_expr, bound_vars, result)?;
211 }
212 }
213 TLExpr::CountingExists {
215 var,
216 domain: _,
217 body,
218 ..
219 }
220 | TLExpr::CountingForAll {
221 var,
222 domain: _,
223 body,
224 ..
225 }
226 | TLExpr::ExactCount {
227 var,
228 domain: _,
229 body,
230 ..
231 }
232 | TLExpr::Majority {
233 var,
234 domain: _,
235 body,
236 } => {
237 let was_bound = bound_vars.contains(var);
239 bound_vars.insert(var.clone());
240
241 if !result.variables.contains_key(var) {
243 result.variables.insert(
244 var.clone(),
245 VariableScope {
246 name: var.clone(),
247 bound_in: ScopeType::Quantifier {
248 quantifier_type: match expr {
249 TLExpr::CountingExists { .. } => "counting_exists".to_string(),
250 TLExpr::CountingForAll { .. } => "counting_forall".to_string(),
251 TLExpr::ExactCount { .. } => "exact_count".to_string(),
252 TLExpr::Majority { .. } => "majority".to_string(),
253 _ => unreachable!(),
254 },
255 },
256 type_annotation: None,
257 },
258 );
259 }
260
261 analyze_expr(body, bound_vars, result)?;
263
264 if !was_bound {
266 bound_vars.remove(var);
267 }
268 }
269 _ => {
271 }
273 }
274
275 Ok(())
276}
277
278fn check_term(term: &Term, bound_vars: &HashSet<String>, result: &mut ScopeAnalysisResult) {
279 match term {
280 Term::Var(var_name) => {
281 if !bound_vars.contains(var_name) && !result.variables.contains_key(var_name) {
282 result.variables.insert(
284 var_name.clone(),
285 VariableScope {
286 name: var_name.clone(),
287 bound_in: ScopeType::Free,
288 type_annotation: None,
289 },
290 );
291 result.unbound_variables.push(var_name.clone());
292 }
293
294 if let Some(type_ann) = term.get_type() {
296 if let Some(existing_scope) = result.variables.get_mut(var_name) {
297 if let Some(ref existing_type) = existing_scope.type_annotation {
298 if existing_type != type_ann {
299 result.type_conflicts.push(TypeConflict {
300 variable: var_name.clone(),
301 type1: existing_type.type_name.clone(),
302 type2: type_ann.type_name.clone(),
303 });
304 }
305 } else {
306 existing_scope.type_annotation = Some(type_ann.clone());
307 }
308 }
309 }
310 }
311 Term::Typed {
312 value,
313 type_annotation,
314 } => {
315 check_term(value, bound_vars, result);
317
318 if let Term::Var(var_name) = value.untyped() {
320 if let Some(existing_scope) = result.variables.get_mut(var_name) {
321 if let Some(ref existing_type) = existing_scope.type_annotation {
322 if existing_type != type_annotation {
323 result.type_conflicts.push(TypeConflict {
324 variable: var_name.clone(),
325 type1: existing_type.type_name.clone(),
326 type2: type_annotation.type_name.clone(),
327 });
328 }
329 } else {
330 existing_scope.type_annotation = Some(type_annotation.clone());
331 }
332 }
333 }
334 }
335 Term::Const(_) => {
336 }
338 }
339}
340
341pub fn validate_scopes(expr: &TLExpr) -> Result<()> {
343 let result = analyze_scopes(expr)?;
344
345 if !result.unbound_variables.is_empty() {
346 bail!(
347 "Unbound variables found: {}",
348 result.unbound_variables.join(", ")
349 );
350 }
351
352 if !result.type_conflicts.is_empty() {
353 let conflict = &result.type_conflicts[0];
354 return Err(IrError::InconsistentTypes {
355 var: conflict.variable.clone(),
356 type1: conflict.type1.clone(),
357 type2: conflict.type2.clone(),
358 }
359 .into());
360 }
361
362 Ok(())
363}
364
365pub fn suggest_quantifiers(expr: &TLExpr) -> Result<Vec<String>> {
367 let result = analyze_scopes(expr)?;
368 let mut suggestions = Vec::new();
369
370 for unbound_var in &result.unbound_variables {
371 suggestions.push(format!(
372 "Consider adding a universal quantifier: ∀{}. <expr>",
373 unbound_var
374 ));
375 }
376
377 Ok(suggestions)
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_bound_variable() {
386 let expr = TLExpr::exists("x", "Domain", TLExpr::pred("p", vec![Term::var("x")]));
387
388 let result = analyze_scopes(&expr).unwrap();
389 assert!(result.unbound_variables.is_empty());
390 assert_eq!(result.variables.len(), 1);
391 assert_eq!(result.variables["x"].name, "x");
392 }
393
394 #[test]
395 fn test_unbound_variable() {
396 let expr = TLExpr::pred("p", vec![Term::var("x")]);
397
398 let result = analyze_scopes(&expr).unwrap();
399 assert_eq!(result.unbound_variables.len(), 1);
400 assert_eq!(result.unbound_variables[0], "x");
401 }
402
403 #[test]
404 fn test_mixed_bound_unbound() {
405 let expr = TLExpr::exists(
407 "x",
408 "Domain",
409 TLExpr::pred("p", vec![Term::var("x"), Term::var("y")]),
410 );
411
412 let result = analyze_scopes(&expr).unwrap();
413 assert_eq!(result.unbound_variables.len(), 1);
414 assert_eq!(result.unbound_variables[0], "y");
415 assert_eq!(result.variables.len(), 2);
416 }
417
418 #[test]
419 fn test_nested_quantifiers() {
420 let expr = TLExpr::exists(
422 "x",
423 "Domain",
424 TLExpr::forall(
425 "y",
426 "Domain",
427 TLExpr::pred("p", vec![Term::var("x"), Term::var("y"), Term::var("z")]),
428 ),
429 );
430
431 let result = analyze_scopes(&expr).unwrap();
432 assert_eq!(result.unbound_variables.len(), 1);
433 assert_eq!(result.unbound_variables[0], "z");
434 }
435
436 #[test]
437 fn test_validate_scopes_success() {
438 let expr = TLExpr::exists("x", "Domain", TLExpr::pred("p", vec![Term::var("x")]));
439
440 assert!(validate_scopes(&expr).is_ok());
441 }
442
443 #[test]
444 fn test_validate_scopes_failure() {
445 let expr = TLExpr::pred("p", vec![Term::var("x")]);
446
447 assert!(validate_scopes(&expr).is_err());
448 }
449
450 #[test]
451 fn test_type_annotations() {
452 let expr = TLExpr::pred(
453 "p",
454 vec![
455 Term::typed_var("x", "Person"),
456 Term::typed_var("x", "Person"), ],
458 );
459
460 let result = analyze_scopes(&expr).unwrap();
461 assert!(result.type_conflicts.is_empty());
462 }
463
464 #[test]
465 fn test_type_conflicts() {
466 let expr = TLExpr::pred(
467 "p",
468 vec![
469 Term::typed_var("x", "Person"),
470 Term::typed_var("x", "Thing"), ],
472 );
473
474 let result = analyze_scopes(&expr).unwrap();
475 assert_eq!(result.type_conflicts.len(), 1);
476 assert_eq!(result.type_conflicts[0].variable, "x");
477 }
478
479 #[test]
480 fn test_suggest_quantifiers() {
481 let expr = TLExpr::pred("p", vec![Term::var("x"), Term::var("y")]);
482
483 let suggestions = suggest_quantifiers(&expr).unwrap();
484 assert_eq!(suggestions.len(), 2);
485 assert!(suggestions[0].contains("x"));
486 assert!(suggestions[1].contains("y"));
487 }
488}