Skip to main content

tensorlogic_ir/expr/
domain_validation.rs

1//! Domain validation for expressions.
2
3use std::collections::HashMap;
4
5use crate::domain::DomainRegistry;
6use crate::error::IrError;
7use crate::expr::TLExpr;
8
9impl TLExpr {
10    /// Validate all domain references in this expression.
11    ///
12    /// Checks that:
13    /// - All quantifier domains exist in the registry
14    /// - Variables used consistently across different quantifiers
15    pub fn validate_domains(&self, registry: &DomainRegistry) -> Result<(), IrError> {
16        let mut var_domains = HashMap::new();
17        self.collect_and_validate_domains(registry, &mut var_domains)
18    }
19
20    fn collect_and_validate_domains(
21        &self,
22        registry: &DomainRegistry,
23        var_domains: &mut HashMap<String, String>,
24    ) -> Result<(), IrError> {
25        match self {
26            TLExpr::Exists { var, domain, body }
27            | TLExpr::ForAll { var, domain, body }
28            | TLExpr::SoftExists {
29                var, domain, body, ..
30            }
31            | TLExpr::SoftForAll {
32                var, domain, body, ..
33            } => {
34                // Check domain exists
35                registry.validate_domain(domain)?;
36
37                // Check for consistent variable usage
38                if let Some(existing_domain) = var_domains.get(var) {
39                    if existing_domain != domain {
40                        // Check if domains are compatible
41                        if !registry.are_compatible(existing_domain, domain)? {
42                            return Err(IrError::VariableDomainMismatch {
43                                var: var.clone(),
44                                expected: existing_domain.clone(),
45                                actual: domain.clone(),
46                            });
47                        }
48                    }
49                } else {
50                    var_domains.insert(var.clone(), domain.clone());
51                }
52
53                body.collect_and_validate_domains(registry, var_domains)?;
54            }
55            TLExpr::Aggregate {
56                var, domain, body, ..
57            } => {
58                // Check domain exists
59                registry.validate_domain(domain)?;
60
61                // Check for consistent variable usage
62                if let Some(existing_domain) = var_domains.get(var) {
63                    if existing_domain != domain {
64                        // Check if domains are compatible
65                        if !registry.are_compatible(existing_domain, domain)? {
66                            return Err(IrError::VariableDomainMismatch {
67                                var: var.clone(),
68                                expected: existing_domain.clone(),
69                                actual: domain.clone(),
70                            });
71                        }
72                    }
73                } else {
74                    var_domains.insert(var.clone(), domain.clone());
75                }
76
77                body.collect_and_validate_domains(registry, var_domains)?;
78            }
79            TLExpr::And(l, r)
80            | TLExpr::Or(l, r)
81            | TLExpr::Imply(l, r)
82            | TLExpr::Add(l, r)
83            | TLExpr::Sub(l, r)
84            | TLExpr::Mul(l, r)
85            | TLExpr::Div(l, r)
86            | TLExpr::Pow(l, r)
87            | TLExpr::Mod(l, r)
88            | TLExpr::Min(l, r)
89            | TLExpr::Max(l, r)
90            | TLExpr::Eq(l, r)
91            | TLExpr::Lt(l, r)
92            | TLExpr::Gt(l, r)
93            | TLExpr::Lte(l, r)
94            | TLExpr::Gte(l, r) => {
95                l.collect_and_validate_domains(registry, var_domains)?;
96                r.collect_and_validate_domains(registry, var_domains)?;
97            }
98            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
99                left.collect_and_validate_domains(registry, var_domains)?;
100                right.collect_and_validate_domains(registry, var_domains)?;
101            }
102            TLExpr::FuzzyImplication {
103                premise,
104                conclusion,
105                ..
106            } => {
107                premise.collect_and_validate_domains(registry, var_domains)?;
108                conclusion.collect_and_validate_domains(registry, var_domains)?;
109            }
110            TLExpr::Not(e)
111            | TLExpr::Score(e)
112            | TLExpr::Abs(e)
113            | TLExpr::Floor(e)
114            | TLExpr::Ceil(e)
115            | TLExpr::Round(e)
116            | TLExpr::Sqrt(e)
117            | TLExpr::Exp(e)
118            | TLExpr::Log(e)
119            | TLExpr::Sin(e)
120            | TLExpr::Cos(e)
121            | TLExpr::Tan(e)
122            | TLExpr::Box(e)
123            | TLExpr::Diamond(e)
124            | TLExpr::Next(e)
125            | TLExpr::Eventually(e)
126            | TLExpr::Always(e) => {
127                e.collect_and_validate_domains(registry, var_domains)?;
128            }
129            TLExpr::FuzzyNot { expr, .. } => {
130                expr.collect_and_validate_domains(registry, var_domains)?;
131            }
132            TLExpr::WeightedRule { rule, .. } => {
133                rule.collect_and_validate_domains(registry, var_domains)?;
134            }
135            TLExpr::Until { before, after }
136            | TLExpr::Release {
137                released: before,
138                releaser: after,
139            }
140            | TLExpr::WeakUntil { before, after }
141            | TLExpr::StrongRelease {
142                released: before,
143                releaser: after,
144            } => {
145                before.collect_and_validate_domains(registry, var_domains)?;
146                after.collect_and_validate_domains(registry, var_domains)?;
147            }
148            TLExpr::ProbabilisticChoice { alternatives } => {
149                for (_, expr) in alternatives {
150                    expr.collect_and_validate_domains(registry, var_domains)?;
151                }
152            }
153            TLExpr::IfThenElse {
154                condition,
155                then_branch,
156                else_branch,
157            } => {
158                condition.collect_and_validate_domains(registry, var_domains)?;
159                then_branch.collect_and_validate_domains(registry, var_domains)?;
160                else_branch.collect_and_validate_domains(registry, var_domains)?;
161            }
162            TLExpr::Let { value, body, .. } => {
163                value.collect_and_validate_domains(registry, var_domains)?;
164                body.collect_and_validate_domains(registry, var_domains)?;
165            }
166            // Beta.1 enhancements
167            TLExpr::Lambda { body, .. } => {
168                // Lambda introduces a local binding, no domain validation
169                body.collect_and_validate_domains(registry, var_domains)?;
170            }
171            TLExpr::Apply { function, argument } => {
172                function.collect_and_validate_domains(registry, var_domains)?;
173                argument.collect_and_validate_domains(registry, var_domains)?;
174            }
175            TLExpr::SetMembership { element, set }
176            | TLExpr::SetUnion {
177                left: element,
178                right: set,
179            }
180            | TLExpr::SetIntersection {
181                left: element,
182                right: set,
183            }
184            | TLExpr::SetDifference {
185                left: element,
186                right: set,
187            } => {
188                element.collect_and_validate_domains(registry, var_domains)?;
189                set.collect_and_validate_domains(registry, var_domains)?;
190            }
191            TLExpr::SetCardinality { set } => {
192                set.collect_and_validate_domains(registry, var_domains)?;
193            }
194            TLExpr::EmptySet => {
195                // No domain validation needed
196            }
197            TLExpr::SetComprehension {
198                var,
199                domain,
200                condition,
201            } => {
202                registry.validate_domain(domain)?;
203                if let Some(existing_domain) = var_domains.get(var) {
204                    if existing_domain != domain
205                        && !registry.are_compatible(existing_domain, domain)?
206                    {
207                        return Err(IrError::VariableDomainMismatch {
208                            var: var.clone(),
209                            expected: existing_domain.clone(),
210                            actual: domain.clone(),
211                        });
212                    }
213                } else {
214                    var_domains.insert(var.clone(), domain.clone());
215                }
216                condition.collect_and_validate_domains(registry, var_domains)?;
217            }
218            TLExpr::CountingExists {
219                var, domain, body, ..
220            }
221            | TLExpr::CountingForAll {
222                var, domain, body, ..
223            }
224            | TLExpr::ExactCount {
225                var, domain, body, ..
226            }
227            | TLExpr::Majority { var, domain, body } => {
228                registry.validate_domain(domain)?;
229                if let Some(existing_domain) = var_domains.get(var) {
230                    if existing_domain != domain
231                        && !registry.are_compatible(existing_domain, domain)?
232                    {
233                        return Err(IrError::VariableDomainMismatch {
234                            var: var.clone(),
235                            expected: existing_domain.clone(),
236                            actual: domain.clone(),
237                        });
238                    }
239                } else {
240                    var_domains.insert(var.clone(), domain.clone());
241                }
242                body.collect_and_validate_domains(registry, var_domains)?;
243            }
244            TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
245                body.collect_and_validate_domains(registry, var_domains)?;
246            }
247            TLExpr::Nominal { .. } => {
248                // No domain validation needed
249            }
250            TLExpr::At { formula, .. } => {
251                formula.collect_and_validate_domains(registry, var_domains)?;
252            }
253            TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
254                formula.collect_and_validate_domains(registry, var_domains)?;
255            }
256            TLExpr::AllDifferent { .. } => {
257                // Variables, no domain validation here
258            }
259            TLExpr::GlobalCardinality { values, .. } => {
260                for val in values {
261                    val.collect_and_validate_domains(registry, var_domains)?;
262                }
263            }
264            TLExpr::Abducible { .. } => {
265                // No domain validation needed
266            }
267            TLExpr::Explain { formula } => {
268                formula.collect_and_validate_domains(registry, var_domains)?;
269            }
270            TLExpr::Pred { .. } | TLExpr::Constant(_) => {
271                // No domain validation needed for predicates and constants
272            }
273        }
274        Ok(())
275    }
276
277    /// Extract all domains referenced in this expression.
278    pub fn referenced_domains(&self) -> Vec<String> {
279        let mut domains = Vec::new();
280        self.collect_domains(&mut domains);
281        domains.sort();
282        domains.dedup();
283        domains
284    }
285
286    fn collect_domains(&self, domains: &mut Vec<String>) {
287        match self {
288            TLExpr::Exists { domain, body, .. }
289            | TLExpr::ForAll { domain, body, .. }
290            | TLExpr::SoftExists { domain, body, .. }
291            | TLExpr::SoftForAll { domain, body, .. }
292            | TLExpr::Aggregate { domain, body, .. } => {
293                domains.push(domain.clone());
294                body.collect_domains(domains);
295            }
296            TLExpr::And(l, r)
297            | TLExpr::Or(l, r)
298            | TLExpr::Imply(l, r)
299            | TLExpr::Add(l, r)
300            | TLExpr::Sub(l, r)
301            | TLExpr::Mul(l, r)
302            | TLExpr::Div(l, r)
303            | TLExpr::Pow(l, r)
304            | TLExpr::Mod(l, r)
305            | TLExpr::Min(l, r)
306            | TLExpr::Max(l, r)
307            | TLExpr::Eq(l, r)
308            | TLExpr::Lt(l, r)
309            | TLExpr::Gt(l, r)
310            | TLExpr::Lte(l, r)
311            | TLExpr::Gte(l, r) => {
312                l.collect_domains(domains);
313                r.collect_domains(domains);
314            }
315            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
316                left.collect_domains(domains);
317                right.collect_domains(domains);
318            }
319            TLExpr::FuzzyImplication {
320                premise,
321                conclusion,
322                ..
323            } => {
324                premise.collect_domains(domains);
325                conclusion.collect_domains(domains);
326            }
327            TLExpr::Not(e)
328            | TLExpr::Score(e)
329            | TLExpr::Abs(e)
330            | TLExpr::Floor(e)
331            | TLExpr::Ceil(e)
332            | TLExpr::Round(e)
333            | TLExpr::Sqrt(e)
334            | TLExpr::Exp(e)
335            | TLExpr::Log(e)
336            | TLExpr::Sin(e)
337            | TLExpr::Cos(e)
338            | TLExpr::Tan(e)
339            | TLExpr::Box(e)
340            | TLExpr::Diamond(e)
341            | TLExpr::Next(e)
342            | TLExpr::Eventually(e)
343            | TLExpr::Always(e) => {
344                e.collect_domains(domains);
345            }
346            TLExpr::FuzzyNot { expr, .. } => {
347                expr.collect_domains(domains);
348            }
349            TLExpr::WeightedRule { rule, .. } => {
350                rule.collect_domains(domains);
351            }
352            TLExpr::Until { before, after }
353            | TLExpr::Release {
354                released: before,
355                releaser: after,
356            }
357            | TLExpr::WeakUntil { before, after }
358            | TLExpr::StrongRelease {
359                released: before,
360                releaser: after,
361            } => {
362                before.collect_domains(domains);
363                after.collect_domains(domains);
364            }
365            TLExpr::ProbabilisticChoice { alternatives } => {
366                for (_, expr) in alternatives {
367                    expr.collect_domains(domains);
368                }
369            }
370            TLExpr::IfThenElse {
371                condition,
372                then_branch,
373                else_branch,
374            } => {
375                condition.collect_domains(domains);
376                then_branch.collect_domains(domains);
377                else_branch.collect_domains(domains);
378            }
379            TLExpr::Let { value, body, .. } => {
380                value.collect_domains(domains);
381                body.collect_domains(domains);
382            }
383            // Beta.1 enhancements
384            TLExpr::Lambda { body, .. } => {
385                body.collect_domains(domains);
386            }
387            TLExpr::Apply { function, argument } => {
388                function.collect_domains(domains);
389                argument.collect_domains(domains);
390            }
391            TLExpr::SetMembership { element, set }
392            | TLExpr::SetUnion {
393                left: element,
394                right: set,
395            }
396            | TLExpr::SetIntersection {
397                left: element,
398                right: set,
399            }
400            | TLExpr::SetDifference {
401                left: element,
402                right: set,
403            } => {
404                element.collect_domains(domains);
405                set.collect_domains(domains);
406            }
407            TLExpr::SetCardinality { set } => {
408                set.collect_domains(domains);
409            }
410            TLExpr::EmptySet => {}
411            TLExpr::SetComprehension {
412                domain, condition, ..
413            } => {
414                domains.push(domain.clone());
415                condition.collect_domains(domains);
416            }
417            TLExpr::CountingExists { domain, body, .. }
418            | TLExpr::CountingForAll { domain, body, .. }
419            | TLExpr::ExactCount { domain, body, .. }
420            | TLExpr::Majority { domain, body, .. } => {
421                domains.push(domain.clone());
422                body.collect_domains(domains);
423            }
424            TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
425                body.collect_domains(domains);
426            }
427            TLExpr::Nominal { .. } => {}
428            TLExpr::At { formula, .. } => {
429                formula.collect_domains(domains);
430            }
431            TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
432                formula.collect_domains(domains);
433            }
434            TLExpr::AllDifferent { .. } => {}
435            TLExpr::GlobalCardinality { values, .. } => {
436                for val in values {
437                    val.collect_domains(domains);
438                }
439            }
440            TLExpr::Abducible { .. } => {}
441            TLExpr::Explain { formula } => {
442                formula.collect_domains(domains);
443            }
444            TLExpr::Pred { .. } | TLExpr::Constant(_) => {}
445        }
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use crate::term::Term;
453
454    #[test]
455    fn test_validate_domains_success() {
456        let registry = DomainRegistry::with_builtins();
457
458        let expr = TLExpr::exists("x", "Int", TLExpr::pred("P", vec![Term::var("x")]));
459
460        assert!(expr.validate_domains(&registry).is_ok());
461    }
462
463    #[test]
464    fn test_validate_domains_not_found() {
465        let registry = DomainRegistry::new();
466
467        let expr = TLExpr::exists(
468            "x",
469            "UnknownDomain",
470            TLExpr::pred("P", vec![Term::var("x")]),
471        );
472
473        assert!(expr.validate_domains(&registry).is_err());
474    }
475
476    #[test]
477    fn test_validate_domains_consistent_usage() {
478        let registry = DomainRegistry::with_builtins();
479
480        // ∃x:Int. ∀x:Int. P(x) - same variable, same domain
481        let expr = TLExpr::exists(
482            "x",
483            "Int",
484            TLExpr::forall("x", "Int", TLExpr::pred("P", vec![Term::var("x")])),
485        );
486
487        assert!(expr.validate_domains(&registry).is_ok());
488    }
489
490    #[test]
491    fn test_validate_domains_incompatible() {
492        let registry = DomainRegistry::with_builtins();
493
494        // ∃x:Int. ∀x:Bool. P(x) - same variable, incompatible domains
495        let expr = TLExpr::exists(
496            "x",
497            "Int",
498            TLExpr::forall("x", "Bool", TLExpr::pred("P", vec![Term::var("x")])),
499        );
500
501        assert!(expr.validate_domains(&registry).is_err());
502    }
503
504    #[test]
505    fn test_referenced_domains() {
506        let expr = TLExpr::exists(
507            "x",
508            "Int",
509            TLExpr::forall("y", "Real", TLExpr::pred("P", vec![Term::var("x")])),
510        );
511
512        let domains = expr.referenced_domains();
513        assert_eq!(domains, vec!["Int", "Real"]);
514    }
515
516    #[test]
517    fn test_referenced_domains_dedup() {
518        let expr = TLExpr::and(
519            TLExpr::exists("x", "Int", TLExpr::pred("P", vec![Term::var("x")])),
520            TLExpr::exists("y", "Int", TLExpr::pred("Q", vec![Term::var("y")])),
521        );
522
523        let domains = expr.referenced_domains();
524        assert_eq!(domains, vec!["Int"]);
525    }
526}