1use std::collections::HashMap;
4
5use crate::domain::DomainRegistry;
6use crate::error::IrError;
7use crate::expr::TLExpr;
8
9impl TLExpr {
10 pub fn validate_domains(&self, registry: &DomainRegistry) -> Result<(), IrError> {
16 let mut var_domains = HashMap::new();
17 self.collect_and_validate_domains(registry, &mut var_domains)
18 }
19
20 fn collect_and_validate_domains(
21 &self,
22 registry: &DomainRegistry,
23 var_domains: &mut HashMap<String, String>,
24 ) -> Result<(), IrError> {
25 match self {
26 TLExpr::Exists { var, domain, body }
27 | TLExpr::ForAll { var, domain, body }
28 | TLExpr::SoftExists {
29 var, domain, body, ..
30 }
31 | TLExpr::SoftForAll {
32 var, domain, body, ..
33 } => {
34 registry.validate_domain(domain)?;
36
37 if let Some(existing_domain) = var_domains.get(var) {
39 if existing_domain != domain {
40 if !registry.are_compatible(existing_domain, domain)? {
42 return Err(IrError::VariableDomainMismatch {
43 var: var.clone(),
44 expected: existing_domain.clone(),
45 actual: domain.clone(),
46 });
47 }
48 }
49 } else {
50 var_domains.insert(var.clone(), domain.clone());
51 }
52
53 body.collect_and_validate_domains(registry, var_domains)?;
54 }
55 TLExpr::Aggregate {
56 var, domain, body, ..
57 } => {
58 registry.validate_domain(domain)?;
60
61 if let Some(existing_domain) = var_domains.get(var) {
63 if existing_domain != domain {
64 if !registry.are_compatible(existing_domain, domain)? {
66 return Err(IrError::VariableDomainMismatch {
67 var: var.clone(),
68 expected: existing_domain.clone(),
69 actual: domain.clone(),
70 });
71 }
72 }
73 } else {
74 var_domains.insert(var.clone(), domain.clone());
75 }
76
77 body.collect_and_validate_domains(registry, var_domains)?;
78 }
79 TLExpr::And(l, r)
80 | TLExpr::Or(l, r)
81 | TLExpr::Imply(l, r)
82 | TLExpr::Add(l, r)
83 | TLExpr::Sub(l, r)
84 | TLExpr::Mul(l, r)
85 | TLExpr::Div(l, r)
86 | TLExpr::Pow(l, r)
87 | TLExpr::Mod(l, r)
88 | TLExpr::Min(l, r)
89 | TLExpr::Max(l, r)
90 | TLExpr::Eq(l, r)
91 | TLExpr::Lt(l, r)
92 | TLExpr::Gt(l, r)
93 | TLExpr::Lte(l, r)
94 | TLExpr::Gte(l, r) => {
95 l.collect_and_validate_domains(registry, var_domains)?;
96 r.collect_and_validate_domains(registry, var_domains)?;
97 }
98 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
99 left.collect_and_validate_domains(registry, var_domains)?;
100 right.collect_and_validate_domains(registry, var_domains)?;
101 }
102 TLExpr::FuzzyImplication {
103 premise,
104 conclusion,
105 ..
106 } => {
107 premise.collect_and_validate_domains(registry, var_domains)?;
108 conclusion.collect_and_validate_domains(registry, var_domains)?;
109 }
110 TLExpr::Not(e)
111 | TLExpr::Score(e)
112 | TLExpr::Abs(e)
113 | TLExpr::Floor(e)
114 | TLExpr::Ceil(e)
115 | TLExpr::Round(e)
116 | TLExpr::Sqrt(e)
117 | TLExpr::Exp(e)
118 | TLExpr::Log(e)
119 | TLExpr::Sin(e)
120 | TLExpr::Cos(e)
121 | TLExpr::Tan(e)
122 | TLExpr::Box(e)
123 | TLExpr::Diamond(e)
124 | TLExpr::Next(e)
125 | TLExpr::Eventually(e)
126 | TLExpr::Always(e) => {
127 e.collect_and_validate_domains(registry, var_domains)?;
128 }
129 TLExpr::FuzzyNot { expr, .. } => {
130 expr.collect_and_validate_domains(registry, var_domains)?;
131 }
132 TLExpr::WeightedRule { rule, .. } => {
133 rule.collect_and_validate_domains(registry, var_domains)?;
134 }
135 TLExpr::Until { before, after }
136 | TLExpr::Release {
137 released: before,
138 releaser: after,
139 }
140 | TLExpr::WeakUntil { before, after }
141 | TLExpr::StrongRelease {
142 released: before,
143 releaser: after,
144 } => {
145 before.collect_and_validate_domains(registry, var_domains)?;
146 after.collect_and_validate_domains(registry, var_domains)?;
147 }
148 TLExpr::ProbabilisticChoice { alternatives } => {
149 for (_, expr) in alternatives {
150 expr.collect_and_validate_domains(registry, var_domains)?;
151 }
152 }
153 TLExpr::IfThenElse {
154 condition,
155 then_branch,
156 else_branch,
157 } => {
158 condition.collect_and_validate_domains(registry, var_domains)?;
159 then_branch.collect_and_validate_domains(registry, var_domains)?;
160 else_branch.collect_and_validate_domains(registry, var_domains)?;
161 }
162 TLExpr::Let { value, body, .. } => {
163 value.collect_and_validate_domains(registry, var_domains)?;
164 body.collect_and_validate_domains(registry, var_domains)?;
165 }
166 TLExpr::Lambda { body, .. } => {
168 body.collect_and_validate_domains(registry, var_domains)?;
170 }
171 TLExpr::Apply { function, argument } => {
172 function.collect_and_validate_domains(registry, var_domains)?;
173 argument.collect_and_validate_domains(registry, var_domains)?;
174 }
175 TLExpr::SetMembership { element, set }
176 | TLExpr::SetUnion {
177 left: element,
178 right: set,
179 }
180 | TLExpr::SetIntersection {
181 left: element,
182 right: set,
183 }
184 | TLExpr::SetDifference {
185 left: element,
186 right: set,
187 } => {
188 element.collect_and_validate_domains(registry, var_domains)?;
189 set.collect_and_validate_domains(registry, var_domains)?;
190 }
191 TLExpr::SetCardinality { set } => {
192 set.collect_and_validate_domains(registry, var_domains)?;
193 }
194 TLExpr::EmptySet => {
195 }
197 TLExpr::SetComprehension {
198 var,
199 domain,
200 condition,
201 } => {
202 registry.validate_domain(domain)?;
203 if let Some(existing_domain) = var_domains.get(var) {
204 if existing_domain != domain
205 && !registry.are_compatible(existing_domain, domain)?
206 {
207 return Err(IrError::VariableDomainMismatch {
208 var: var.clone(),
209 expected: existing_domain.clone(),
210 actual: domain.clone(),
211 });
212 }
213 } else {
214 var_domains.insert(var.clone(), domain.clone());
215 }
216 condition.collect_and_validate_domains(registry, var_domains)?;
217 }
218 TLExpr::CountingExists {
219 var, domain, body, ..
220 }
221 | TLExpr::CountingForAll {
222 var, domain, body, ..
223 }
224 | TLExpr::ExactCount {
225 var, domain, body, ..
226 }
227 | TLExpr::Majority { var, domain, body } => {
228 registry.validate_domain(domain)?;
229 if let Some(existing_domain) = var_domains.get(var) {
230 if existing_domain != domain
231 && !registry.are_compatible(existing_domain, domain)?
232 {
233 return Err(IrError::VariableDomainMismatch {
234 var: var.clone(),
235 expected: existing_domain.clone(),
236 actual: domain.clone(),
237 });
238 }
239 } else {
240 var_domains.insert(var.clone(), domain.clone());
241 }
242 body.collect_and_validate_domains(registry, var_domains)?;
243 }
244 TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
245 body.collect_and_validate_domains(registry, var_domains)?;
246 }
247 TLExpr::Nominal { .. } => {
248 }
250 TLExpr::At { formula, .. } => {
251 formula.collect_and_validate_domains(registry, var_domains)?;
252 }
253 TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
254 formula.collect_and_validate_domains(registry, var_domains)?;
255 }
256 TLExpr::AllDifferent { .. } => {
257 }
259 TLExpr::GlobalCardinality { values, .. } => {
260 for val in values {
261 val.collect_and_validate_domains(registry, var_domains)?;
262 }
263 }
264 TLExpr::Abducible { .. } => {
265 }
267 TLExpr::Explain { formula } => {
268 formula.collect_and_validate_domains(registry, var_domains)?;
269 }
270 TLExpr::Pred { .. } | TLExpr::Constant(_) => {
271 }
273 }
274 Ok(())
275 }
276
277 pub fn referenced_domains(&self) -> Vec<String> {
279 let mut domains = Vec::new();
280 self.collect_domains(&mut domains);
281 domains.sort();
282 domains.dedup();
283 domains
284 }
285
286 fn collect_domains(&self, domains: &mut Vec<String>) {
287 match self {
288 TLExpr::Exists { domain, body, .. }
289 | TLExpr::ForAll { domain, body, .. }
290 | TLExpr::SoftExists { domain, body, .. }
291 | TLExpr::SoftForAll { domain, body, .. }
292 | TLExpr::Aggregate { domain, body, .. } => {
293 domains.push(domain.clone());
294 body.collect_domains(domains);
295 }
296 TLExpr::And(l, r)
297 | TLExpr::Or(l, r)
298 | TLExpr::Imply(l, r)
299 | TLExpr::Add(l, r)
300 | TLExpr::Sub(l, r)
301 | TLExpr::Mul(l, r)
302 | TLExpr::Div(l, r)
303 | TLExpr::Pow(l, r)
304 | TLExpr::Mod(l, r)
305 | TLExpr::Min(l, r)
306 | TLExpr::Max(l, r)
307 | TLExpr::Eq(l, r)
308 | TLExpr::Lt(l, r)
309 | TLExpr::Gt(l, r)
310 | TLExpr::Lte(l, r)
311 | TLExpr::Gte(l, r) => {
312 l.collect_domains(domains);
313 r.collect_domains(domains);
314 }
315 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
316 left.collect_domains(domains);
317 right.collect_domains(domains);
318 }
319 TLExpr::FuzzyImplication {
320 premise,
321 conclusion,
322 ..
323 } => {
324 premise.collect_domains(domains);
325 conclusion.collect_domains(domains);
326 }
327 TLExpr::Not(e)
328 | TLExpr::Score(e)
329 | TLExpr::Abs(e)
330 | TLExpr::Floor(e)
331 | TLExpr::Ceil(e)
332 | TLExpr::Round(e)
333 | TLExpr::Sqrt(e)
334 | TLExpr::Exp(e)
335 | TLExpr::Log(e)
336 | TLExpr::Sin(e)
337 | TLExpr::Cos(e)
338 | TLExpr::Tan(e)
339 | TLExpr::Box(e)
340 | TLExpr::Diamond(e)
341 | TLExpr::Next(e)
342 | TLExpr::Eventually(e)
343 | TLExpr::Always(e) => {
344 e.collect_domains(domains);
345 }
346 TLExpr::FuzzyNot { expr, .. } => {
347 expr.collect_domains(domains);
348 }
349 TLExpr::WeightedRule { rule, .. } => {
350 rule.collect_domains(domains);
351 }
352 TLExpr::Until { before, after }
353 | TLExpr::Release {
354 released: before,
355 releaser: after,
356 }
357 | TLExpr::WeakUntil { before, after }
358 | TLExpr::StrongRelease {
359 released: before,
360 releaser: after,
361 } => {
362 before.collect_domains(domains);
363 after.collect_domains(domains);
364 }
365 TLExpr::ProbabilisticChoice { alternatives } => {
366 for (_, expr) in alternatives {
367 expr.collect_domains(domains);
368 }
369 }
370 TLExpr::IfThenElse {
371 condition,
372 then_branch,
373 else_branch,
374 } => {
375 condition.collect_domains(domains);
376 then_branch.collect_domains(domains);
377 else_branch.collect_domains(domains);
378 }
379 TLExpr::Let { value, body, .. } => {
380 value.collect_domains(domains);
381 body.collect_domains(domains);
382 }
383 TLExpr::Lambda { body, .. } => {
385 body.collect_domains(domains);
386 }
387 TLExpr::Apply { function, argument } => {
388 function.collect_domains(domains);
389 argument.collect_domains(domains);
390 }
391 TLExpr::SetMembership { element, set }
392 | TLExpr::SetUnion {
393 left: element,
394 right: set,
395 }
396 | TLExpr::SetIntersection {
397 left: element,
398 right: set,
399 }
400 | TLExpr::SetDifference {
401 left: element,
402 right: set,
403 } => {
404 element.collect_domains(domains);
405 set.collect_domains(domains);
406 }
407 TLExpr::SetCardinality { set } => {
408 set.collect_domains(domains);
409 }
410 TLExpr::EmptySet => {}
411 TLExpr::SetComprehension {
412 domain, condition, ..
413 } => {
414 domains.push(domain.clone());
415 condition.collect_domains(domains);
416 }
417 TLExpr::CountingExists { domain, body, .. }
418 | TLExpr::CountingForAll { domain, body, .. }
419 | TLExpr::ExactCount { domain, body, .. }
420 | TLExpr::Majority { domain, body, .. } => {
421 domains.push(domain.clone());
422 body.collect_domains(domains);
423 }
424 TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
425 body.collect_domains(domains);
426 }
427 TLExpr::Nominal { .. } => {}
428 TLExpr::At { formula, .. } => {
429 formula.collect_domains(domains);
430 }
431 TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
432 formula.collect_domains(domains);
433 }
434 TLExpr::AllDifferent { .. } => {}
435 TLExpr::GlobalCardinality { values, .. } => {
436 for val in values {
437 val.collect_domains(domains);
438 }
439 }
440 TLExpr::Abducible { .. } => {}
441 TLExpr::Explain { formula } => {
442 formula.collect_domains(domains);
443 }
444 TLExpr::Pred { .. } | TLExpr::Constant(_) => {}
445 }
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use crate::term::Term;
453
454 #[test]
455 fn test_validate_domains_success() {
456 let registry = DomainRegistry::with_builtins();
457
458 let expr = TLExpr::exists("x", "Int", TLExpr::pred("P", vec![Term::var("x")]));
459
460 assert!(expr.validate_domains(®istry).is_ok());
461 }
462
463 #[test]
464 fn test_validate_domains_not_found() {
465 let registry = DomainRegistry::new();
466
467 let expr = TLExpr::exists(
468 "x",
469 "UnknownDomain",
470 TLExpr::pred("P", vec![Term::var("x")]),
471 );
472
473 assert!(expr.validate_domains(®istry).is_err());
474 }
475
476 #[test]
477 fn test_validate_domains_consistent_usage() {
478 let registry = DomainRegistry::with_builtins();
479
480 let expr = TLExpr::exists(
482 "x",
483 "Int",
484 TLExpr::forall("x", "Int", TLExpr::pred("P", vec![Term::var("x")])),
485 );
486
487 assert!(expr.validate_domains(®istry).is_ok());
488 }
489
490 #[test]
491 fn test_validate_domains_incompatible() {
492 let registry = DomainRegistry::with_builtins();
493
494 let expr = TLExpr::exists(
496 "x",
497 "Int",
498 TLExpr::forall("x", "Bool", TLExpr::pred("P", vec![Term::var("x")])),
499 );
500
501 assert!(expr.validate_domains(®istry).is_err());
502 }
503
504 #[test]
505 fn test_referenced_domains() {
506 let expr = TLExpr::exists(
507 "x",
508 "Int",
509 TLExpr::forall("y", "Real", TLExpr::pred("P", vec![Term::var("x")])),
510 );
511
512 let domains = expr.referenced_domains();
513 assert_eq!(domains, vec!["Int", "Real"]);
514 }
515
516 #[test]
517 fn test_referenced_domains_dedup() {
518 let expr = TLExpr::and(
519 TLExpr::exists("x", "Int", TLExpr::pred("P", vec![Term::var("x")])),
520 TLExpr::exists("y", "Int", TLExpr::pred("Q", vec![Term::var("y")])),
521 );
522
523 let domains = expr.referenced_domains();
524 assert_eq!(domains, vec!["Int"]);
525 }
526}