1use std::collections::HashMap;
4
5use crate::domain::DomainRegistry;
6use crate::error::IrError;
7use crate::expr::TLExpr;
8
9impl TLExpr {
10 pub fn validate_domains(&self, registry: &DomainRegistry) -> Result<(), IrError> {
16 let mut var_domains = HashMap::new();
17 self.collect_and_validate_domains(registry, &mut var_domains)
18 }
19
20 fn collect_and_validate_domains(
21 &self,
22 registry: &DomainRegistry,
23 var_domains: &mut HashMap<String, String>,
24 ) -> Result<(), IrError> {
25 match self {
26 TLExpr::Exists { var, domain, body }
27 | TLExpr::ForAll { var, domain, body }
28 | TLExpr::SoftExists {
29 var, domain, body, ..
30 }
31 | TLExpr::SoftForAll {
32 var, domain, body, ..
33 } => {
34 registry.validate_domain(domain)?;
36
37 if let Some(existing_domain) = var_domains.get(var) {
39 if existing_domain != domain {
40 if !registry.are_compatible(existing_domain, domain)? {
42 return Err(IrError::VariableDomainMismatch {
43 var: var.clone(),
44 expected: existing_domain.clone(),
45 actual: domain.clone(),
46 });
47 }
48 }
49 } else {
50 var_domains.insert(var.clone(), domain.clone());
51 }
52
53 body.collect_and_validate_domains(registry, var_domains)?;
54 }
55 TLExpr::Aggregate {
56 var, domain, body, ..
57 } => {
58 registry.validate_domain(domain)?;
60
61 if let Some(existing_domain) = var_domains.get(var) {
63 if existing_domain != domain {
64 if !registry.are_compatible(existing_domain, domain)? {
66 return Err(IrError::VariableDomainMismatch {
67 var: var.clone(),
68 expected: existing_domain.clone(),
69 actual: domain.clone(),
70 });
71 }
72 }
73 } else {
74 var_domains.insert(var.clone(), domain.clone());
75 }
76
77 body.collect_and_validate_domains(registry, var_domains)?;
78 }
79 TLExpr::And(l, r)
80 | TLExpr::Or(l, r)
81 | TLExpr::Imply(l, r)
82 | TLExpr::Add(l, r)
83 | TLExpr::Sub(l, r)
84 | TLExpr::Mul(l, r)
85 | TLExpr::Div(l, r)
86 | TLExpr::Pow(l, r)
87 | TLExpr::Mod(l, r)
88 | TLExpr::Min(l, r)
89 | TLExpr::Max(l, r)
90 | TLExpr::Eq(l, r)
91 | TLExpr::Lt(l, r)
92 | TLExpr::Gt(l, r)
93 | TLExpr::Lte(l, r)
94 | TLExpr::Gte(l, r) => {
95 l.collect_and_validate_domains(registry, var_domains)?;
96 r.collect_and_validate_domains(registry, var_domains)?;
97 }
98 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
99 left.collect_and_validate_domains(registry, var_domains)?;
100 right.collect_and_validate_domains(registry, var_domains)?;
101 }
102 TLExpr::FuzzyImplication {
103 premise,
104 conclusion,
105 ..
106 } => {
107 premise.collect_and_validate_domains(registry, var_domains)?;
108 conclusion.collect_and_validate_domains(registry, var_domains)?;
109 }
110 TLExpr::Not(e)
111 | TLExpr::Score(e)
112 | TLExpr::Abs(e)
113 | TLExpr::Floor(e)
114 | TLExpr::Ceil(e)
115 | TLExpr::Round(e)
116 | TLExpr::Sqrt(e)
117 | TLExpr::Exp(e)
118 | TLExpr::Log(e)
119 | TLExpr::Sin(e)
120 | TLExpr::Cos(e)
121 | TLExpr::Tan(e)
122 | TLExpr::Box(e)
123 | TLExpr::Diamond(e)
124 | TLExpr::Next(e)
125 | TLExpr::Eventually(e)
126 | TLExpr::Always(e) => {
127 e.collect_and_validate_domains(registry, var_domains)?;
128 }
129 TLExpr::FuzzyNot { expr, .. } => {
130 expr.collect_and_validate_domains(registry, var_domains)?;
131 }
132 TLExpr::WeightedRule { rule, .. } => {
133 rule.collect_and_validate_domains(registry, var_domains)?;
134 }
135 TLExpr::Until { before, after }
136 | TLExpr::Release {
137 released: before,
138 releaser: after,
139 }
140 | TLExpr::WeakUntil { before, after }
141 | TLExpr::StrongRelease {
142 released: before,
143 releaser: after,
144 } => {
145 before.collect_and_validate_domains(registry, var_domains)?;
146 after.collect_and_validate_domains(registry, var_domains)?;
147 }
148 TLExpr::ProbabilisticChoice { alternatives } => {
149 for (_, expr) in alternatives {
150 expr.collect_and_validate_domains(registry, var_domains)?;
151 }
152 }
153 TLExpr::IfThenElse {
154 condition,
155 then_branch,
156 else_branch,
157 } => {
158 condition.collect_and_validate_domains(registry, var_domains)?;
159 then_branch.collect_and_validate_domains(registry, var_domains)?;
160 else_branch.collect_and_validate_domains(registry, var_domains)?;
161 }
162 TLExpr::Let { value, body, .. } => {
163 value.collect_and_validate_domains(registry, var_domains)?;
164 body.collect_and_validate_domains(registry, var_domains)?;
165 }
166 TLExpr::Lambda { body, .. } => {
168 body.collect_and_validate_domains(registry, var_domains)?;
170 }
171 TLExpr::Apply { function, argument } => {
172 function.collect_and_validate_domains(registry, var_domains)?;
173 argument.collect_and_validate_domains(registry, var_domains)?;
174 }
175 TLExpr::SetMembership { element, set }
176 | TLExpr::SetUnion {
177 left: element,
178 right: set,
179 }
180 | TLExpr::SetIntersection {
181 left: element,
182 right: set,
183 }
184 | TLExpr::SetDifference {
185 left: element,
186 right: set,
187 } => {
188 element.collect_and_validate_domains(registry, var_domains)?;
189 set.collect_and_validate_domains(registry, var_domains)?;
190 }
191 TLExpr::SetCardinality { set } => {
192 set.collect_and_validate_domains(registry, var_domains)?;
193 }
194 TLExpr::EmptySet => {
195 }
197 TLExpr::SetComprehension {
198 var,
199 domain,
200 condition,
201 } => {
202 registry.validate_domain(domain)?;
203 if let Some(existing_domain) = var_domains.get(var) {
204 if existing_domain != domain
205 && !registry.are_compatible(existing_domain, domain)?
206 {
207 return Err(IrError::VariableDomainMismatch {
208 var: var.clone(),
209 expected: existing_domain.clone(),
210 actual: domain.clone(),
211 });
212 }
213 } else {
214 var_domains.insert(var.clone(), domain.clone());
215 }
216 condition.collect_and_validate_domains(registry, var_domains)?;
217 }
218 TLExpr::CountingExists {
219 var, domain, body, ..
220 }
221 | TLExpr::CountingForAll {
222 var, domain, body, ..
223 }
224 | TLExpr::ExactCount {
225 var, domain, body, ..
226 }
227 | TLExpr::Majority { var, domain, body } => {
228 registry.validate_domain(domain)?;
229 if let Some(existing_domain) = var_domains.get(var) {
230 if existing_domain != domain
231 && !registry.are_compatible(existing_domain, domain)?
232 {
233 return Err(IrError::VariableDomainMismatch {
234 var: var.clone(),
235 expected: existing_domain.clone(),
236 actual: domain.clone(),
237 });
238 }
239 } else {
240 var_domains.insert(var.clone(), domain.clone());
241 }
242 body.collect_and_validate_domains(registry, var_domains)?;
243 }
244 TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
245 body.collect_and_validate_domains(registry, var_domains)?;
246 }
247 TLExpr::Nominal { .. } => {
248 }
250 TLExpr::At { formula, .. } => {
251 formula.collect_and_validate_domains(registry, var_domains)?;
252 }
253 TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
254 formula.collect_and_validate_domains(registry, var_domains)?;
255 }
256 TLExpr::AllDifferent { .. } => {
257 }
259 TLExpr::GlobalCardinality { values, .. } => {
260 for val in values {
261 val.collect_and_validate_domains(registry, var_domains)?;
262 }
263 }
264 TLExpr::Abducible { .. } => {
265 }
267 TLExpr::Explain { formula } => {
268 formula.collect_and_validate_domains(registry, var_domains)?;
269 }
270 TLExpr::SymbolLiteral(_) => {
271 }
273 TLExpr::Match { scrutinee, arms } => {
274 scrutinee.collect_and_validate_domains(registry, var_domains)?;
275 for (_, body) in arms {
276 body.collect_and_validate_domains(registry, var_domains)?;
277 }
278 }
279 TLExpr::Pred { .. } | TLExpr::Constant(_) => {
280 }
282 }
283 Ok(())
284 }
285
286 pub fn referenced_domains(&self) -> Vec<String> {
288 let mut domains = Vec::new();
289 self.collect_domains(&mut domains);
290 domains.sort();
291 domains.dedup();
292 domains
293 }
294
295 fn collect_domains(&self, domains: &mut Vec<String>) {
296 match self {
297 TLExpr::Exists { domain, body, .. }
298 | TLExpr::ForAll { domain, body, .. }
299 | TLExpr::SoftExists { domain, body, .. }
300 | TLExpr::SoftForAll { domain, body, .. }
301 | TLExpr::Aggregate { domain, body, .. } => {
302 domains.push(domain.clone());
303 body.collect_domains(domains);
304 }
305 TLExpr::And(l, r)
306 | TLExpr::Or(l, r)
307 | TLExpr::Imply(l, r)
308 | TLExpr::Add(l, r)
309 | TLExpr::Sub(l, r)
310 | TLExpr::Mul(l, r)
311 | TLExpr::Div(l, r)
312 | TLExpr::Pow(l, r)
313 | TLExpr::Mod(l, r)
314 | TLExpr::Min(l, r)
315 | TLExpr::Max(l, r)
316 | TLExpr::Eq(l, r)
317 | TLExpr::Lt(l, r)
318 | TLExpr::Gt(l, r)
319 | TLExpr::Lte(l, r)
320 | TLExpr::Gte(l, r) => {
321 l.collect_domains(domains);
322 r.collect_domains(domains);
323 }
324 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
325 left.collect_domains(domains);
326 right.collect_domains(domains);
327 }
328 TLExpr::FuzzyImplication {
329 premise,
330 conclusion,
331 ..
332 } => {
333 premise.collect_domains(domains);
334 conclusion.collect_domains(domains);
335 }
336 TLExpr::Not(e)
337 | TLExpr::Score(e)
338 | TLExpr::Abs(e)
339 | TLExpr::Floor(e)
340 | TLExpr::Ceil(e)
341 | TLExpr::Round(e)
342 | TLExpr::Sqrt(e)
343 | TLExpr::Exp(e)
344 | TLExpr::Log(e)
345 | TLExpr::Sin(e)
346 | TLExpr::Cos(e)
347 | TLExpr::Tan(e)
348 | TLExpr::Box(e)
349 | TLExpr::Diamond(e)
350 | TLExpr::Next(e)
351 | TLExpr::Eventually(e)
352 | TLExpr::Always(e) => {
353 e.collect_domains(domains);
354 }
355 TLExpr::FuzzyNot { expr, .. } => {
356 expr.collect_domains(domains);
357 }
358 TLExpr::WeightedRule { rule, .. } => {
359 rule.collect_domains(domains);
360 }
361 TLExpr::Until { before, after }
362 | TLExpr::Release {
363 released: before,
364 releaser: after,
365 }
366 | TLExpr::WeakUntil { before, after }
367 | TLExpr::StrongRelease {
368 released: before,
369 releaser: after,
370 } => {
371 before.collect_domains(domains);
372 after.collect_domains(domains);
373 }
374 TLExpr::ProbabilisticChoice { alternatives } => {
375 for (_, expr) in alternatives {
376 expr.collect_domains(domains);
377 }
378 }
379 TLExpr::IfThenElse {
380 condition,
381 then_branch,
382 else_branch,
383 } => {
384 condition.collect_domains(domains);
385 then_branch.collect_domains(domains);
386 else_branch.collect_domains(domains);
387 }
388 TLExpr::Let { value, body, .. } => {
389 value.collect_domains(domains);
390 body.collect_domains(domains);
391 }
392 TLExpr::Lambda { body, .. } => {
394 body.collect_domains(domains);
395 }
396 TLExpr::Apply { function, argument } => {
397 function.collect_domains(domains);
398 argument.collect_domains(domains);
399 }
400 TLExpr::SetMembership { element, set }
401 | TLExpr::SetUnion {
402 left: element,
403 right: set,
404 }
405 | TLExpr::SetIntersection {
406 left: element,
407 right: set,
408 }
409 | TLExpr::SetDifference {
410 left: element,
411 right: set,
412 } => {
413 element.collect_domains(domains);
414 set.collect_domains(domains);
415 }
416 TLExpr::SetCardinality { set } => {
417 set.collect_domains(domains);
418 }
419 TLExpr::EmptySet => {}
420 TLExpr::SetComprehension {
421 domain, condition, ..
422 } => {
423 domains.push(domain.clone());
424 condition.collect_domains(domains);
425 }
426 TLExpr::CountingExists { domain, body, .. }
427 | TLExpr::CountingForAll { domain, body, .. }
428 | TLExpr::ExactCount { domain, body, .. }
429 | TLExpr::Majority { domain, body, .. } => {
430 domains.push(domain.clone());
431 body.collect_domains(domains);
432 }
433 TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
434 body.collect_domains(domains);
435 }
436 TLExpr::Nominal { .. } => {}
437 TLExpr::At { formula, .. } => {
438 formula.collect_domains(domains);
439 }
440 TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
441 formula.collect_domains(domains);
442 }
443 TLExpr::AllDifferent { .. } => {}
444 TLExpr::GlobalCardinality { values, .. } => {
445 for val in values {
446 val.collect_domains(domains);
447 }
448 }
449 TLExpr::Abducible { .. } => {}
450 TLExpr::Explain { formula } => {
451 formula.collect_domains(domains);
452 }
453 TLExpr::SymbolLiteral(_) => {}
454 TLExpr::Match { scrutinee, arms } => {
455 scrutinee.collect_domains(domains);
456 for (_, body) in arms {
457 body.collect_domains(domains);
458 }
459 }
460 TLExpr::Pred { .. } | TLExpr::Constant(_) => {}
461 }
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468 use crate::term::Term;
469
470 #[test]
471 fn test_validate_domains_success() {
472 let registry = DomainRegistry::with_builtins();
473
474 let expr = TLExpr::exists("x", "Int", TLExpr::pred("P", vec![Term::var("x")]));
475
476 assert!(expr.validate_domains(®istry).is_ok());
477 }
478
479 #[test]
480 fn test_validate_domains_not_found() {
481 let registry = DomainRegistry::new();
482
483 let expr = TLExpr::exists(
484 "x",
485 "UnknownDomain",
486 TLExpr::pred("P", vec![Term::var("x")]),
487 );
488
489 assert!(expr.validate_domains(®istry).is_err());
490 }
491
492 #[test]
493 fn test_validate_domains_consistent_usage() {
494 let registry = DomainRegistry::with_builtins();
495
496 let expr = TLExpr::exists(
498 "x",
499 "Int",
500 TLExpr::forall("x", "Int", TLExpr::pred("P", vec![Term::var("x")])),
501 );
502
503 assert!(expr.validate_domains(®istry).is_ok());
504 }
505
506 #[test]
507 fn test_validate_domains_incompatible() {
508 let registry = DomainRegistry::with_builtins();
509
510 let expr = TLExpr::exists(
512 "x",
513 "Int",
514 TLExpr::forall("x", "Bool", TLExpr::pred("P", vec![Term::var("x")])),
515 );
516
517 assert!(expr.validate_domains(®istry).is_err());
518 }
519
520 #[test]
521 fn test_referenced_domains() {
522 let expr = TLExpr::exists(
523 "x",
524 "Int",
525 TLExpr::forall("y", "Real", TLExpr::pred("P", vec![Term::var("x")])),
526 );
527
528 let domains = expr.referenced_domains();
529 assert_eq!(domains, vec!["Int", "Real"]);
530 }
531
532 #[test]
533 fn test_referenced_domains_dedup() {
534 let expr = TLExpr::and(
535 TLExpr::exists("x", "Int", TLExpr::pred("P", vec![Term::var("x")])),
536 TLExpr::exists("y", "Int", TLExpr::pred("Q", vec![Term::var("y")])),
537 );
538
539 let domains = expr.referenced_domains();
540 assert_eq!(domains, vec!["Int"]);
541 }
542}