1use std::collections::{HashMap, HashSet};
4
5use anyhow::{bail, Result};
6use tensorlogic_ir::{IrError, TLExpr, Term, TypeAnnotation};
7
8#[derive(Debug, Clone)]
10pub struct VariableScope {
11 pub name: String,
12 pub bound_in: ScopeType,
13 pub type_annotation: Option<TypeAnnotation>,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum ScopeType {
18 Quantifier { quantifier_type: String },
19 Free,
20}
21
22#[derive(Debug, Clone, Default)]
24pub struct ScopeAnalysisResult {
25 pub variables: HashMap<String, VariableScope>,
26 pub unbound_variables: Vec<String>,
27 pub type_conflicts: Vec<TypeConflict>,
28}
29
30#[derive(Debug, Clone)]
31pub struct TypeConflict {
32 pub variable: String,
33 pub type1: String,
34 pub type2: String,
35}
36
37pub fn analyze_scopes(expr: &TLExpr) -> Result<ScopeAnalysisResult> {
39 let mut result = ScopeAnalysisResult::default();
40 let mut bound_vars = HashSet::new();
41
42 analyze_expr(expr, &mut bound_vars, &mut result)?;
43
44 Ok(result)
45}
46
47fn analyze_expr(
48 expr: &TLExpr,
49 bound_vars: &mut HashSet<String>,
50 result: &mut ScopeAnalysisResult,
51) -> Result<()> {
52 match expr {
53 #[allow(unreachable_patterns)]
54 TLExpr::Pred { name: _, args } => {
55 for term in args {
57 check_term(term, bound_vars, result);
58 }
59 }
60 TLExpr::And(left, right)
61 | TLExpr::Or(left, right)
62 | TLExpr::Imply(left, right)
63 | TLExpr::Add(left, right)
64 | TLExpr::Sub(left, right)
65 | TLExpr::Mul(left, right)
66 | TLExpr::Div(left, right)
67 | TLExpr::Pow(left, right)
68 | TLExpr::Mod(left, right)
69 | TLExpr::Min(left, right)
70 | TLExpr::Max(left, right)
71 | TLExpr::Eq(left, right)
72 | TLExpr::Lt(left, right)
73 | TLExpr::Gt(left, right)
74 | TLExpr::Lte(left, right)
75 | TLExpr::Gte(left, right)
76 | TLExpr::TNorm { left, right, .. }
77 | TLExpr::TCoNorm { left, right, .. }
78 | TLExpr::FuzzyImplication {
79 premise: left,
80 conclusion: right,
81 ..
82 } => {
83 analyze_expr(left, bound_vars, result)?;
84 analyze_expr(right, bound_vars, result)?;
85 }
86 TLExpr::Not(inner)
87 | TLExpr::Score(inner)
88 | TLExpr::Abs(inner)
89 | TLExpr::Floor(inner)
90 | TLExpr::Ceil(inner)
91 | TLExpr::Round(inner)
92 | TLExpr::Sqrt(inner)
93 | TLExpr::Exp(inner)
94 | TLExpr::Log(inner)
95 | TLExpr::Sin(inner)
96 | TLExpr::Cos(inner)
97 | TLExpr::Tan(inner)
98 | TLExpr::FuzzyNot { expr: inner, .. }
99 | TLExpr::WeightedRule { rule: inner, .. } => {
100 analyze_expr(inner, bound_vars, result)?;
101 }
102 TLExpr::IfThenElse {
103 condition,
104 then_branch,
105 else_branch,
106 } => {
107 analyze_expr(condition, bound_vars, result)?;
108 analyze_expr(then_branch, bound_vars, result)?;
109 analyze_expr(else_branch, bound_vars, result)?;
110 }
111 TLExpr::Constant(_) => {
112 }
114 TLExpr::Exists {
115 var,
116 domain: _,
117 body,
118 }
119 | TLExpr::ForAll {
120 var,
121 domain: _,
122 body,
123 }
124 | TLExpr::SoftExists {
125 var,
126 domain: _,
127 body,
128 ..
129 }
130 | TLExpr::SoftForAll {
131 var,
132 domain: _,
133 body,
134 ..
135 }
136 | TLExpr::Aggregate {
137 var,
138 domain: _,
139 body,
140 ..
141 } => {
142 let was_bound = bound_vars.contains(var);
144 bound_vars.insert(var.clone());
145
146 if !result.variables.contains_key(var) {
148 result.variables.insert(
149 var.clone(),
150 VariableScope {
151 name: var.clone(),
152 bound_in: ScopeType::Quantifier {
153 quantifier_type: match expr {
154 TLExpr::Exists { .. } => "exists".to_string(),
155 TLExpr::ForAll { .. } => "forall".to_string(),
156 TLExpr::SoftExists { .. } => "soft_exists".to_string(),
157 TLExpr::SoftForAll { .. } => "soft_forall".to_string(),
158 TLExpr::Aggregate { .. } => "aggregate".to_string(),
159 _ => unreachable!(),
160 },
161 },
162 type_annotation: None,
163 },
164 );
165 }
166
167 analyze_expr(body, bound_vars, result)?;
169
170 if !was_bound {
172 bound_vars.remove(var);
173 }
174 }
175 TLExpr::Let { var, value, body } => {
176 analyze_expr(value, bound_vars, result)?;
178 let was_bound = bound_vars.contains(var);
180 bound_vars.insert(var.clone());
181 analyze_expr(body, bound_vars, result)?;
182 if !was_bound {
183 bound_vars.remove(var);
184 }
185 }
186
187 TLExpr::Box(inner)
189 | TLExpr::Diamond(inner)
190 | TLExpr::Next(inner)
191 | TLExpr::Eventually(inner)
192 | TLExpr::Always(inner) => {
193 analyze_expr(inner, bound_vars, result)?;
194 }
195 TLExpr::Until { before, after }
196 | TLExpr::Release {
197 released: before,
198 releaser: after,
199 }
200 | TLExpr::WeakUntil { before, after }
201 | TLExpr::StrongRelease {
202 released: before,
203 releaser: after,
204 } => {
205 analyze_expr(before, bound_vars, result)?;
206 analyze_expr(after, bound_vars, result)?;
207 }
208 TLExpr::ProbabilisticChoice { alternatives } => {
209 for (_weight, alt_expr) in alternatives {
210 analyze_expr(alt_expr, bound_vars, result)?;
211 }
212 }
213 }
214
215 Ok(())
216}
217
218fn check_term(term: &Term, bound_vars: &HashSet<String>, result: &mut ScopeAnalysisResult) {
219 match term {
220 Term::Var(var_name) => {
221 if !bound_vars.contains(var_name) && !result.variables.contains_key(var_name) {
222 result.variables.insert(
224 var_name.clone(),
225 VariableScope {
226 name: var_name.clone(),
227 bound_in: ScopeType::Free,
228 type_annotation: None,
229 },
230 );
231 result.unbound_variables.push(var_name.clone());
232 }
233
234 if let Some(type_ann) = term.get_type() {
236 if let Some(existing_scope) = result.variables.get_mut(var_name) {
237 if let Some(ref existing_type) = existing_scope.type_annotation {
238 if existing_type != type_ann {
239 result.type_conflicts.push(TypeConflict {
240 variable: var_name.clone(),
241 type1: existing_type.type_name.clone(),
242 type2: type_ann.type_name.clone(),
243 });
244 }
245 } else {
246 existing_scope.type_annotation = Some(type_ann.clone());
247 }
248 }
249 }
250 }
251 Term::Typed {
252 value,
253 type_annotation,
254 } => {
255 check_term(value, bound_vars, result);
257
258 if let Term::Var(var_name) = value.untyped() {
260 if let Some(existing_scope) = result.variables.get_mut(var_name) {
261 if let Some(ref existing_type) = existing_scope.type_annotation {
262 if existing_type != type_annotation {
263 result.type_conflicts.push(TypeConflict {
264 variable: var_name.clone(),
265 type1: existing_type.type_name.clone(),
266 type2: type_annotation.type_name.clone(),
267 });
268 }
269 } else {
270 existing_scope.type_annotation = Some(type_annotation.clone());
271 }
272 }
273 }
274 }
275 Term::Const(_) => {
276 }
278 }
279}
280
281pub fn validate_scopes(expr: &TLExpr) -> Result<()> {
283 let result = analyze_scopes(expr)?;
284
285 if !result.unbound_variables.is_empty() {
286 bail!(
287 "Unbound variables found: {}",
288 result.unbound_variables.join(", ")
289 );
290 }
291
292 if !result.type_conflicts.is_empty() {
293 let conflict = &result.type_conflicts[0];
294 return Err(IrError::InconsistentTypes {
295 var: conflict.variable.clone(),
296 type1: conflict.type1.clone(),
297 type2: conflict.type2.clone(),
298 }
299 .into());
300 }
301
302 Ok(())
303}
304
305pub fn suggest_quantifiers(expr: &TLExpr) -> Result<Vec<String>> {
307 let result = analyze_scopes(expr)?;
308 let mut suggestions = Vec::new();
309
310 for unbound_var in &result.unbound_variables {
311 suggestions.push(format!(
312 "Consider adding a universal quantifier: ∀{}. <expr>",
313 unbound_var
314 ));
315 }
316
317 Ok(suggestions)
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_bound_variable() {
326 let expr = TLExpr::exists("x", "Domain", TLExpr::pred("p", vec![Term::var("x")]));
327
328 let result = analyze_scopes(&expr).unwrap();
329 assert!(result.unbound_variables.is_empty());
330 assert_eq!(result.variables.len(), 1);
331 assert_eq!(result.variables["x"].name, "x");
332 }
333
334 #[test]
335 fn test_unbound_variable() {
336 let expr = TLExpr::pred("p", vec![Term::var("x")]);
337
338 let result = analyze_scopes(&expr).unwrap();
339 assert_eq!(result.unbound_variables.len(), 1);
340 assert_eq!(result.unbound_variables[0], "x");
341 }
342
343 #[test]
344 fn test_mixed_bound_unbound() {
345 let expr = TLExpr::exists(
347 "x",
348 "Domain",
349 TLExpr::pred("p", vec![Term::var("x"), Term::var("y")]),
350 );
351
352 let result = analyze_scopes(&expr).unwrap();
353 assert_eq!(result.unbound_variables.len(), 1);
354 assert_eq!(result.unbound_variables[0], "y");
355 assert_eq!(result.variables.len(), 2);
356 }
357
358 #[test]
359 fn test_nested_quantifiers() {
360 let expr = TLExpr::exists(
362 "x",
363 "Domain",
364 TLExpr::forall(
365 "y",
366 "Domain",
367 TLExpr::pred("p", vec![Term::var("x"), Term::var("y"), Term::var("z")]),
368 ),
369 );
370
371 let result = analyze_scopes(&expr).unwrap();
372 assert_eq!(result.unbound_variables.len(), 1);
373 assert_eq!(result.unbound_variables[0], "z");
374 }
375
376 #[test]
377 fn test_validate_scopes_success() {
378 let expr = TLExpr::exists("x", "Domain", TLExpr::pred("p", vec![Term::var("x")]));
379
380 assert!(validate_scopes(&expr).is_ok());
381 }
382
383 #[test]
384 fn test_validate_scopes_failure() {
385 let expr = TLExpr::pred("p", vec![Term::var("x")]);
386
387 assert!(validate_scopes(&expr).is_err());
388 }
389
390 #[test]
391 fn test_type_annotations() {
392 let expr = TLExpr::pred(
393 "p",
394 vec![
395 Term::typed_var("x", "Person"),
396 Term::typed_var("x", "Person"), ],
398 );
399
400 let result = analyze_scopes(&expr).unwrap();
401 assert!(result.type_conflicts.is_empty());
402 }
403
404 #[test]
405 fn test_type_conflicts() {
406 let expr = TLExpr::pred(
407 "p",
408 vec![
409 Term::typed_var("x", "Person"),
410 Term::typed_var("x", "Thing"), ],
412 );
413
414 let result = analyze_scopes(&expr).unwrap();
415 assert_eq!(result.type_conflicts.len(), 1);
416 assert_eq!(result.type_conflicts[0].variable, "x");
417 }
418
419 #[test]
420 fn test_suggest_quantifiers() {
421 let expr = TLExpr::pred("p", vec![Term::var("x"), Term::var("y")]);
422
423 let suggestions = suggest_quantifiers(&expr).unwrap();
424 assert_eq!(suggestions.len(), 2);
425 assert!(suggestions[0].contains("x"));
426 assert!(suggestions[1].contains("y"));
427 }
428}