Skip to main content

tensorlogic_ir/expr/
domain_validation.rs

1//! Domain validation for expressions.
2
3use std::collections::HashMap;
4
5use crate::domain::DomainRegistry;
6use crate::error::IrError;
7use crate::expr::TLExpr;
8
9impl TLExpr {
10    /// Validate all domain references in this expression.
11    ///
12    /// Checks that:
13    /// - All quantifier domains exist in the registry
14    /// - Variables used consistently across different quantifiers
15    pub fn validate_domains(&self, registry: &DomainRegistry) -> Result<(), IrError> {
16        let mut var_domains = HashMap::new();
17        self.collect_and_validate_domains(registry, &mut var_domains)
18    }
19
20    fn collect_and_validate_domains(
21        &self,
22        registry: &DomainRegistry,
23        var_domains: &mut HashMap<String, String>,
24    ) -> Result<(), IrError> {
25        match self {
26            TLExpr::Exists { var, domain, body }
27            | TLExpr::ForAll { var, domain, body }
28            | TLExpr::SoftExists {
29                var, domain, body, ..
30            }
31            | TLExpr::SoftForAll {
32                var, domain, body, ..
33            } => {
34                // Check domain exists
35                registry.validate_domain(domain)?;
36
37                // Check for consistent variable usage
38                if let Some(existing_domain) = var_domains.get(var) {
39                    if existing_domain != domain {
40                        // Check if domains are compatible
41                        if !registry.are_compatible(existing_domain, domain)? {
42                            return Err(IrError::VariableDomainMismatch {
43                                var: var.clone(),
44                                expected: existing_domain.clone(),
45                                actual: domain.clone(),
46                            });
47                        }
48                    }
49                } else {
50                    var_domains.insert(var.clone(), domain.clone());
51                }
52
53                body.collect_and_validate_domains(registry, var_domains)?;
54            }
55            TLExpr::Aggregate {
56                var, domain, body, ..
57            } => {
58                // Check domain exists
59                registry.validate_domain(domain)?;
60
61                // Check for consistent variable usage
62                if let Some(existing_domain) = var_domains.get(var) {
63                    if existing_domain != domain {
64                        // Check if domains are compatible
65                        if !registry.are_compatible(existing_domain, domain)? {
66                            return Err(IrError::VariableDomainMismatch {
67                                var: var.clone(),
68                                expected: existing_domain.clone(),
69                                actual: domain.clone(),
70                            });
71                        }
72                    }
73                } else {
74                    var_domains.insert(var.clone(), domain.clone());
75                }
76
77                body.collect_and_validate_domains(registry, var_domains)?;
78            }
79            TLExpr::And(l, r)
80            | TLExpr::Or(l, r)
81            | TLExpr::Imply(l, r)
82            | TLExpr::Add(l, r)
83            | TLExpr::Sub(l, r)
84            | TLExpr::Mul(l, r)
85            | TLExpr::Div(l, r)
86            | TLExpr::Pow(l, r)
87            | TLExpr::Mod(l, r)
88            | TLExpr::Min(l, r)
89            | TLExpr::Max(l, r)
90            | TLExpr::Eq(l, r)
91            | TLExpr::Lt(l, r)
92            | TLExpr::Gt(l, r)
93            | TLExpr::Lte(l, r)
94            | TLExpr::Gte(l, r) => {
95                l.collect_and_validate_domains(registry, var_domains)?;
96                r.collect_and_validate_domains(registry, var_domains)?;
97            }
98            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
99                left.collect_and_validate_domains(registry, var_domains)?;
100                right.collect_and_validate_domains(registry, var_domains)?;
101            }
102            TLExpr::FuzzyImplication {
103                premise,
104                conclusion,
105                ..
106            } => {
107                premise.collect_and_validate_domains(registry, var_domains)?;
108                conclusion.collect_and_validate_domains(registry, var_domains)?;
109            }
110            TLExpr::Not(e)
111            | TLExpr::Score(e)
112            | TLExpr::Abs(e)
113            | TLExpr::Floor(e)
114            | TLExpr::Ceil(e)
115            | TLExpr::Round(e)
116            | TLExpr::Sqrt(e)
117            | TLExpr::Exp(e)
118            | TLExpr::Log(e)
119            | TLExpr::Sin(e)
120            | TLExpr::Cos(e)
121            | TLExpr::Tan(e)
122            | TLExpr::Box(e)
123            | TLExpr::Diamond(e)
124            | TLExpr::Next(e)
125            | TLExpr::Eventually(e)
126            | TLExpr::Always(e) => {
127                e.collect_and_validate_domains(registry, var_domains)?;
128            }
129            TLExpr::FuzzyNot { expr, .. } => {
130                expr.collect_and_validate_domains(registry, var_domains)?;
131            }
132            TLExpr::WeightedRule { rule, .. } => {
133                rule.collect_and_validate_domains(registry, var_domains)?;
134            }
135            TLExpr::Until { before, after }
136            | TLExpr::Release {
137                released: before,
138                releaser: after,
139            }
140            | TLExpr::WeakUntil { before, after }
141            | TLExpr::StrongRelease {
142                released: before,
143                releaser: after,
144            } => {
145                before.collect_and_validate_domains(registry, var_domains)?;
146                after.collect_and_validate_domains(registry, var_domains)?;
147            }
148            TLExpr::ProbabilisticChoice { alternatives } => {
149                for (_, expr) in alternatives {
150                    expr.collect_and_validate_domains(registry, var_domains)?;
151                }
152            }
153            TLExpr::IfThenElse {
154                condition,
155                then_branch,
156                else_branch,
157            } => {
158                condition.collect_and_validate_domains(registry, var_domains)?;
159                then_branch.collect_and_validate_domains(registry, var_domains)?;
160                else_branch.collect_and_validate_domains(registry, var_domains)?;
161            }
162            TLExpr::Let { value, body, .. } => {
163                value.collect_and_validate_domains(registry, var_domains)?;
164                body.collect_and_validate_domains(registry, var_domains)?;
165            }
166            // Beta.1 enhancements
167            TLExpr::Lambda { body, .. } => {
168                // Lambda introduces a local binding, no domain validation
169                body.collect_and_validate_domains(registry, var_domains)?;
170            }
171            TLExpr::Apply { function, argument } => {
172                function.collect_and_validate_domains(registry, var_domains)?;
173                argument.collect_and_validate_domains(registry, var_domains)?;
174            }
175            TLExpr::SetMembership { element, set }
176            | TLExpr::SetUnion {
177                left: element,
178                right: set,
179            }
180            | TLExpr::SetIntersection {
181                left: element,
182                right: set,
183            }
184            | TLExpr::SetDifference {
185                left: element,
186                right: set,
187            } => {
188                element.collect_and_validate_domains(registry, var_domains)?;
189                set.collect_and_validate_domains(registry, var_domains)?;
190            }
191            TLExpr::SetCardinality { set } => {
192                set.collect_and_validate_domains(registry, var_domains)?;
193            }
194            TLExpr::EmptySet => {
195                // No domain validation needed
196            }
197            TLExpr::SetComprehension {
198                var,
199                domain,
200                condition,
201            } => {
202                registry.validate_domain(domain)?;
203                if let Some(existing_domain) = var_domains.get(var) {
204                    if existing_domain != domain
205                        && !registry.are_compatible(existing_domain, domain)?
206                    {
207                        return Err(IrError::VariableDomainMismatch {
208                            var: var.clone(),
209                            expected: existing_domain.clone(),
210                            actual: domain.clone(),
211                        });
212                    }
213                } else {
214                    var_domains.insert(var.clone(), domain.clone());
215                }
216                condition.collect_and_validate_domains(registry, var_domains)?;
217            }
218            TLExpr::CountingExists {
219                var, domain, body, ..
220            }
221            | TLExpr::CountingForAll {
222                var, domain, body, ..
223            }
224            | TLExpr::ExactCount {
225                var, domain, body, ..
226            }
227            | TLExpr::Majority { var, domain, body } => {
228                registry.validate_domain(domain)?;
229                if let Some(existing_domain) = var_domains.get(var) {
230                    if existing_domain != domain
231                        && !registry.are_compatible(existing_domain, domain)?
232                    {
233                        return Err(IrError::VariableDomainMismatch {
234                            var: var.clone(),
235                            expected: existing_domain.clone(),
236                            actual: domain.clone(),
237                        });
238                    }
239                } else {
240                    var_domains.insert(var.clone(), domain.clone());
241                }
242                body.collect_and_validate_domains(registry, var_domains)?;
243            }
244            TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
245                body.collect_and_validate_domains(registry, var_domains)?;
246            }
247            TLExpr::Nominal { .. } => {
248                // No domain validation needed
249            }
250            TLExpr::At { formula, .. } => {
251                formula.collect_and_validate_domains(registry, var_domains)?;
252            }
253            TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
254                formula.collect_and_validate_domains(registry, var_domains)?;
255            }
256            TLExpr::AllDifferent { .. } => {
257                // Variables, no domain validation here
258            }
259            TLExpr::GlobalCardinality { values, .. } => {
260                for val in values {
261                    val.collect_and_validate_domains(registry, var_domains)?;
262                }
263            }
264            TLExpr::Abducible { .. } => {
265                // No domain validation needed
266            }
267            TLExpr::Explain { formula } => {
268                formula.collect_and_validate_domains(registry, var_domains)?;
269            }
270            TLExpr::SymbolLiteral(_) => {
271                // No domain validation needed for symbol literals
272            }
273            TLExpr::Match { scrutinee, arms } => {
274                scrutinee.collect_and_validate_domains(registry, var_domains)?;
275                for (_, body) in arms {
276                    body.collect_and_validate_domains(registry, var_domains)?;
277                }
278            }
279            TLExpr::Pred { .. } | TLExpr::Constant(_) => {
280                // No domain validation needed for predicates and constants
281            }
282        }
283        Ok(())
284    }
285
286    /// Extract all domains referenced in this expression.
287    pub fn referenced_domains(&self) -> Vec<String> {
288        let mut domains = Vec::new();
289        self.collect_domains(&mut domains);
290        domains.sort();
291        domains.dedup();
292        domains
293    }
294
295    fn collect_domains(&self, domains: &mut Vec<String>) {
296        match self {
297            TLExpr::Exists { domain, body, .. }
298            | TLExpr::ForAll { domain, body, .. }
299            | TLExpr::SoftExists { domain, body, .. }
300            | TLExpr::SoftForAll { domain, body, .. }
301            | TLExpr::Aggregate { domain, body, .. } => {
302                domains.push(domain.clone());
303                body.collect_domains(domains);
304            }
305            TLExpr::And(l, r)
306            | TLExpr::Or(l, r)
307            | TLExpr::Imply(l, r)
308            | TLExpr::Add(l, r)
309            | TLExpr::Sub(l, r)
310            | TLExpr::Mul(l, r)
311            | TLExpr::Div(l, r)
312            | TLExpr::Pow(l, r)
313            | TLExpr::Mod(l, r)
314            | TLExpr::Min(l, r)
315            | TLExpr::Max(l, r)
316            | TLExpr::Eq(l, r)
317            | TLExpr::Lt(l, r)
318            | TLExpr::Gt(l, r)
319            | TLExpr::Lte(l, r)
320            | TLExpr::Gte(l, r) => {
321                l.collect_domains(domains);
322                r.collect_domains(domains);
323            }
324            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
325                left.collect_domains(domains);
326                right.collect_domains(domains);
327            }
328            TLExpr::FuzzyImplication {
329                premise,
330                conclusion,
331                ..
332            } => {
333                premise.collect_domains(domains);
334                conclusion.collect_domains(domains);
335            }
336            TLExpr::Not(e)
337            | TLExpr::Score(e)
338            | TLExpr::Abs(e)
339            | TLExpr::Floor(e)
340            | TLExpr::Ceil(e)
341            | TLExpr::Round(e)
342            | TLExpr::Sqrt(e)
343            | TLExpr::Exp(e)
344            | TLExpr::Log(e)
345            | TLExpr::Sin(e)
346            | TLExpr::Cos(e)
347            | TLExpr::Tan(e)
348            | TLExpr::Box(e)
349            | TLExpr::Diamond(e)
350            | TLExpr::Next(e)
351            | TLExpr::Eventually(e)
352            | TLExpr::Always(e) => {
353                e.collect_domains(domains);
354            }
355            TLExpr::FuzzyNot { expr, .. } => {
356                expr.collect_domains(domains);
357            }
358            TLExpr::WeightedRule { rule, .. } => {
359                rule.collect_domains(domains);
360            }
361            TLExpr::Until { before, after }
362            | TLExpr::Release {
363                released: before,
364                releaser: after,
365            }
366            | TLExpr::WeakUntil { before, after }
367            | TLExpr::StrongRelease {
368                released: before,
369                releaser: after,
370            } => {
371                before.collect_domains(domains);
372                after.collect_domains(domains);
373            }
374            TLExpr::ProbabilisticChoice { alternatives } => {
375                for (_, expr) in alternatives {
376                    expr.collect_domains(domains);
377                }
378            }
379            TLExpr::IfThenElse {
380                condition,
381                then_branch,
382                else_branch,
383            } => {
384                condition.collect_domains(domains);
385                then_branch.collect_domains(domains);
386                else_branch.collect_domains(domains);
387            }
388            TLExpr::Let { value, body, .. } => {
389                value.collect_domains(domains);
390                body.collect_domains(domains);
391            }
392            // Beta.1 enhancements
393            TLExpr::Lambda { body, .. } => {
394                body.collect_domains(domains);
395            }
396            TLExpr::Apply { function, argument } => {
397                function.collect_domains(domains);
398                argument.collect_domains(domains);
399            }
400            TLExpr::SetMembership { element, set }
401            | TLExpr::SetUnion {
402                left: element,
403                right: set,
404            }
405            | TLExpr::SetIntersection {
406                left: element,
407                right: set,
408            }
409            | TLExpr::SetDifference {
410                left: element,
411                right: set,
412            } => {
413                element.collect_domains(domains);
414                set.collect_domains(domains);
415            }
416            TLExpr::SetCardinality { set } => {
417                set.collect_domains(domains);
418            }
419            TLExpr::EmptySet => {}
420            TLExpr::SetComprehension {
421                domain, condition, ..
422            } => {
423                domains.push(domain.clone());
424                condition.collect_domains(domains);
425            }
426            TLExpr::CountingExists { domain, body, .. }
427            | TLExpr::CountingForAll { domain, body, .. }
428            | TLExpr::ExactCount { domain, body, .. }
429            | TLExpr::Majority { domain, body, .. } => {
430                domains.push(domain.clone());
431                body.collect_domains(domains);
432            }
433            TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
434                body.collect_domains(domains);
435            }
436            TLExpr::Nominal { .. } => {}
437            TLExpr::At { formula, .. } => {
438                formula.collect_domains(domains);
439            }
440            TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
441                formula.collect_domains(domains);
442            }
443            TLExpr::AllDifferent { .. } => {}
444            TLExpr::GlobalCardinality { values, .. } => {
445                for val in values {
446                    val.collect_domains(domains);
447                }
448            }
449            TLExpr::Abducible { .. } => {}
450            TLExpr::Explain { formula } => {
451                formula.collect_domains(domains);
452            }
453            TLExpr::SymbolLiteral(_) => {}
454            TLExpr::Match { scrutinee, arms } => {
455                scrutinee.collect_domains(domains);
456                for (_, body) in arms {
457                    body.collect_domains(domains);
458                }
459            }
460            TLExpr::Pred { .. } | TLExpr::Constant(_) => {}
461        }
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    use crate::term::Term;
469
470    #[test]
471    fn test_validate_domains_success() {
472        let registry = DomainRegistry::with_builtins();
473
474        let expr = TLExpr::exists("x", "Int", TLExpr::pred("P", vec![Term::var("x")]));
475
476        assert!(expr.validate_domains(&registry).is_ok());
477    }
478
479    #[test]
480    fn test_validate_domains_not_found() {
481        let registry = DomainRegistry::new();
482
483        let expr = TLExpr::exists(
484            "x",
485            "UnknownDomain",
486            TLExpr::pred("P", vec![Term::var("x")]),
487        );
488
489        assert!(expr.validate_domains(&registry).is_err());
490    }
491
492    #[test]
493    fn test_validate_domains_consistent_usage() {
494        let registry = DomainRegistry::with_builtins();
495
496        // ∃x:Int. ∀x:Int. P(x) - same variable, same domain
497        let expr = TLExpr::exists(
498            "x",
499            "Int",
500            TLExpr::forall("x", "Int", TLExpr::pred("P", vec![Term::var("x")])),
501        );
502
503        assert!(expr.validate_domains(&registry).is_ok());
504    }
505
506    #[test]
507    fn test_validate_domains_incompatible() {
508        let registry = DomainRegistry::with_builtins();
509
510        // ∃x:Int. ∀x:Bool. P(x) - same variable, incompatible domains
511        let expr = TLExpr::exists(
512            "x",
513            "Int",
514            TLExpr::forall("x", "Bool", TLExpr::pred("P", vec![Term::var("x")])),
515        );
516
517        assert!(expr.validate_domains(&registry).is_err());
518    }
519
520    #[test]
521    fn test_referenced_domains() {
522        let expr = TLExpr::exists(
523            "x",
524            "Int",
525            TLExpr::forall("y", "Real", TLExpr::pred("P", vec![Term::var("x")])),
526        );
527
528        let domains = expr.referenced_domains();
529        assert_eq!(domains, vec!["Int", "Real"]);
530    }
531
532    #[test]
533    fn test_referenced_domains_dedup() {
534        let expr = TLExpr::and(
535            TLExpr::exists("x", "Int", TLExpr::pred("P", vec![Term::var("x")])),
536            TLExpr::exists("y", "Int", TLExpr::pred("Q", vec![Term::var("y")])),
537        );
538
539        let domains = expr.referenced_domains();
540        assert_eq!(domains, vec!["Int"]);
541    }
542}