tensorlogic_ir/expr/
domain_validation.rs

1//! Domain validation for expressions.
2
3use std::collections::HashMap;
4
5use crate::domain::DomainRegistry;
6use crate::error::IrError;
7use crate::expr::TLExpr;
8
9impl TLExpr {
10    /// Validate all domain references in this expression.
11    ///
12    /// Checks that:
13    /// - All quantifier domains exist in the registry
14    /// - Variables used consistently across different quantifiers
15    pub fn validate_domains(&self, registry: &DomainRegistry) -> Result<(), IrError> {
16        let mut var_domains = HashMap::new();
17        self.collect_and_validate_domains(registry, &mut var_domains)
18    }
19
20    fn collect_and_validate_domains(
21        &self,
22        registry: &DomainRegistry,
23        var_domains: &mut HashMap<String, String>,
24    ) -> Result<(), IrError> {
25        match self {
26            TLExpr::Exists { var, domain, body }
27            | TLExpr::ForAll { var, domain, body }
28            | TLExpr::SoftExists {
29                var, domain, body, ..
30            }
31            | TLExpr::SoftForAll {
32                var, domain, body, ..
33            } => {
34                // Check domain exists
35                registry.validate_domain(domain)?;
36
37                // Check for consistent variable usage
38                if let Some(existing_domain) = var_domains.get(var) {
39                    if existing_domain != domain {
40                        // Check if domains are compatible
41                        if !registry.are_compatible(existing_domain, domain)? {
42                            return Err(IrError::VariableDomainMismatch {
43                                var: var.clone(),
44                                expected: existing_domain.clone(),
45                                actual: domain.clone(),
46                            });
47                        }
48                    }
49                } else {
50                    var_domains.insert(var.clone(), domain.clone());
51                }
52
53                body.collect_and_validate_domains(registry, var_domains)?;
54            }
55            TLExpr::Aggregate {
56                var, domain, body, ..
57            } => {
58                // Check domain exists
59                registry.validate_domain(domain)?;
60
61                // Check for consistent variable usage
62                if let Some(existing_domain) = var_domains.get(var) {
63                    if existing_domain != domain {
64                        // Check if domains are compatible
65                        if !registry.are_compatible(existing_domain, domain)? {
66                            return Err(IrError::VariableDomainMismatch {
67                                var: var.clone(),
68                                expected: existing_domain.clone(),
69                                actual: domain.clone(),
70                            });
71                        }
72                    }
73                } else {
74                    var_domains.insert(var.clone(), domain.clone());
75                }
76
77                body.collect_and_validate_domains(registry, var_domains)?;
78            }
79            TLExpr::And(l, r)
80            | TLExpr::Or(l, r)
81            | TLExpr::Imply(l, r)
82            | TLExpr::Add(l, r)
83            | TLExpr::Sub(l, r)
84            | TLExpr::Mul(l, r)
85            | TLExpr::Div(l, r)
86            | TLExpr::Pow(l, r)
87            | TLExpr::Mod(l, r)
88            | TLExpr::Min(l, r)
89            | TLExpr::Max(l, r)
90            | TLExpr::Eq(l, r)
91            | TLExpr::Lt(l, r)
92            | TLExpr::Gt(l, r)
93            | TLExpr::Lte(l, r)
94            | TLExpr::Gte(l, r) => {
95                l.collect_and_validate_domains(registry, var_domains)?;
96                r.collect_and_validate_domains(registry, var_domains)?;
97            }
98            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
99                left.collect_and_validate_domains(registry, var_domains)?;
100                right.collect_and_validate_domains(registry, var_domains)?;
101            }
102            TLExpr::FuzzyImplication {
103                premise,
104                conclusion,
105                ..
106            } => {
107                premise.collect_and_validate_domains(registry, var_domains)?;
108                conclusion.collect_and_validate_domains(registry, var_domains)?;
109            }
110            TLExpr::Not(e)
111            | TLExpr::Score(e)
112            | TLExpr::Abs(e)
113            | TLExpr::Floor(e)
114            | TLExpr::Ceil(e)
115            | TLExpr::Round(e)
116            | TLExpr::Sqrt(e)
117            | TLExpr::Exp(e)
118            | TLExpr::Log(e)
119            | TLExpr::Sin(e)
120            | TLExpr::Cos(e)
121            | TLExpr::Tan(e)
122            | TLExpr::Box(e)
123            | TLExpr::Diamond(e)
124            | TLExpr::Next(e)
125            | TLExpr::Eventually(e)
126            | TLExpr::Always(e) => {
127                e.collect_and_validate_domains(registry, var_domains)?;
128            }
129            TLExpr::FuzzyNot { expr, .. } => {
130                expr.collect_and_validate_domains(registry, var_domains)?;
131            }
132            TLExpr::WeightedRule { rule, .. } => {
133                rule.collect_and_validate_domains(registry, var_domains)?;
134            }
135            TLExpr::Until { before, after }
136            | TLExpr::Release {
137                released: before,
138                releaser: after,
139            }
140            | TLExpr::WeakUntil { before, after }
141            | TLExpr::StrongRelease {
142                released: before,
143                releaser: after,
144            } => {
145                before.collect_and_validate_domains(registry, var_domains)?;
146                after.collect_and_validate_domains(registry, var_domains)?;
147            }
148            TLExpr::ProbabilisticChoice { alternatives } => {
149                for (_, expr) in alternatives {
150                    expr.collect_and_validate_domains(registry, var_domains)?;
151                }
152            }
153            TLExpr::IfThenElse {
154                condition,
155                then_branch,
156                else_branch,
157            } => {
158                condition.collect_and_validate_domains(registry, var_domains)?;
159                then_branch.collect_and_validate_domains(registry, var_domains)?;
160                else_branch.collect_and_validate_domains(registry, var_domains)?;
161            }
162            TLExpr::Let { value, body, .. } => {
163                value.collect_and_validate_domains(registry, var_domains)?;
164                body.collect_and_validate_domains(registry, var_domains)?;
165            }
166            TLExpr::Pred { .. } | TLExpr::Constant(_) => {
167                // No domain validation needed for predicates and constants
168            }
169        }
170        Ok(())
171    }
172
173    /// Extract all domains referenced in this expression.
174    pub fn referenced_domains(&self) -> Vec<String> {
175        let mut domains = Vec::new();
176        self.collect_domains(&mut domains);
177        domains.sort();
178        domains.dedup();
179        domains
180    }
181
182    fn collect_domains(&self, domains: &mut Vec<String>) {
183        match self {
184            TLExpr::Exists { domain, body, .. }
185            | TLExpr::ForAll { domain, body, .. }
186            | TLExpr::SoftExists { domain, body, .. }
187            | TLExpr::SoftForAll { domain, body, .. }
188            | TLExpr::Aggregate { domain, body, .. } => {
189                domains.push(domain.clone());
190                body.collect_domains(domains);
191            }
192            TLExpr::And(l, r)
193            | TLExpr::Or(l, r)
194            | TLExpr::Imply(l, r)
195            | TLExpr::Add(l, r)
196            | TLExpr::Sub(l, r)
197            | TLExpr::Mul(l, r)
198            | TLExpr::Div(l, r)
199            | TLExpr::Pow(l, r)
200            | TLExpr::Mod(l, r)
201            | TLExpr::Min(l, r)
202            | TLExpr::Max(l, r)
203            | TLExpr::Eq(l, r)
204            | TLExpr::Lt(l, r)
205            | TLExpr::Gt(l, r)
206            | TLExpr::Lte(l, r)
207            | TLExpr::Gte(l, r) => {
208                l.collect_domains(domains);
209                r.collect_domains(domains);
210            }
211            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
212                left.collect_domains(domains);
213                right.collect_domains(domains);
214            }
215            TLExpr::FuzzyImplication {
216                premise,
217                conclusion,
218                ..
219            } => {
220                premise.collect_domains(domains);
221                conclusion.collect_domains(domains);
222            }
223            TLExpr::Not(e)
224            | TLExpr::Score(e)
225            | TLExpr::Abs(e)
226            | TLExpr::Floor(e)
227            | TLExpr::Ceil(e)
228            | TLExpr::Round(e)
229            | TLExpr::Sqrt(e)
230            | TLExpr::Exp(e)
231            | TLExpr::Log(e)
232            | TLExpr::Sin(e)
233            | TLExpr::Cos(e)
234            | TLExpr::Tan(e)
235            | TLExpr::Box(e)
236            | TLExpr::Diamond(e)
237            | TLExpr::Next(e)
238            | TLExpr::Eventually(e)
239            | TLExpr::Always(e) => {
240                e.collect_domains(domains);
241            }
242            TLExpr::FuzzyNot { expr, .. } => {
243                expr.collect_domains(domains);
244            }
245            TLExpr::WeightedRule { rule, .. } => {
246                rule.collect_domains(domains);
247            }
248            TLExpr::Until { before, after }
249            | TLExpr::Release {
250                released: before,
251                releaser: after,
252            }
253            | TLExpr::WeakUntil { before, after }
254            | TLExpr::StrongRelease {
255                released: before,
256                releaser: after,
257            } => {
258                before.collect_domains(domains);
259                after.collect_domains(domains);
260            }
261            TLExpr::ProbabilisticChoice { alternatives } => {
262                for (_, expr) in alternatives {
263                    expr.collect_domains(domains);
264                }
265            }
266            TLExpr::IfThenElse {
267                condition,
268                then_branch,
269                else_branch,
270            } => {
271                condition.collect_domains(domains);
272                then_branch.collect_domains(domains);
273                else_branch.collect_domains(domains);
274            }
275            TLExpr::Let { value, body, .. } => {
276                value.collect_domains(domains);
277                body.collect_domains(domains);
278            }
279            TLExpr::Pred { .. } | TLExpr::Constant(_) => {}
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use crate::term::Term;
288
289    #[test]
290    fn test_validate_domains_success() {
291        let registry = DomainRegistry::with_builtins();
292
293        let expr = TLExpr::exists("x", "Int", TLExpr::pred("P", vec![Term::var("x")]));
294
295        assert!(expr.validate_domains(&registry).is_ok());
296    }
297
298    #[test]
299    fn test_validate_domains_not_found() {
300        let registry = DomainRegistry::new();
301
302        let expr = TLExpr::exists(
303            "x",
304            "UnknownDomain",
305            TLExpr::pred("P", vec![Term::var("x")]),
306        );
307
308        assert!(expr.validate_domains(&registry).is_err());
309    }
310
311    #[test]
312    fn test_validate_domains_consistent_usage() {
313        let registry = DomainRegistry::with_builtins();
314
315        // ∃x:Int. ∀x:Int. P(x) - same variable, same domain
316        let expr = TLExpr::exists(
317            "x",
318            "Int",
319            TLExpr::forall("x", "Int", TLExpr::pred("P", vec![Term::var("x")])),
320        );
321
322        assert!(expr.validate_domains(&registry).is_ok());
323    }
324
325    #[test]
326    fn test_validate_domains_incompatible() {
327        let registry = DomainRegistry::with_builtins();
328
329        // ∃x:Int. ∀x:Bool. P(x) - same variable, incompatible domains
330        let expr = TLExpr::exists(
331            "x",
332            "Int",
333            TLExpr::forall("x", "Bool", TLExpr::pred("P", vec![Term::var("x")])),
334        );
335
336        assert!(expr.validate_domains(&registry).is_err());
337    }
338
339    #[test]
340    fn test_referenced_domains() {
341        let expr = TLExpr::exists(
342            "x",
343            "Int",
344            TLExpr::forall("y", "Real", TLExpr::pred("P", vec![Term::var("x")])),
345        );
346
347        let domains = expr.referenced_domains();
348        assert_eq!(domains, vec!["Int", "Real"]);
349    }
350
351    #[test]
352    fn test_referenced_domains_dedup() {
353        let expr = TLExpr::and(
354            TLExpr::exists("x", "Int", TLExpr::pred("P", vec![Term::var("x")])),
355            TLExpr::exists("y", "Int", TLExpr::pred("Q", vec![Term::var("y")])),
356        );
357
358        let domains = expr.referenced_domains();
359        assert_eq!(domains, vec!["Int"]);
360    }
361}