Skip to main content

tensorlogic_ir/expr/
validation.rs

1//! Expression validation (arity checking).
2
3use std::collections::HashMap;
4
5use super::TLExpr;
6
7impl TLExpr {
8    /// Validate that all predicates with the same name have consistent arity
9    pub fn validate_arity(&self) -> Result<(), String> {
10        self.validate_arity_recursive(&HashMap::new())
11    }
12
13    fn validate_arity_recursive(&self, seen: &HashMap<String, usize>) -> Result<(), String> {
14        match self {
15            TLExpr::Pred { name, args } => {
16                if let Some(&expected_arity) = seen.get(name) {
17                    if expected_arity != args.len() {
18                        return Err(format!(
19                            "Predicate '{}' has inconsistent arity: expected {}, found {}",
20                            name,
21                            expected_arity,
22                            args.len()
23                        ));
24                    }
25                }
26                let mut new_seen = seen.clone();
27                new_seen.insert(name.clone(), args.len());
28                Ok(())
29            }
30            TLExpr::And(l, r)
31            | TLExpr::Or(l, r)
32            | TLExpr::Imply(l, r)
33            | TLExpr::Add(l, r)
34            | TLExpr::Sub(l, r)
35            | TLExpr::Mul(l, r)
36            | TLExpr::Div(l, r)
37            | TLExpr::Pow(l, r)
38            | TLExpr::Mod(l, r)
39            | TLExpr::Min(l, r)
40            | TLExpr::Max(l, r)
41            | TLExpr::Eq(l, r)
42            | TLExpr::Lt(l, r)
43            | TLExpr::Gt(l, r)
44            | TLExpr::Lte(l, r)
45            | TLExpr::Gte(l, r) => {
46                let mut new_seen = seen.clone();
47
48                l.collect_and_check_arity(&mut new_seen)?;
49                r.collect_and_check_arity(&mut new_seen)?;
50
51                Ok(())
52            }
53            TLExpr::Not(e)
54            | TLExpr::Score(e)
55            | TLExpr::Abs(e)
56            | TLExpr::Floor(e)
57            | TLExpr::Ceil(e)
58            | TLExpr::Round(e)
59            | TLExpr::Sqrt(e)
60            | TLExpr::Exp(e)
61            | TLExpr::Log(e)
62            | TLExpr::Sin(e)
63            | TLExpr::Cos(e)
64            | TLExpr::Tan(e)
65            | TLExpr::Box(e)
66            | TLExpr::Diamond(e)
67            | TLExpr::Next(e)
68            | TLExpr::Eventually(e)
69            | TLExpr::Always(e) => e.validate_arity_recursive(seen),
70            TLExpr::Until { before, after } => {
71                let mut new_seen = seen.clone();
72                before.collect_and_check_arity(&mut new_seen)?;
73                after.collect_and_check_arity(&mut new_seen)?;
74                Ok(())
75            }
76            TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
77                body.validate_arity_recursive(seen)
78            }
79            TLExpr::Aggregate { body, .. } => body.validate_arity_recursive(seen),
80            TLExpr::IfThenElse {
81                condition,
82                then_branch,
83                else_branch,
84            } => {
85                let mut new_seen = seen.clone();
86                condition.collect_and_check_arity(&mut new_seen)?;
87                then_branch.collect_and_check_arity(&mut new_seen)?;
88                else_branch.collect_and_check_arity(&mut new_seen)?;
89                Ok(())
90            }
91            TLExpr::Let { value, body, .. } => {
92                let mut new_seen = seen.clone();
93                value.collect_and_check_arity(&mut new_seen)?;
94                body.collect_and_check_arity(&mut new_seen)?;
95                Ok(())
96            }
97
98            // Fuzzy logic operators
99            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
100                let mut new_seen = seen.clone();
101                left.collect_and_check_arity(&mut new_seen)?;
102                right.collect_and_check_arity(&mut new_seen)?;
103                Ok(())
104            }
105            TLExpr::FuzzyNot { expr, .. } => expr.validate_arity_recursive(seen),
106            TLExpr::FuzzyImplication {
107                premise,
108                conclusion,
109                ..
110            } => {
111                let mut new_seen = seen.clone();
112                premise.collect_and_check_arity(&mut new_seen)?;
113                conclusion.collect_and_check_arity(&mut new_seen)?;
114                Ok(())
115            }
116
117            // Probabilistic operators
118            TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
119                body.validate_arity_recursive(seen)
120            }
121            TLExpr::WeightedRule { rule, .. } => rule.validate_arity_recursive(seen),
122            TLExpr::ProbabilisticChoice { alternatives } => {
123                let mut new_seen = seen.clone();
124                for (_, expr) in alternatives {
125                    expr.collect_and_check_arity(&mut new_seen)?;
126                }
127                Ok(())
128            }
129
130            // Extended temporal logic
131            TLExpr::Release { released, releaser }
132            | TLExpr::WeakUntil {
133                before: released,
134                after: releaser,
135            }
136            | TLExpr::StrongRelease { released, releaser } => {
137                let mut new_seen = seen.clone();
138                released.collect_and_check_arity(&mut new_seen)?;
139                releaser.collect_and_check_arity(&mut new_seen)?;
140                Ok(())
141            }
142
143            // Beta.1 enhancements
144            TLExpr::Lambda { body, .. } => body.validate_arity_recursive(seen),
145            TLExpr::Apply { function, argument } => {
146                let mut new_seen = seen.clone();
147                function.collect_and_check_arity(&mut new_seen)?;
148                argument.collect_and_check_arity(&mut new_seen)?;
149                Ok(())
150            }
151            TLExpr::SetMembership { element, set }
152            | TLExpr::SetUnion {
153                left: element,
154                right: set,
155            }
156            | TLExpr::SetIntersection {
157                left: element,
158                right: set,
159            }
160            | TLExpr::SetDifference {
161                left: element,
162                right: set,
163            } => {
164                let mut new_seen = seen.clone();
165                element.collect_and_check_arity(&mut new_seen)?;
166                set.collect_and_check_arity(&mut new_seen)?;
167                Ok(())
168            }
169            TLExpr::SetCardinality { set } => set.validate_arity_recursive(seen),
170            TLExpr::EmptySet => Ok(()),
171            TLExpr::SetComprehension { condition, .. } => condition.validate_arity_recursive(seen),
172            TLExpr::CountingExists { body, .. }
173            | TLExpr::CountingForAll { body, .. }
174            | TLExpr::ExactCount { body, .. }
175            | TLExpr::Majority { body, .. } => body.validate_arity_recursive(seen),
176            TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
177                body.validate_arity_recursive(seen)
178            }
179            TLExpr::Nominal { .. } => Ok(()),
180            TLExpr::At { formula, .. } => formula.validate_arity_recursive(seen),
181            TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
182                formula.validate_arity_recursive(seen)
183            }
184            TLExpr::AllDifferent { .. } => Ok(()),
185            TLExpr::GlobalCardinality { values, .. } => {
186                let mut new_seen = seen.clone();
187                for val in values {
188                    val.collect_and_check_arity(&mut new_seen)?;
189                }
190                Ok(())
191            }
192            TLExpr::Abducible { .. } => Ok(()),
193            TLExpr::Explain { formula } => formula.validate_arity_recursive(seen),
194            TLExpr::SymbolLiteral(_) => Ok(()),
195            TLExpr::Match { scrutinee, arms } => {
196                if arms.is_empty() {
197                    return Err("Match expression must have at least one arm".into());
198                }
199                let last = &arms[arms.len() - 1].0;
200                if !matches!(last, crate::pattern::MatchPattern::Wildcard) {
201                    return Err("Last arm of Match expression must be a Wildcard pattern".into());
202                }
203                scrutinee.validate_arity_recursive(seen)?;
204                for (_, body) in arms {
205                    body.validate_arity_recursive(seen)?;
206                }
207                Ok(())
208            }
209
210            TLExpr::Constant(_) => Ok(()),
211        }
212    }
213
214    pub(crate) fn collect_and_check_arity(
215        &self,
216        seen: &mut HashMap<String, usize>,
217    ) -> Result<(), String> {
218        match self {
219            TLExpr::Pred { name, args } => {
220                if let Some(&expected_arity) = seen.get(name) {
221                    if expected_arity != args.len() {
222                        return Err(format!(
223                            "Predicate '{}' has inconsistent arity: expected {}, found {}",
224                            name,
225                            expected_arity,
226                            args.len()
227                        ));
228                    }
229                } else {
230                    seen.insert(name.clone(), args.len());
231                }
232                Ok(())
233            }
234            TLExpr::And(l, r)
235            | TLExpr::Or(l, r)
236            | TLExpr::Imply(l, r)
237            | TLExpr::Add(l, r)
238            | TLExpr::Sub(l, r)
239            | TLExpr::Mul(l, r)
240            | TLExpr::Div(l, r)
241            | TLExpr::Pow(l, r)
242            | TLExpr::Mod(l, r)
243            | TLExpr::Min(l, r)
244            | TLExpr::Max(l, r)
245            | TLExpr::Eq(l, r)
246            | TLExpr::Lt(l, r)
247            | TLExpr::Gt(l, r)
248            | TLExpr::Lte(l, r)
249            | TLExpr::Gte(l, r) => {
250                l.collect_and_check_arity(seen)?;
251                r.collect_and_check_arity(seen)?;
252                Ok(())
253            }
254            TLExpr::Not(e)
255            | TLExpr::Score(e)
256            | TLExpr::Abs(e)
257            | TLExpr::Floor(e)
258            | TLExpr::Ceil(e)
259            | TLExpr::Round(e)
260            | TLExpr::Sqrt(e)
261            | TLExpr::Exp(e)
262            | TLExpr::Log(e)
263            | TLExpr::Sin(e)
264            | TLExpr::Cos(e)
265            | TLExpr::Tan(e)
266            | TLExpr::Box(e)
267            | TLExpr::Diamond(e)
268            | TLExpr::Next(e)
269            | TLExpr::Eventually(e)
270            | TLExpr::Always(e) => e.collect_and_check_arity(seen),
271            TLExpr::Until { before, after } => {
272                before.collect_and_check_arity(seen)?;
273                after.collect_and_check_arity(seen)?;
274                Ok(())
275            }
276            TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
277                body.collect_and_check_arity(seen)
278            }
279            TLExpr::Aggregate { body, .. } => body.collect_and_check_arity(seen),
280            TLExpr::IfThenElse {
281                condition,
282                then_branch,
283                else_branch,
284            } => {
285                condition.collect_and_check_arity(seen)?;
286                then_branch.collect_and_check_arity(seen)?;
287                else_branch.collect_and_check_arity(seen)?;
288                Ok(())
289            }
290            TLExpr::Let { value, body, .. } => {
291                value.collect_and_check_arity(seen)?;
292                body.collect_and_check_arity(seen)?;
293                Ok(())
294            }
295
296            // Fuzzy logic operators
297            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
298                left.collect_and_check_arity(seen)?;
299                right.collect_and_check_arity(seen)?;
300                Ok(())
301            }
302            TLExpr::FuzzyNot { expr, .. } => expr.collect_and_check_arity(seen),
303            TLExpr::FuzzyImplication {
304                premise,
305                conclusion,
306                ..
307            } => {
308                premise.collect_and_check_arity(seen)?;
309                conclusion.collect_and_check_arity(seen)?;
310                Ok(())
311            }
312
313            // Probabilistic operators
314            TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
315                body.collect_and_check_arity(seen)
316            }
317            TLExpr::WeightedRule { rule, .. } => rule.collect_and_check_arity(seen),
318            TLExpr::ProbabilisticChoice { alternatives } => {
319                for (_, expr) in alternatives {
320                    expr.collect_and_check_arity(seen)?;
321                }
322                Ok(())
323            }
324
325            // Extended temporal logic
326            TLExpr::Release { released, releaser }
327            | TLExpr::WeakUntil {
328                before: released,
329                after: releaser,
330            }
331            | TLExpr::StrongRelease { released, releaser } => {
332                released.collect_and_check_arity(seen)?;
333                releaser.collect_and_check_arity(seen)?;
334                Ok(())
335            }
336
337            // Beta.1 enhancements
338            TLExpr::Lambda { body, .. } => body.collect_and_check_arity(seen),
339            TLExpr::Apply { function, argument } => {
340                function.collect_and_check_arity(seen)?;
341                argument.collect_and_check_arity(seen)?;
342                Ok(())
343            }
344            TLExpr::SetMembership { element, set }
345            | TLExpr::SetUnion {
346                left: element,
347                right: set,
348            }
349            | TLExpr::SetIntersection {
350                left: element,
351                right: set,
352            }
353            | TLExpr::SetDifference {
354                left: element,
355                right: set,
356            } => {
357                element.collect_and_check_arity(seen)?;
358                set.collect_and_check_arity(seen)?;
359                Ok(())
360            }
361            TLExpr::SetCardinality { set } => set.collect_and_check_arity(seen),
362            TLExpr::EmptySet => Ok(()),
363            TLExpr::SetComprehension { condition, .. } => condition.collect_and_check_arity(seen),
364            TLExpr::CountingExists { body, .. }
365            | TLExpr::CountingForAll { body, .. }
366            | TLExpr::ExactCount { body, .. }
367            | TLExpr::Majority { body, .. } => body.collect_and_check_arity(seen),
368            TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
369                body.collect_and_check_arity(seen)
370            }
371            TLExpr::Nominal { .. } => Ok(()),
372            TLExpr::At { formula, .. } => formula.collect_and_check_arity(seen),
373            TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
374                formula.collect_and_check_arity(seen)
375            }
376            TLExpr::AllDifferent { .. } => Ok(()),
377            TLExpr::GlobalCardinality { values, .. } => {
378                for val in values {
379                    val.collect_and_check_arity(seen)?;
380                }
381                Ok(())
382            }
383            TLExpr::Abducible { .. } => Ok(()),
384            TLExpr::Explain { formula } => formula.collect_and_check_arity(seen),
385            TLExpr::SymbolLiteral(_) => Ok(()),
386            TLExpr::Match { scrutinee, arms } => {
387                scrutinee.collect_and_check_arity(seen)?;
388                for (_, body) in arms {
389                    body.collect_and_check_arity(seen)?;
390                }
391                Ok(())
392            }
393
394            TLExpr::Constant(_) => Ok(()),
395        }
396    }
397}