1use std::collections::HashMap;
12
13use super::TLExpr;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum ACOperator {
18 And,
20 Or,
22 Add,
24 Mul,
26 Min,
28 Max,
30}
31
32impl ACOperator {
33 pub fn matches_expr(&self, expr: &TLExpr) -> bool {
35 matches!(
36 (self, expr),
37 (ACOperator::And, TLExpr::And(_, _))
38 | (ACOperator::Or, TLExpr::Or(_, _))
39 | (ACOperator::Add, TLExpr::Add(_, _))
40 | (ACOperator::Mul, TLExpr::Mul(_, _))
41 | (ACOperator::Min, TLExpr::Min(_, _))
42 | (ACOperator::Max, TLExpr::Max(_, _))
43 )
44 }
45
46 pub fn extract_operands<'a>(&self, expr: &'a TLExpr) -> Option<(&'a TLExpr, &'a TLExpr)> {
48 match (self, expr) {
49 (ACOperator::And, TLExpr::And(l, r)) => Some((l, r)),
50 (ACOperator::Or, TLExpr::Or(l, r)) => Some((l, r)),
51 (ACOperator::Add, TLExpr::Add(l, r)) => Some((l, r)),
52 (ACOperator::Mul, TLExpr::Mul(l, r)) => Some((l, r)),
53 (ACOperator::Min, TLExpr::Min(l, r)) => Some((l, r)),
54 (ACOperator::Max, TLExpr::Max(l, r)) => Some((l, r)),
55 _ => None,
56 }
57 }
58}
59
60pub fn flatten_ac(expr: &TLExpr, op: ACOperator) -> Vec<TLExpr> {
64 let mut result = Vec::new();
65 flatten_ac_recursive(expr, op, &mut result);
66 result
67}
68
69fn flatten_ac_recursive(expr: &TLExpr, op: ACOperator, acc: &mut Vec<TLExpr>) {
70 if let Some((left, right)) = op.extract_operands(expr) {
71 flatten_ac_recursive(left, op, acc);
72 flatten_ac_recursive(right, op, acc);
73 } else {
74 acc.push(expr.clone());
75 }
76}
77
78pub fn normalize_ac(expr: &TLExpr, op: ACOperator) -> TLExpr {
82 if !op.matches_expr(expr) {
83 return expr.clone();
84 }
85
86 let mut operands = flatten_ac(expr, op);
87
88 operands.sort_by_cached_key(|e| format!("{:?}", e));
90
91 if operands.is_empty() {
93 return expr.clone();
94 }
95
96 let mut result = operands
97 .pop()
98 .expect("operands must be non-empty after validation");
99 while let Some(operand) = operands.pop() {
100 result = match op {
101 ACOperator::And => TLExpr::and(operand, result),
102 ACOperator::Or => TLExpr::or(operand, result),
103 ACOperator::Add => TLExpr::add(operand, result),
104 ACOperator::Mul => TLExpr::mul(operand, result),
105 ACOperator::Min => TLExpr::min(operand, result),
106 ACOperator::Max => TLExpr::max(operand, result),
107 };
108 }
109
110 result
111}
112
113pub fn ac_equivalent(expr1: &TLExpr, expr2: &TLExpr) -> bool {
117 for op in &[
119 ACOperator::And,
120 ACOperator::Or,
121 ACOperator::Add,
122 ACOperator::Mul,
123 ACOperator::Min,
124 ACOperator::Max,
125 ] {
126 if op.matches_expr(expr1) || op.matches_expr(expr2) {
127 let norm1 = normalize_ac(expr1, *op);
128 let norm2 = normalize_ac(expr2, *op);
129 return norm1 == norm2;
130 }
131 }
132
133 expr1 == expr2
135}
136
137#[derive(Debug, Clone)]
142pub struct ACPattern {
143 pub operator: ACOperator,
145 pub fixed_operands: Vec<TLExpr>,
147 pub variable_operands: Vec<String>,
149}
150
151impl ACPattern {
152 pub fn new(operator: ACOperator) -> Self {
154 Self {
155 operator,
156 fixed_operands: Vec::new(),
157 variable_operands: Vec::new(),
158 }
159 }
160
161 pub fn with_fixed(mut self, operand: TLExpr) -> Self {
163 self.fixed_operands.push(operand);
164 self
165 }
166
167 pub fn with_variable(mut self, var: impl Into<String>) -> Self {
169 self.variable_operands.push(var.into());
170 self
171 }
172
173 pub fn matches(&self, expr: &TLExpr) -> Option<HashMap<String, Vec<TLExpr>>> {
177 let expr_operands = flatten_ac(expr, self.operator);
179
180 let mut remaining = expr_operands.clone();
182 for fixed in &self.fixed_operands {
183 if let Some(pos) = remaining.iter().position(|e| e == fixed) {
184 remaining.remove(pos);
185 } else {
186 return None; }
188 }
189
190 if self.variable_operands.is_empty() {
192 if remaining.is_empty() {
193 return Some(HashMap::new());
194 } else {
195 return None;
196 }
197 }
198
199 if self.variable_operands.len() == 1 {
201 let mut bindings = HashMap::new();
202 bindings.insert(self.variable_operands[0].clone(), remaining);
203 return Some(bindings);
204 }
205
206 if remaining.len() < self.variable_operands.len() {
210 return None; }
212
213 let mut bindings = HashMap::new();
214 let chunk_size = remaining.len() / self.variable_operands.len();
215 let mut start = 0;
216
217 for (i, var) in self.variable_operands.iter().enumerate() {
218 let end = if i == self.variable_operands.len() - 1 {
219 remaining.len() } else {
221 start + chunk_size
222 };
223
224 let chunk = remaining[start..end].to_vec();
225 bindings.insert(var.clone(), chunk);
226 start = end;
227 }
228
229 Some(bindings)
230 }
231}
232
233#[derive(Debug, Clone)]
237pub struct Multiset<T> {
238 elements: HashMap<T, usize>,
239}
240
241impl<T: Eq + std::hash::Hash + Clone> Multiset<T> {
242 pub fn new() -> Self {
244 Self {
245 elements: HashMap::new(),
246 }
247 }
248
249 pub fn from_vec(vec: Vec<T>) -> Self {
251 let mut multiset = Self::new();
252 for elem in vec {
253 multiset.insert(elem);
254 }
255 multiset
256 }
257
258 pub fn insert(&mut self, elem: T) {
260 *self.elements.entry(elem).or_insert(0) += 1;
261 }
262
263 pub fn remove(&mut self, elem: &T) -> bool {
265 if let Some(count) = self.elements.get_mut(elem) {
266 if *count > 0 {
267 *count -= 1;
268 if *count == 0 {
269 self.elements.remove(elem);
270 }
271 return true;
272 }
273 }
274 false
275 }
276
277 pub fn contains(&self, elem: &T) -> bool {
279 self.elements.get(elem).is_some_and(|&count| count > 0)
280 }
281
282 pub fn is_empty(&self) -> bool {
284 self.elements.is_empty()
285 }
286
287 pub fn count(&self, elem: &T) -> usize {
289 self.elements.get(elem).copied().unwrap_or(0)
290 }
291
292 pub fn is_subset(&self, other: &Multiset<T>) -> bool {
294 for (elem, count) in &self.elements {
295 if other.count(elem) < *count {
296 return false;
297 }
298 }
299 true
300 }
301}
302
303impl<T: Eq + std::hash::Hash + Clone> Default for Multiset<T> {
304 fn default() -> Self {
305 Self::new()
306 }
307}
308
309impl<T: Eq + std::hash::Hash> PartialEq for Multiset<T> {
310 fn eq(&self, other: &Self) -> bool {
311 self.elements == other.elements
312 }
313}
314
315impl<T: Eq + std::hash::Hash> Eq for Multiset<T> {}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::Term;
321
322 #[test]
323 fn test_flatten_ac_and() {
324 let expr = TLExpr::and(
326 TLExpr::and(
327 TLExpr::pred("A", vec![Term::var("x")]),
328 TLExpr::pred("B", vec![Term::var("y")]),
329 ),
330 TLExpr::pred("C", vec![Term::var("z")]),
331 );
332
333 let operands = flatten_ac(&expr, ACOperator::And);
334 assert_eq!(operands.len(), 3);
335 }
336
337 #[test]
338 fn test_normalize_ac() {
339 let expr1 = TLExpr::and(
341 TLExpr::pred("B", vec![Term::var("y")]),
342 TLExpr::pred("A", vec![Term::var("x")]),
343 );
344
345 let expr2 = TLExpr::and(
346 TLExpr::pred("A", vec![Term::var("x")]),
347 TLExpr::pred("B", vec![Term::var("y")]),
348 );
349
350 let norm1 = normalize_ac(&expr1, ACOperator::And);
351 let norm2 = normalize_ac(&expr2, ACOperator::And);
352
353 assert_eq!(norm1, norm2);
354 }
355
356 #[test]
357 fn test_ac_equivalent() {
358 let expr1 = TLExpr::and(
360 TLExpr::and(
361 TLExpr::pred("A", vec![Term::var("x")]),
362 TLExpr::pred("B", vec![Term::var("y")]),
363 ),
364 TLExpr::pred("C", vec![Term::var("z")]),
365 );
366
367 let expr2 = TLExpr::and(
368 TLExpr::pred("C", vec![Term::var("z")]),
369 TLExpr::and(
370 TLExpr::pred("B", vec![Term::var("y")]),
371 TLExpr::pred("A", vec![Term::var("x")]),
372 ),
373 );
374
375 assert!(ac_equivalent(&expr1, &expr2));
376 }
377
378 #[test]
379 fn test_ac_pattern_simple() {
380 let pattern = ACPattern::new(ACOperator::And)
382 .with_fixed(TLExpr::pred("A", vec![Term::var("x")]))
383 .with_variable("rest");
384
385 let expr = TLExpr::and(
387 TLExpr::and(
388 TLExpr::pred("A", vec![Term::var("x")]),
389 TLExpr::pred("B", vec![Term::var("y")]),
390 ),
391 TLExpr::pred("C", vec![Term::var("z")]),
392 );
393
394 let bindings = pattern.matches(&expr).expect("unwrap");
395 assert!(bindings.contains_key("rest"));
396 assert_eq!(bindings.get("rest").expect("unwrap").len(), 2); }
398
399 #[test]
400 fn test_multiset_operations() {
401 let mut ms1 = Multiset::new();
402 ms1.insert("A");
403 ms1.insert("B");
404 ms1.insert("A"); assert_eq!(ms1.count(&"A"), 2);
407 assert_eq!(ms1.count(&"B"), 1);
408 assert!(ms1.contains(&"A"));
409
410 let mut ms2 = Multiset::new();
411 ms2.insert("A");
412
413 assert!(ms2.is_subset(&ms1));
414 assert!(!ms1.is_subset(&ms2));
415 }
416
417 #[test]
418 fn test_multiset_equality() {
419 let ms1 = Multiset::from_vec(vec!["A", "B", "A"]);
420 let ms2 = Multiset::from_vec(vec!["B", "A", "A"]);
421 let ms3 = Multiset::from_vec(vec!["A", "B"]);
422
423 assert_eq!(ms1, ms2); assert_ne!(ms1, ms3); }
426}