Skip to main content

tensorlogic_ir/expr/
ac_matching.rs

1//! Associative-Commutative (AC) pattern matching for logical expressions.
2//!
3//! This module provides AC-matching capabilities that recognize equivalent expressions
4//! under associativity and commutativity, such as:
5//! - `A ∧ B ≡ B ∧ A` (commutativity)
6//! - `(A ∧ B) ∧ C ≡ A ∧ (B ∧ C)` (associativity)
7//!
8//! AC-matching is crucial for advanced rewriting systems where the order and
9//! nesting of operators should not affect pattern matching.
10
11use std::collections::HashMap;
12
13use super::TLExpr;
14
15/// Operators that are associative and commutative.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum ACOperator {
18    /// Logical AND (∧)
19    And,
20    /// Logical OR (∨)
21    Or,
22    /// Addition (+)
23    Add,
24    /// Multiplication (*)
25    Mul,
26    /// Min operation
27    Min,
28    /// Max operation
29    Max,
30}
31
32impl ACOperator {
33    /// Check if an expression uses this AC operator.
34    pub fn matches_expr(&self, expr: &TLExpr) -> bool {
35        matches!(
36            (self, expr),
37            (ACOperator::And, TLExpr::And(_, _))
38                | (ACOperator::Or, TLExpr::Or(_, _))
39                | (ACOperator::Add, TLExpr::Add(_, _))
40                | (ACOperator::Mul, TLExpr::Mul(_, _))
41                | (ACOperator::Min, TLExpr::Min(_, _))
42                | (ACOperator::Max, TLExpr::Max(_, _))
43        )
44    }
45
46    /// Extract operands from an AC expression.
47    pub fn extract_operands<'a>(&self, expr: &'a TLExpr) -> Option<(&'a TLExpr, &'a TLExpr)> {
48        match (self, expr) {
49            (ACOperator::And, TLExpr::And(l, r)) => Some((l, r)),
50            (ACOperator::Or, TLExpr::Or(l, r)) => Some((l, r)),
51            (ACOperator::Add, TLExpr::Add(l, r)) => Some((l, r)),
52            (ACOperator::Mul, TLExpr::Mul(l, r)) => Some((l, r)),
53            (ACOperator::Min, TLExpr::Min(l, r)) => Some((l, r)),
54            (ACOperator::Max, TLExpr::Max(l, r)) => Some((l, r)),
55            _ => None,
56        }
57    }
58}
59
60/// Flatten an AC expression into a list of operands.
61///
62/// For example, `(A ∧ B) ∧ (C ∧ D)` becomes `[A, B, C, D]`.
63pub fn flatten_ac(expr: &TLExpr, op: ACOperator) -> Vec<TLExpr> {
64    let mut result = Vec::new();
65    flatten_ac_recursive(expr, op, &mut result);
66    result
67}
68
69fn flatten_ac_recursive(expr: &TLExpr, op: ACOperator, acc: &mut Vec<TLExpr>) {
70    if let Some((left, right)) = op.extract_operands(expr) {
71        flatten_ac_recursive(left, op, acc);
72        flatten_ac_recursive(right, op, acc);
73    } else {
74        acc.push(expr.clone());
75    }
76}
77
78/// Normalize an AC expression by sorting operands.
79///
80/// This creates a canonical form for AC expressions, making them easier to compare.
81pub fn normalize_ac(expr: &TLExpr, op: ACOperator) -> TLExpr {
82    if !op.matches_expr(expr) {
83        return expr.clone();
84    }
85
86    let mut operands = flatten_ac(expr, op);
87
88    // Sort operands by their debug representation (simple but effective)
89    operands.sort_by_cached_key(|e| format!("{:?}", e));
90
91    // Rebuild the expression
92    if operands.is_empty() {
93        return expr.clone();
94    }
95
96    let mut result = operands.pop().unwrap();
97    while let Some(operand) = operands.pop() {
98        result = match op {
99            ACOperator::And => TLExpr::and(operand, result),
100            ACOperator::Or => TLExpr::or(operand, result),
101            ACOperator::Add => TLExpr::add(operand, result),
102            ACOperator::Mul => TLExpr::mul(operand, result),
103            ACOperator::Min => TLExpr::min(operand, result),
104            ACOperator::Max => TLExpr::max(operand, result),
105        };
106    }
107
108    result
109}
110
111/// Check if two expressions are AC-equivalent.
112///
113/// This recursively normalizes both expressions and compares them.
114pub fn ac_equivalent(expr1: &TLExpr, expr2: &TLExpr) -> bool {
115    // Try each AC operator
116    for op in &[
117        ACOperator::And,
118        ACOperator::Or,
119        ACOperator::Add,
120        ACOperator::Mul,
121        ACOperator::Min,
122        ACOperator::Max,
123    ] {
124        if op.matches_expr(expr1) || op.matches_expr(expr2) {
125            let norm1 = normalize_ac(expr1, *op);
126            let norm2 = normalize_ac(expr2, *op);
127            return norm1 == norm2;
128        }
129    }
130
131    // If neither is an AC operator, just compare directly
132    expr1 == expr2
133}
134
135/// AC pattern matching with variable bindings.
136///
137/// This is more sophisticated than simple AC-equivalence checking, as it allows
138/// pattern variables to match subsets of operands.
139#[derive(Debug, Clone)]
140pub struct ACPattern {
141    /// The AC operator for this pattern
142    pub operator: ACOperator,
143    /// Fixed operands that must match exactly
144    pub fixed_operands: Vec<TLExpr>,
145    /// Variable operands that can match multiple elements
146    pub variable_operands: Vec<String>,
147}
148
149impl ACPattern {
150    /// Create a new AC pattern.
151    pub fn new(operator: ACOperator) -> Self {
152        Self {
153            operator,
154            fixed_operands: Vec::new(),
155            variable_operands: Vec::new(),
156        }
157    }
158
159    /// Add a fixed operand to the pattern.
160    pub fn with_fixed(mut self, operand: TLExpr) -> Self {
161        self.fixed_operands.push(operand);
162        self
163    }
164
165    /// Add a variable operand to the pattern.
166    pub fn with_variable(mut self, var: impl Into<String>) -> Self {
167        self.variable_operands.push(var.into());
168        self
169    }
170
171    /// Try to match this pattern against an expression.
172    ///
173    /// Returns bindings for variable operands if successful.
174    pub fn matches(&self, expr: &TLExpr) -> Option<HashMap<String, Vec<TLExpr>>> {
175        // Extract operands from expression
176        let expr_operands = flatten_ac(expr, self.operator);
177
178        // Check if all fixed operands are present
179        let mut remaining = expr_operands.clone();
180        for fixed in &self.fixed_operands {
181            if let Some(pos) = remaining.iter().position(|e| e == fixed) {
182                remaining.remove(pos);
183            } else {
184                return None; // Fixed operand not found
185            }
186        }
187
188        // If we have no variable operands, remaining should be empty
189        if self.variable_operands.is_empty() {
190            if remaining.is_empty() {
191                return Some(HashMap::new());
192            } else {
193                return None;
194            }
195        }
196
197        // For single variable operand, it matches all remaining
198        if self.variable_operands.len() == 1 {
199            let mut bindings = HashMap::new();
200            bindings.insert(self.variable_operands[0].clone(), remaining);
201            return Some(bindings);
202        }
203
204        // For multiple variable operands, we need to find all partitions
205        // This is NP-complete in general, so we use a simple heuristic:
206        // distribute remaining operands evenly
207        if remaining.len() < self.variable_operands.len() {
208            return None; // Not enough operands
209        }
210
211        let mut bindings = HashMap::new();
212        let chunk_size = remaining.len() / self.variable_operands.len();
213        let mut start = 0;
214
215        for (i, var) in self.variable_operands.iter().enumerate() {
216            let end = if i == self.variable_operands.len() - 1 {
217                remaining.len() // Last variable gets all remaining
218            } else {
219                start + chunk_size
220            };
221
222            let chunk = remaining[start..end].to_vec();
223            bindings.insert(var.clone(), chunk);
224            start = end;
225        }
226
227        Some(bindings)
228    }
229}
230
231/// Multiset for AC matching.
232///
233/// Represents a collection of elements where order doesn't matter but multiplicity does.
234#[derive(Debug, Clone)]
235pub struct Multiset<T> {
236    elements: HashMap<T, usize>,
237}
238
239impl<T: Eq + std::hash::Hash + Clone> Multiset<T> {
240    /// Create an empty multiset.
241    pub fn new() -> Self {
242        Self {
243            elements: HashMap::new(),
244        }
245    }
246
247    /// Create a multiset from a vector.
248    pub fn from_vec(vec: Vec<T>) -> Self {
249        let mut multiset = Self::new();
250        for elem in vec {
251            multiset.insert(elem);
252        }
253        multiset
254    }
255
256    /// Insert an element into the multiset.
257    pub fn insert(&mut self, elem: T) {
258        *self.elements.entry(elem).or_insert(0) += 1;
259    }
260
261    /// Remove an element from the multiset.
262    pub fn remove(&mut self, elem: &T) -> bool {
263        if let Some(count) = self.elements.get_mut(elem) {
264            if *count > 0 {
265                *count -= 1;
266                if *count == 0 {
267                    self.elements.remove(elem);
268                }
269                return true;
270            }
271        }
272        false
273    }
274
275    /// Check if the multiset contains an element.
276    pub fn contains(&self, elem: &T) -> bool {
277        self.elements.get(elem).is_some_and(|&count| count > 0)
278    }
279
280    /// Check if the multiset is empty.
281    pub fn is_empty(&self) -> bool {
282        self.elements.is_empty()
283    }
284
285    /// Get the number of occurrences of an element.
286    pub fn count(&self, elem: &T) -> usize {
287        self.elements.get(elem).copied().unwrap_or(0)
288    }
289
290    /// Check if this is a subset of another multiset.
291    pub fn is_subset(&self, other: &Multiset<T>) -> bool {
292        for (elem, count) in &self.elements {
293            if other.count(elem) < *count {
294                return false;
295            }
296        }
297        true
298    }
299}
300
301impl<T: Eq + std::hash::Hash + Clone> Default for Multiset<T> {
302    fn default() -> Self {
303        Self::new()
304    }
305}
306
307impl<T: Eq + std::hash::Hash> PartialEq for Multiset<T> {
308    fn eq(&self, other: &Self) -> bool {
309        self.elements == other.elements
310    }
311}
312
313impl<T: Eq + std::hash::Hash> Eq for Multiset<T> {}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use crate::Term;
319
320    #[test]
321    fn test_flatten_ac_and() {
322        // (A ∧ B) ∧ C should flatten to [A, B, C]
323        let expr = TLExpr::and(
324            TLExpr::and(
325                TLExpr::pred("A", vec![Term::var("x")]),
326                TLExpr::pred("B", vec![Term::var("y")]),
327            ),
328            TLExpr::pred("C", vec![Term::var("z")]),
329        );
330
331        let operands = flatten_ac(&expr, ACOperator::And);
332        assert_eq!(operands.len(), 3);
333    }
334
335    #[test]
336    fn test_normalize_ac() {
337        // B ∧ A should normalize to A ∧ B
338        let expr1 = TLExpr::and(
339            TLExpr::pred("B", vec![Term::var("y")]),
340            TLExpr::pred("A", vec![Term::var("x")]),
341        );
342
343        let expr2 = TLExpr::and(
344            TLExpr::pred("A", vec![Term::var("x")]),
345            TLExpr::pred("B", vec![Term::var("y")]),
346        );
347
348        let norm1 = normalize_ac(&expr1, ACOperator::And);
349        let norm2 = normalize_ac(&expr2, ACOperator::And);
350
351        assert_eq!(norm1, norm2);
352    }
353
354    #[test]
355    fn test_ac_equivalent() {
356        // (A ∧ B) ∧ C ≡ C ∧ (B ∧ A)
357        let expr1 = TLExpr::and(
358            TLExpr::and(
359                TLExpr::pred("A", vec![Term::var("x")]),
360                TLExpr::pred("B", vec![Term::var("y")]),
361            ),
362            TLExpr::pred("C", vec![Term::var("z")]),
363        );
364
365        let expr2 = TLExpr::and(
366            TLExpr::pred("C", vec![Term::var("z")]),
367            TLExpr::and(
368                TLExpr::pred("B", vec![Term::var("y")]),
369                TLExpr::pred("A", vec![Term::var("x")]),
370            ),
371        );
372
373        assert!(ac_equivalent(&expr1, &expr2));
374    }
375
376    #[test]
377    fn test_ac_pattern_simple() {
378        // Pattern: A ∧ <var>
379        let pattern = ACPattern::new(ACOperator::And)
380            .with_fixed(TLExpr::pred("A", vec![Term::var("x")]))
381            .with_variable("rest");
382
383        // Expression: A ∧ B ∧ C
384        let expr = TLExpr::and(
385            TLExpr::and(
386                TLExpr::pred("A", vec![Term::var("x")]),
387                TLExpr::pred("B", vec![Term::var("y")]),
388            ),
389            TLExpr::pred("C", vec![Term::var("z")]),
390        );
391
392        let bindings = pattern.matches(&expr).unwrap();
393        assert!(bindings.contains_key("rest"));
394        assert_eq!(bindings.get("rest").unwrap().len(), 2); // B and C
395    }
396
397    #[test]
398    fn test_multiset_operations() {
399        let mut ms1 = Multiset::new();
400        ms1.insert("A");
401        ms1.insert("B");
402        ms1.insert("A"); // A appears twice
403
404        assert_eq!(ms1.count(&"A"), 2);
405        assert_eq!(ms1.count(&"B"), 1);
406        assert!(ms1.contains(&"A"));
407
408        let mut ms2 = Multiset::new();
409        ms2.insert("A");
410
411        assert!(ms2.is_subset(&ms1));
412        assert!(!ms1.is_subset(&ms2));
413    }
414
415    #[test]
416    fn test_multiset_equality() {
417        let ms1 = Multiset::from_vec(vec!["A", "B", "A"]);
418        let ms2 = Multiset::from_vec(vec!["B", "A", "A"]);
419        let ms3 = Multiset::from_vec(vec!["A", "B"]);
420
421        assert_eq!(ms1, ms2); // Order doesn't matter
422        assert_ne!(ms1, ms3); // Multiplicity matters
423    }
424}