1use std::collections::HashMap;
12
13use super::TLExpr;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum ACOperator {
18 And,
20 Or,
22 Add,
24 Mul,
26 Min,
28 Max,
30}
31
32impl ACOperator {
33 pub fn matches_expr(&self, expr: &TLExpr) -> bool {
35 matches!(
36 (self, expr),
37 (ACOperator::And, TLExpr::And(_, _))
38 | (ACOperator::Or, TLExpr::Or(_, _))
39 | (ACOperator::Add, TLExpr::Add(_, _))
40 | (ACOperator::Mul, TLExpr::Mul(_, _))
41 | (ACOperator::Min, TLExpr::Min(_, _))
42 | (ACOperator::Max, TLExpr::Max(_, _))
43 )
44 }
45
46 pub fn extract_operands<'a>(&self, expr: &'a TLExpr) -> Option<(&'a TLExpr, &'a TLExpr)> {
48 match (self, expr) {
49 (ACOperator::And, TLExpr::And(l, r)) => Some((l, r)),
50 (ACOperator::Or, TLExpr::Or(l, r)) => Some((l, r)),
51 (ACOperator::Add, TLExpr::Add(l, r)) => Some((l, r)),
52 (ACOperator::Mul, TLExpr::Mul(l, r)) => Some((l, r)),
53 (ACOperator::Min, TLExpr::Min(l, r)) => Some((l, r)),
54 (ACOperator::Max, TLExpr::Max(l, r)) => Some((l, r)),
55 _ => None,
56 }
57 }
58}
59
60pub fn flatten_ac(expr: &TLExpr, op: ACOperator) -> Vec<TLExpr> {
64 let mut result = Vec::new();
65 flatten_ac_recursive(expr, op, &mut result);
66 result
67}
68
69fn flatten_ac_recursive(expr: &TLExpr, op: ACOperator, acc: &mut Vec<TLExpr>) {
70 if let Some((left, right)) = op.extract_operands(expr) {
71 flatten_ac_recursive(left, op, acc);
72 flatten_ac_recursive(right, op, acc);
73 } else {
74 acc.push(expr.clone());
75 }
76}
77
78pub fn normalize_ac(expr: &TLExpr, op: ACOperator) -> TLExpr {
82 if !op.matches_expr(expr) {
83 return expr.clone();
84 }
85
86 let mut operands = flatten_ac(expr, op);
87
88 operands.sort_by_cached_key(|e| format!("{:?}", e));
90
91 if operands.is_empty() {
93 return expr.clone();
94 }
95
96 let mut result = operands.pop().unwrap();
97 while let Some(operand) = operands.pop() {
98 result = match op {
99 ACOperator::And => TLExpr::and(operand, result),
100 ACOperator::Or => TLExpr::or(operand, result),
101 ACOperator::Add => TLExpr::add(operand, result),
102 ACOperator::Mul => TLExpr::mul(operand, result),
103 ACOperator::Min => TLExpr::min(operand, result),
104 ACOperator::Max => TLExpr::max(operand, result),
105 };
106 }
107
108 result
109}
110
111pub fn ac_equivalent(expr1: &TLExpr, expr2: &TLExpr) -> bool {
115 for op in &[
117 ACOperator::And,
118 ACOperator::Or,
119 ACOperator::Add,
120 ACOperator::Mul,
121 ACOperator::Min,
122 ACOperator::Max,
123 ] {
124 if op.matches_expr(expr1) || op.matches_expr(expr2) {
125 let norm1 = normalize_ac(expr1, *op);
126 let norm2 = normalize_ac(expr2, *op);
127 return norm1 == norm2;
128 }
129 }
130
131 expr1 == expr2
133}
134
135#[derive(Debug, Clone)]
140pub struct ACPattern {
141 pub operator: ACOperator,
143 pub fixed_operands: Vec<TLExpr>,
145 pub variable_operands: Vec<String>,
147}
148
149impl ACPattern {
150 pub fn new(operator: ACOperator) -> Self {
152 Self {
153 operator,
154 fixed_operands: Vec::new(),
155 variable_operands: Vec::new(),
156 }
157 }
158
159 pub fn with_fixed(mut self, operand: TLExpr) -> Self {
161 self.fixed_operands.push(operand);
162 self
163 }
164
165 pub fn with_variable(mut self, var: impl Into<String>) -> Self {
167 self.variable_operands.push(var.into());
168 self
169 }
170
171 pub fn matches(&self, expr: &TLExpr) -> Option<HashMap<String, Vec<TLExpr>>> {
175 let expr_operands = flatten_ac(expr, self.operator);
177
178 let mut remaining = expr_operands.clone();
180 for fixed in &self.fixed_operands {
181 if let Some(pos) = remaining.iter().position(|e| e == fixed) {
182 remaining.remove(pos);
183 } else {
184 return None; }
186 }
187
188 if self.variable_operands.is_empty() {
190 if remaining.is_empty() {
191 return Some(HashMap::new());
192 } else {
193 return None;
194 }
195 }
196
197 if self.variable_operands.len() == 1 {
199 let mut bindings = HashMap::new();
200 bindings.insert(self.variable_operands[0].clone(), remaining);
201 return Some(bindings);
202 }
203
204 if remaining.len() < self.variable_operands.len() {
208 return None; }
210
211 let mut bindings = HashMap::new();
212 let chunk_size = remaining.len() / self.variable_operands.len();
213 let mut start = 0;
214
215 for (i, var) in self.variable_operands.iter().enumerate() {
216 let end = if i == self.variable_operands.len() - 1 {
217 remaining.len() } else {
219 start + chunk_size
220 };
221
222 let chunk = remaining[start..end].to_vec();
223 bindings.insert(var.clone(), chunk);
224 start = end;
225 }
226
227 Some(bindings)
228 }
229}
230
231#[derive(Debug, Clone)]
235pub struct Multiset<T> {
236 elements: HashMap<T, usize>,
237}
238
239impl<T: Eq + std::hash::Hash + Clone> Multiset<T> {
240 pub fn new() -> Self {
242 Self {
243 elements: HashMap::new(),
244 }
245 }
246
247 pub fn from_vec(vec: Vec<T>) -> Self {
249 let mut multiset = Self::new();
250 for elem in vec {
251 multiset.insert(elem);
252 }
253 multiset
254 }
255
256 pub fn insert(&mut self, elem: T) {
258 *self.elements.entry(elem).or_insert(0) += 1;
259 }
260
261 pub fn remove(&mut self, elem: &T) -> bool {
263 if let Some(count) = self.elements.get_mut(elem) {
264 if *count > 0 {
265 *count -= 1;
266 if *count == 0 {
267 self.elements.remove(elem);
268 }
269 return true;
270 }
271 }
272 false
273 }
274
275 pub fn contains(&self, elem: &T) -> bool {
277 self.elements.get(elem).is_some_and(|&count| count > 0)
278 }
279
280 pub fn is_empty(&self) -> bool {
282 self.elements.is_empty()
283 }
284
285 pub fn count(&self, elem: &T) -> usize {
287 self.elements.get(elem).copied().unwrap_or(0)
288 }
289
290 pub fn is_subset(&self, other: &Multiset<T>) -> bool {
292 for (elem, count) in &self.elements {
293 if other.count(elem) < *count {
294 return false;
295 }
296 }
297 true
298 }
299}
300
301impl<T: Eq + std::hash::Hash + Clone> Default for Multiset<T> {
302 fn default() -> Self {
303 Self::new()
304 }
305}
306
307impl<T: Eq + std::hash::Hash> PartialEq for Multiset<T> {
308 fn eq(&self, other: &Self) -> bool {
309 self.elements == other.elements
310 }
311}
312
313impl<T: Eq + std::hash::Hash> Eq for Multiset<T> {}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use crate::Term;
319
320 #[test]
321 fn test_flatten_ac_and() {
322 let expr = TLExpr::and(
324 TLExpr::and(
325 TLExpr::pred("A", vec![Term::var("x")]),
326 TLExpr::pred("B", vec![Term::var("y")]),
327 ),
328 TLExpr::pred("C", vec![Term::var("z")]),
329 );
330
331 let operands = flatten_ac(&expr, ACOperator::And);
332 assert_eq!(operands.len(), 3);
333 }
334
335 #[test]
336 fn test_normalize_ac() {
337 let expr1 = TLExpr::and(
339 TLExpr::pred("B", vec![Term::var("y")]),
340 TLExpr::pred("A", vec![Term::var("x")]),
341 );
342
343 let expr2 = TLExpr::and(
344 TLExpr::pred("A", vec![Term::var("x")]),
345 TLExpr::pred("B", vec![Term::var("y")]),
346 );
347
348 let norm1 = normalize_ac(&expr1, ACOperator::And);
349 let norm2 = normalize_ac(&expr2, ACOperator::And);
350
351 assert_eq!(norm1, norm2);
352 }
353
354 #[test]
355 fn test_ac_equivalent() {
356 let expr1 = TLExpr::and(
358 TLExpr::and(
359 TLExpr::pred("A", vec![Term::var("x")]),
360 TLExpr::pred("B", vec![Term::var("y")]),
361 ),
362 TLExpr::pred("C", vec![Term::var("z")]),
363 );
364
365 let expr2 = TLExpr::and(
366 TLExpr::pred("C", vec![Term::var("z")]),
367 TLExpr::and(
368 TLExpr::pred("B", vec![Term::var("y")]),
369 TLExpr::pred("A", vec![Term::var("x")]),
370 ),
371 );
372
373 assert!(ac_equivalent(&expr1, &expr2));
374 }
375
376 #[test]
377 fn test_ac_pattern_simple() {
378 let pattern = ACPattern::new(ACOperator::And)
380 .with_fixed(TLExpr::pred("A", vec![Term::var("x")]))
381 .with_variable("rest");
382
383 let expr = TLExpr::and(
385 TLExpr::and(
386 TLExpr::pred("A", vec![Term::var("x")]),
387 TLExpr::pred("B", vec![Term::var("y")]),
388 ),
389 TLExpr::pred("C", vec![Term::var("z")]),
390 );
391
392 let bindings = pattern.matches(&expr).unwrap();
393 assert!(bindings.contains_key("rest"));
394 assert_eq!(bindings.get("rest").unwrap().len(), 2); }
396
397 #[test]
398 fn test_multiset_operations() {
399 let mut ms1 = Multiset::new();
400 ms1.insert("A");
401 ms1.insert("B");
402 ms1.insert("A"); assert_eq!(ms1.count(&"A"), 2);
405 assert_eq!(ms1.count(&"B"), 1);
406 assert!(ms1.contains(&"A"));
407
408 let mut ms2 = Multiset::new();
409 ms2.insert("A");
410
411 assert!(ms2.is_subset(&ms1));
412 assert!(!ms1.is_subset(&ms2));
413 }
414
415 #[test]
416 fn test_multiset_equality() {
417 let ms1 = Multiset::from_vec(vec!["A", "B", "A"]);
418 let ms2 = Multiset::from_vec(vec!["B", "A", "A"]);
419 let ms3 = Multiset::from_vec(vec!["A", "B"]);
420
421 assert_eq!(ms1, ms2); assert_ne!(ms1, ms3); }
424}