Skip to main content

tensorlogic_ir/expr/
ac_matching.rs

1//! Associative-Commutative (AC) pattern matching for logical expressions.
2//!
3//! This module provides AC-matching capabilities that recognize equivalent expressions
4//! under associativity and commutativity, such as:
5//! - `A ∧ B ≡ B ∧ A` (commutativity)
6//! - `(A ∧ B) ∧ C ≡ A ∧ (B ∧ C)` (associativity)
7//!
8//! AC-matching is crucial for advanced rewriting systems where the order and
9//! nesting of operators should not affect pattern matching.
10
11use std::collections::HashMap;
12
13use super::TLExpr;
14
15/// Operators that are associative and commutative.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum ACOperator {
18    /// Logical AND (∧)
19    And,
20    /// Logical OR (∨)
21    Or,
22    /// Addition (+)
23    Add,
24    /// Multiplication (*)
25    Mul,
26    /// Min operation
27    Min,
28    /// Max operation
29    Max,
30}
31
32impl ACOperator {
33    /// Check if an expression uses this AC operator.
34    pub fn matches_expr(&self, expr: &TLExpr) -> bool {
35        matches!(
36            (self, expr),
37            (ACOperator::And, TLExpr::And(_, _))
38                | (ACOperator::Or, TLExpr::Or(_, _))
39                | (ACOperator::Add, TLExpr::Add(_, _))
40                | (ACOperator::Mul, TLExpr::Mul(_, _))
41                | (ACOperator::Min, TLExpr::Min(_, _))
42                | (ACOperator::Max, TLExpr::Max(_, _))
43        )
44    }
45
46    /// Extract operands from an AC expression.
47    pub fn extract_operands<'a>(&self, expr: &'a TLExpr) -> Option<(&'a TLExpr, &'a TLExpr)> {
48        match (self, expr) {
49            (ACOperator::And, TLExpr::And(l, r)) => Some((l, r)),
50            (ACOperator::Or, TLExpr::Or(l, r)) => Some((l, r)),
51            (ACOperator::Add, TLExpr::Add(l, r)) => Some((l, r)),
52            (ACOperator::Mul, TLExpr::Mul(l, r)) => Some((l, r)),
53            (ACOperator::Min, TLExpr::Min(l, r)) => Some((l, r)),
54            (ACOperator::Max, TLExpr::Max(l, r)) => Some((l, r)),
55            _ => None,
56        }
57    }
58}
59
60/// Flatten an AC expression into a list of operands.
61///
62/// For example, `(A ∧ B) ∧ (C ∧ D)` becomes `[A, B, C, D]`.
63pub fn flatten_ac(expr: &TLExpr, op: ACOperator) -> Vec<TLExpr> {
64    let mut result = Vec::new();
65    flatten_ac_recursive(expr, op, &mut result);
66    result
67}
68
69fn flatten_ac_recursive(expr: &TLExpr, op: ACOperator, acc: &mut Vec<TLExpr>) {
70    if let Some((left, right)) = op.extract_operands(expr) {
71        flatten_ac_recursive(left, op, acc);
72        flatten_ac_recursive(right, op, acc);
73    } else {
74        acc.push(expr.clone());
75    }
76}
77
78/// Normalize an AC expression by sorting operands.
79///
80/// This creates a canonical form for AC expressions, making them easier to compare.
81pub fn normalize_ac(expr: &TLExpr, op: ACOperator) -> TLExpr {
82    if !op.matches_expr(expr) {
83        return expr.clone();
84    }
85
86    let mut operands = flatten_ac(expr, op);
87
88    // Sort operands by their debug representation (simple but effective)
89    operands.sort_by_cached_key(|e| format!("{:?}", e));
90
91    // Rebuild the expression
92    if operands.is_empty() {
93        return expr.clone();
94    }
95
96    let mut result = operands
97        .pop()
98        .expect("operands must be non-empty after validation");
99    while let Some(operand) = operands.pop() {
100        result = match op {
101            ACOperator::And => TLExpr::and(operand, result),
102            ACOperator::Or => TLExpr::or(operand, result),
103            ACOperator::Add => TLExpr::add(operand, result),
104            ACOperator::Mul => TLExpr::mul(operand, result),
105            ACOperator::Min => TLExpr::min(operand, result),
106            ACOperator::Max => TLExpr::max(operand, result),
107        };
108    }
109
110    result
111}
112
113/// Check if two expressions are AC-equivalent.
114///
115/// This recursively normalizes both expressions and compares them.
116pub fn ac_equivalent(expr1: &TLExpr, expr2: &TLExpr) -> bool {
117    // Try each AC operator
118    for op in &[
119        ACOperator::And,
120        ACOperator::Or,
121        ACOperator::Add,
122        ACOperator::Mul,
123        ACOperator::Min,
124        ACOperator::Max,
125    ] {
126        if op.matches_expr(expr1) || op.matches_expr(expr2) {
127            let norm1 = normalize_ac(expr1, *op);
128            let norm2 = normalize_ac(expr2, *op);
129            return norm1 == norm2;
130        }
131    }
132
133    // If neither is an AC operator, just compare directly
134    expr1 == expr2
135}
136
137/// AC pattern matching with variable bindings.
138///
139/// This is more sophisticated than simple AC-equivalence checking, as it allows
140/// pattern variables to match subsets of operands.
141#[derive(Debug, Clone)]
142pub struct ACPattern {
143    /// The AC operator for this pattern
144    pub operator: ACOperator,
145    /// Fixed operands that must match exactly
146    pub fixed_operands: Vec<TLExpr>,
147    /// Variable operands that can match multiple elements
148    pub variable_operands: Vec<String>,
149}
150
151impl ACPattern {
152    /// Create a new AC pattern.
153    pub fn new(operator: ACOperator) -> Self {
154        Self {
155            operator,
156            fixed_operands: Vec::new(),
157            variable_operands: Vec::new(),
158        }
159    }
160
161    /// Add a fixed operand to the pattern.
162    pub fn with_fixed(mut self, operand: TLExpr) -> Self {
163        self.fixed_operands.push(operand);
164        self
165    }
166
167    /// Add a variable operand to the pattern.
168    pub fn with_variable(mut self, var: impl Into<String>) -> Self {
169        self.variable_operands.push(var.into());
170        self
171    }
172
173    /// Try to match this pattern against an expression.
174    ///
175    /// Returns bindings for variable operands if successful.
176    pub fn matches(&self, expr: &TLExpr) -> Option<HashMap<String, Vec<TLExpr>>> {
177        // Extract operands from expression
178        let expr_operands = flatten_ac(expr, self.operator);
179
180        // Check if all fixed operands are present
181        let mut remaining = expr_operands.clone();
182        for fixed in &self.fixed_operands {
183            if let Some(pos) = remaining.iter().position(|e| e == fixed) {
184                remaining.remove(pos);
185            } else {
186                return None; // Fixed operand not found
187            }
188        }
189
190        // If we have no variable operands, remaining should be empty
191        if self.variable_operands.is_empty() {
192            if remaining.is_empty() {
193                return Some(HashMap::new());
194            } else {
195                return None;
196            }
197        }
198
199        // For single variable operand, it matches all remaining
200        if self.variable_operands.len() == 1 {
201            let mut bindings = HashMap::new();
202            bindings.insert(self.variable_operands[0].clone(), remaining);
203            return Some(bindings);
204        }
205
206        // For multiple variable operands, we need to find all partitions
207        // This is NP-complete in general, so we use a simple heuristic:
208        // distribute remaining operands evenly
209        if remaining.len() < self.variable_operands.len() {
210            return None; // Not enough operands
211        }
212
213        let mut bindings = HashMap::new();
214        let chunk_size = remaining.len() / self.variable_operands.len();
215        let mut start = 0;
216
217        for (i, var) in self.variable_operands.iter().enumerate() {
218            let end = if i == self.variable_operands.len() - 1 {
219                remaining.len() // Last variable gets all remaining
220            } else {
221                start + chunk_size
222            };
223
224            let chunk = remaining[start..end].to_vec();
225            bindings.insert(var.clone(), chunk);
226            start = end;
227        }
228
229        Some(bindings)
230    }
231}
232
233/// Multiset for AC matching.
234///
235/// Represents a collection of elements where order doesn't matter but multiplicity does.
236#[derive(Debug, Clone)]
237pub struct Multiset<T> {
238    elements: HashMap<T, usize>,
239}
240
241impl<T: Eq + std::hash::Hash + Clone> Multiset<T> {
242    /// Create an empty multiset.
243    pub fn new() -> Self {
244        Self {
245            elements: HashMap::new(),
246        }
247    }
248
249    /// Create a multiset from a vector.
250    pub fn from_vec(vec: Vec<T>) -> Self {
251        let mut multiset = Self::new();
252        for elem in vec {
253            multiset.insert(elem);
254        }
255        multiset
256    }
257
258    /// Insert an element into the multiset.
259    pub fn insert(&mut self, elem: T) {
260        *self.elements.entry(elem).or_insert(0) += 1;
261    }
262
263    /// Remove an element from the multiset.
264    pub fn remove(&mut self, elem: &T) -> bool {
265        if let Some(count) = self.elements.get_mut(elem) {
266            if *count > 0 {
267                *count -= 1;
268                if *count == 0 {
269                    self.elements.remove(elem);
270                }
271                return true;
272            }
273        }
274        false
275    }
276
277    /// Check if the multiset contains an element.
278    pub fn contains(&self, elem: &T) -> bool {
279        self.elements.get(elem).is_some_and(|&count| count > 0)
280    }
281
282    /// Check if the multiset is empty.
283    pub fn is_empty(&self) -> bool {
284        self.elements.is_empty()
285    }
286
287    /// Get the number of occurrences of an element.
288    pub fn count(&self, elem: &T) -> usize {
289        self.elements.get(elem).copied().unwrap_or(0)
290    }
291
292    /// Check if this is a subset of another multiset.
293    pub fn is_subset(&self, other: &Multiset<T>) -> bool {
294        for (elem, count) in &self.elements {
295            if other.count(elem) < *count {
296                return false;
297            }
298        }
299        true
300    }
301}
302
303impl<T: Eq + std::hash::Hash + Clone> Default for Multiset<T> {
304    fn default() -> Self {
305        Self::new()
306    }
307}
308
309impl<T: Eq + std::hash::Hash> PartialEq for Multiset<T> {
310    fn eq(&self, other: &Self) -> bool {
311        self.elements == other.elements
312    }
313}
314
315impl<T: Eq + std::hash::Hash> Eq for Multiset<T> {}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::Term;
321
322    #[test]
323    fn test_flatten_ac_and() {
324        // (A ∧ B) ∧ C should flatten to [A, B, C]
325        let expr = TLExpr::and(
326            TLExpr::and(
327                TLExpr::pred("A", vec![Term::var("x")]),
328                TLExpr::pred("B", vec![Term::var("y")]),
329            ),
330            TLExpr::pred("C", vec![Term::var("z")]),
331        );
332
333        let operands = flatten_ac(&expr, ACOperator::And);
334        assert_eq!(operands.len(), 3);
335    }
336
337    #[test]
338    fn test_normalize_ac() {
339        // B ∧ A should normalize to A ∧ B
340        let expr1 = TLExpr::and(
341            TLExpr::pred("B", vec![Term::var("y")]),
342            TLExpr::pred("A", vec![Term::var("x")]),
343        );
344
345        let expr2 = TLExpr::and(
346            TLExpr::pred("A", vec![Term::var("x")]),
347            TLExpr::pred("B", vec![Term::var("y")]),
348        );
349
350        let norm1 = normalize_ac(&expr1, ACOperator::And);
351        let norm2 = normalize_ac(&expr2, ACOperator::And);
352
353        assert_eq!(norm1, norm2);
354    }
355
356    #[test]
357    fn test_ac_equivalent() {
358        // (A ∧ B) ∧ C ≡ C ∧ (B ∧ A)
359        let expr1 = TLExpr::and(
360            TLExpr::and(
361                TLExpr::pred("A", vec![Term::var("x")]),
362                TLExpr::pred("B", vec![Term::var("y")]),
363            ),
364            TLExpr::pred("C", vec![Term::var("z")]),
365        );
366
367        let expr2 = TLExpr::and(
368            TLExpr::pred("C", vec![Term::var("z")]),
369            TLExpr::and(
370                TLExpr::pred("B", vec![Term::var("y")]),
371                TLExpr::pred("A", vec![Term::var("x")]),
372            ),
373        );
374
375        assert!(ac_equivalent(&expr1, &expr2));
376    }
377
378    #[test]
379    fn test_ac_pattern_simple() {
380        // Pattern: A ∧ <var>
381        let pattern = ACPattern::new(ACOperator::And)
382            .with_fixed(TLExpr::pred("A", vec![Term::var("x")]))
383            .with_variable("rest");
384
385        // Expression: A ∧ B ∧ C
386        let expr = TLExpr::and(
387            TLExpr::and(
388                TLExpr::pred("A", vec![Term::var("x")]),
389                TLExpr::pred("B", vec![Term::var("y")]),
390            ),
391            TLExpr::pred("C", vec![Term::var("z")]),
392        );
393
394        let bindings = pattern.matches(&expr).expect("unwrap");
395        assert!(bindings.contains_key("rest"));
396        assert_eq!(bindings.get("rest").expect("unwrap").len(), 2); // B and C
397    }
398
399    #[test]
400    fn test_multiset_operations() {
401        let mut ms1 = Multiset::new();
402        ms1.insert("A");
403        ms1.insert("B");
404        ms1.insert("A"); // A appears twice
405
406        assert_eq!(ms1.count(&"A"), 2);
407        assert_eq!(ms1.count(&"B"), 1);
408        assert!(ms1.contains(&"A"));
409
410        let mut ms2 = Multiset::new();
411        ms2.insert("A");
412
413        assert!(ms2.is_subset(&ms1));
414        assert!(!ms1.is_subset(&ms2));
415    }
416
417    #[test]
418    fn test_multiset_equality() {
419        let ms1 = Multiset::from_vec(vec!["A", "B", "A"]);
420        let ms2 = Multiset::from_vec(vec!["B", "A", "A"]);
421        let ms3 = Multiset::from_vec(vec!["A", "B"]);
422
423        assert_eq!(ms1, ms2); // Order doesn't matter
424        assert_ne!(ms1, ms3); // Multiplicity matters
425    }
426}