1use std::collections::HashMap;
4
5use crate::domain::DomainRegistry;
6use crate::error::IrError;
7use crate::expr::TLExpr;
8
9impl TLExpr {
10 pub fn validate_domains(&self, registry: &DomainRegistry) -> Result<(), IrError> {
16 let mut var_domains = HashMap::new();
17 self.collect_and_validate_domains(registry, &mut var_domains)
18 }
19
20 fn collect_and_validate_domains(
21 &self,
22 registry: &DomainRegistry,
23 var_domains: &mut HashMap<String, String>,
24 ) -> Result<(), IrError> {
25 match self {
26 TLExpr::Exists { var, domain, body }
27 | TLExpr::ForAll { var, domain, body }
28 | TLExpr::SoftExists {
29 var, domain, body, ..
30 }
31 | TLExpr::SoftForAll {
32 var, domain, body, ..
33 } => {
34 registry.validate_domain(domain)?;
36
37 if let Some(existing_domain) = var_domains.get(var) {
39 if existing_domain != domain {
40 if !registry.are_compatible(existing_domain, domain)? {
42 return Err(IrError::VariableDomainMismatch {
43 var: var.clone(),
44 expected: existing_domain.clone(),
45 actual: domain.clone(),
46 });
47 }
48 }
49 } else {
50 var_domains.insert(var.clone(), domain.clone());
51 }
52
53 body.collect_and_validate_domains(registry, var_domains)?;
54 }
55 TLExpr::Aggregate {
56 var, domain, body, ..
57 } => {
58 registry.validate_domain(domain)?;
60
61 if let Some(existing_domain) = var_domains.get(var) {
63 if existing_domain != domain {
64 if !registry.are_compatible(existing_domain, domain)? {
66 return Err(IrError::VariableDomainMismatch {
67 var: var.clone(),
68 expected: existing_domain.clone(),
69 actual: domain.clone(),
70 });
71 }
72 }
73 } else {
74 var_domains.insert(var.clone(), domain.clone());
75 }
76
77 body.collect_and_validate_domains(registry, var_domains)?;
78 }
79 TLExpr::And(l, r)
80 | TLExpr::Or(l, r)
81 | TLExpr::Imply(l, r)
82 | TLExpr::Add(l, r)
83 | TLExpr::Sub(l, r)
84 | TLExpr::Mul(l, r)
85 | TLExpr::Div(l, r)
86 | TLExpr::Pow(l, r)
87 | TLExpr::Mod(l, r)
88 | TLExpr::Min(l, r)
89 | TLExpr::Max(l, r)
90 | TLExpr::Eq(l, r)
91 | TLExpr::Lt(l, r)
92 | TLExpr::Gt(l, r)
93 | TLExpr::Lte(l, r)
94 | TLExpr::Gte(l, r) => {
95 l.collect_and_validate_domains(registry, var_domains)?;
96 r.collect_and_validate_domains(registry, var_domains)?;
97 }
98 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
99 left.collect_and_validate_domains(registry, var_domains)?;
100 right.collect_and_validate_domains(registry, var_domains)?;
101 }
102 TLExpr::FuzzyImplication {
103 premise,
104 conclusion,
105 ..
106 } => {
107 premise.collect_and_validate_domains(registry, var_domains)?;
108 conclusion.collect_and_validate_domains(registry, var_domains)?;
109 }
110 TLExpr::Not(e)
111 | TLExpr::Score(e)
112 | TLExpr::Abs(e)
113 | TLExpr::Floor(e)
114 | TLExpr::Ceil(e)
115 | TLExpr::Round(e)
116 | TLExpr::Sqrt(e)
117 | TLExpr::Exp(e)
118 | TLExpr::Log(e)
119 | TLExpr::Sin(e)
120 | TLExpr::Cos(e)
121 | TLExpr::Tan(e)
122 | TLExpr::Box(e)
123 | TLExpr::Diamond(e)
124 | TLExpr::Next(e)
125 | TLExpr::Eventually(e)
126 | TLExpr::Always(e) => {
127 e.collect_and_validate_domains(registry, var_domains)?;
128 }
129 TLExpr::FuzzyNot { expr, .. } => {
130 expr.collect_and_validate_domains(registry, var_domains)?;
131 }
132 TLExpr::WeightedRule { rule, .. } => {
133 rule.collect_and_validate_domains(registry, var_domains)?;
134 }
135 TLExpr::Until { before, after }
136 | TLExpr::Release {
137 released: before,
138 releaser: after,
139 }
140 | TLExpr::WeakUntil { before, after }
141 | TLExpr::StrongRelease {
142 released: before,
143 releaser: after,
144 } => {
145 before.collect_and_validate_domains(registry, var_domains)?;
146 after.collect_and_validate_domains(registry, var_domains)?;
147 }
148 TLExpr::ProbabilisticChoice { alternatives } => {
149 for (_, expr) in alternatives {
150 expr.collect_and_validate_domains(registry, var_domains)?;
151 }
152 }
153 TLExpr::IfThenElse {
154 condition,
155 then_branch,
156 else_branch,
157 } => {
158 condition.collect_and_validate_domains(registry, var_domains)?;
159 then_branch.collect_and_validate_domains(registry, var_domains)?;
160 else_branch.collect_and_validate_domains(registry, var_domains)?;
161 }
162 TLExpr::Let { value, body, .. } => {
163 value.collect_and_validate_domains(registry, var_domains)?;
164 body.collect_and_validate_domains(registry, var_domains)?;
165 }
166 TLExpr::Pred { .. } | TLExpr::Constant(_) => {
167 }
169 }
170 Ok(())
171 }
172
173 pub fn referenced_domains(&self) -> Vec<String> {
175 let mut domains = Vec::new();
176 self.collect_domains(&mut domains);
177 domains.sort();
178 domains.dedup();
179 domains
180 }
181
182 fn collect_domains(&self, domains: &mut Vec<String>) {
183 match self {
184 TLExpr::Exists { domain, body, .. }
185 | TLExpr::ForAll { domain, body, .. }
186 | TLExpr::SoftExists { domain, body, .. }
187 | TLExpr::SoftForAll { domain, body, .. }
188 | TLExpr::Aggregate { domain, body, .. } => {
189 domains.push(domain.clone());
190 body.collect_domains(domains);
191 }
192 TLExpr::And(l, r)
193 | TLExpr::Or(l, r)
194 | TLExpr::Imply(l, r)
195 | TLExpr::Add(l, r)
196 | TLExpr::Sub(l, r)
197 | TLExpr::Mul(l, r)
198 | TLExpr::Div(l, r)
199 | TLExpr::Pow(l, r)
200 | TLExpr::Mod(l, r)
201 | TLExpr::Min(l, r)
202 | TLExpr::Max(l, r)
203 | TLExpr::Eq(l, r)
204 | TLExpr::Lt(l, r)
205 | TLExpr::Gt(l, r)
206 | TLExpr::Lte(l, r)
207 | TLExpr::Gte(l, r) => {
208 l.collect_domains(domains);
209 r.collect_domains(domains);
210 }
211 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
212 left.collect_domains(domains);
213 right.collect_domains(domains);
214 }
215 TLExpr::FuzzyImplication {
216 premise,
217 conclusion,
218 ..
219 } => {
220 premise.collect_domains(domains);
221 conclusion.collect_domains(domains);
222 }
223 TLExpr::Not(e)
224 | TLExpr::Score(e)
225 | TLExpr::Abs(e)
226 | TLExpr::Floor(e)
227 | TLExpr::Ceil(e)
228 | TLExpr::Round(e)
229 | TLExpr::Sqrt(e)
230 | TLExpr::Exp(e)
231 | TLExpr::Log(e)
232 | TLExpr::Sin(e)
233 | TLExpr::Cos(e)
234 | TLExpr::Tan(e)
235 | TLExpr::Box(e)
236 | TLExpr::Diamond(e)
237 | TLExpr::Next(e)
238 | TLExpr::Eventually(e)
239 | TLExpr::Always(e) => {
240 e.collect_domains(domains);
241 }
242 TLExpr::FuzzyNot { expr, .. } => {
243 expr.collect_domains(domains);
244 }
245 TLExpr::WeightedRule { rule, .. } => {
246 rule.collect_domains(domains);
247 }
248 TLExpr::Until { before, after }
249 | TLExpr::Release {
250 released: before,
251 releaser: after,
252 }
253 | TLExpr::WeakUntil { before, after }
254 | TLExpr::StrongRelease {
255 released: before,
256 releaser: after,
257 } => {
258 before.collect_domains(domains);
259 after.collect_domains(domains);
260 }
261 TLExpr::ProbabilisticChoice { alternatives } => {
262 for (_, expr) in alternatives {
263 expr.collect_domains(domains);
264 }
265 }
266 TLExpr::IfThenElse {
267 condition,
268 then_branch,
269 else_branch,
270 } => {
271 condition.collect_domains(domains);
272 then_branch.collect_domains(domains);
273 else_branch.collect_domains(domains);
274 }
275 TLExpr::Let { value, body, .. } => {
276 value.collect_domains(domains);
277 body.collect_domains(domains);
278 }
279 TLExpr::Pred { .. } | TLExpr::Constant(_) => {}
280 }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use crate::term::Term;
288
289 #[test]
290 fn test_validate_domains_success() {
291 let registry = DomainRegistry::with_builtins();
292
293 let expr = TLExpr::exists("x", "Int", TLExpr::pred("P", vec![Term::var("x")]));
294
295 assert!(expr.validate_domains(®istry).is_ok());
296 }
297
298 #[test]
299 fn test_validate_domains_not_found() {
300 let registry = DomainRegistry::new();
301
302 let expr = TLExpr::exists(
303 "x",
304 "UnknownDomain",
305 TLExpr::pred("P", vec![Term::var("x")]),
306 );
307
308 assert!(expr.validate_domains(®istry).is_err());
309 }
310
311 #[test]
312 fn test_validate_domains_consistent_usage() {
313 let registry = DomainRegistry::with_builtins();
314
315 let expr = TLExpr::exists(
317 "x",
318 "Int",
319 TLExpr::forall("x", "Int", TLExpr::pred("P", vec![Term::var("x")])),
320 );
321
322 assert!(expr.validate_domains(®istry).is_ok());
323 }
324
325 #[test]
326 fn test_validate_domains_incompatible() {
327 let registry = DomainRegistry::with_builtins();
328
329 let expr = TLExpr::exists(
331 "x",
332 "Int",
333 TLExpr::forall("x", "Bool", TLExpr::pred("P", vec![Term::var("x")])),
334 );
335
336 assert!(expr.validate_domains(®istry).is_err());
337 }
338
339 #[test]
340 fn test_referenced_domains() {
341 let expr = TLExpr::exists(
342 "x",
343 "Int",
344 TLExpr::forall("y", "Real", TLExpr::pred("P", vec![Term::var("x")])),
345 );
346
347 let domains = expr.referenced_domains();
348 assert_eq!(domains, vec!["Int", "Real"]);
349 }
350
351 #[test]
352 fn test_referenced_domains_dedup() {
353 let expr = TLExpr::and(
354 TLExpr::exists("x", "Int", TLExpr::pred("P", vec![Term::var("x")])),
355 TLExpr::exists("y", "Int", TLExpr::pred("Q", vec![Term::var("y")])),
356 );
357
358 let domains = expr.referenced_domains();
359 assert_eq!(domains, vec!["Int"]);
360 }
361}