Skip to main content

tensorlogic_ir/expr/
confluence.rs

1//! Confluence checking and critical pair analysis for term rewriting systems.
2//!
3//! This module provides tools for analyzing the confluence properties of rewrite systems:
4//! - **Confluence**: A property where different rewrite sequences lead to the same result
5//! - **Critical pairs**: Overlapping rule applications that may cause conflicts
6//! - **Joinability**: Whether two expressions can be rewritten to a common form
7//!
8//! A confluent rewrite system guarantees that the order of rule application doesn't matter,
9//! which is crucial for correctness and determinism.
10
11use std::collections::{HashMap, HashSet, VecDeque};
12
13use super::rewriting::RewriteSystem;
14use super::TLExpr;
15
16/// A critical pair representing a potential conflict between two rules.
17#[derive(Debug, Clone)]
18pub struct CriticalPair {
19    /// The expression where the overlap occurs
20    pub overlap: TLExpr,
21    /// Result of applying the first rule
22    pub result1: TLExpr,
23    /// Result of applying the second rule
24    pub result2: TLExpr,
25    /// Names of the rules involved
26    pub rule1_name: String,
27    pub rule2_name: String,
28    /// Whether this critical pair is joinable
29    pub joinable: Option<bool>,
30}
31
32impl CriticalPair {
33    /// Create a new critical pair.
34    pub fn new(
35        overlap: TLExpr,
36        result1: TLExpr,
37        result2: TLExpr,
38        rule1_name: String,
39        rule2_name: String,
40    ) -> Self {
41        Self {
42            overlap,
43            result1,
44            result2,
45            rule1_name,
46            rule2_name,
47            joinable: None,
48        }
49    }
50
51    /// Check if this critical pair is trivially joinable (results are equal).
52    pub fn is_trivially_joinable(&self) -> bool {
53        self.result1 == self.result2
54    }
55
56    /// Check if the results are syntactically different.
57    pub fn has_conflict(&self) -> bool {
58        !self.is_trivially_joinable()
59    }
60}
61
62/// Result of confluence analysis.
63#[derive(Debug, Clone)]
64pub struct ConfluenceReport {
65    /// All critical pairs found
66    pub critical_pairs: Vec<CriticalPair>,
67    /// Number of joinable critical pairs
68    pub joinable_count: usize,
69    /// Number of non-joinable critical pairs
70    pub non_joinable_count: usize,
71    /// Whether the system is locally confluent
72    pub is_locally_confluent: bool,
73    /// Whether termination was verified
74    pub is_terminating: bool,
75}
76
77impl ConfluenceReport {
78    /// Create a new confluence report.
79    pub fn new() -> Self {
80        Self {
81            critical_pairs: Vec::new(),
82            joinable_count: 0,
83            non_joinable_count: 0,
84            is_locally_confluent: false,
85            is_terminating: false,
86        }
87    }
88
89    /// Check if the system is globally confluent (by Newman's lemma).
90    ///
91    /// Newman's lemma: A terminating system is confluent iff it is locally confluent.
92    pub fn is_confluent(&self) -> bool {
93        self.is_terminating && self.is_locally_confluent
94    }
95
96    /// Get a summary string.
97    pub fn summary(&self) -> String {
98        format!(
99            "Confluence Report:\n\
100             - Critical pairs: {}\n\
101             - Joinable: {}\n\
102             - Non-joinable: {}\n\
103             - Locally confluent: {}\n\
104             - Terminating: {}\n\
105             - Confluent: {}",
106            self.critical_pairs.len(),
107            self.joinable_count,
108            self.non_joinable_count,
109            self.is_locally_confluent,
110            self.is_terminating,
111            self.is_confluent()
112        )
113    }
114}
115
116impl Default for ConfluenceReport {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122/// Confluence checker for rewrite systems.
123pub struct ConfluenceChecker {
124    /// Maximum depth for joinability testing
125    max_depth: usize,
126    /// Maximum expression size for analysis
127    max_expr_size: usize,
128    /// Cache of expression pairs and their joinability
129    joinability_cache: HashMap<(String, String), bool>,
130}
131
132impl ConfluenceChecker {
133    /// Create a new confluence checker.
134    pub fn new() -> Self {
135        Self {
136            max_depth: 10,
137            max_expr_size: 1000,
138            joinability_cache: HashMap::new(),
139        }
140    }
141
142    /// Set maximum depth for joinability testing.
143    pub fn with_max_depth(mut self, depth: usize) -> Self {
144        self.max_depth = depth;
145        self
146    }
147
148    /// Set maximum expression size.
149    pub fn with_max_expr_size(mut self, size: usize) -> Self {
150        self.max_expr_size = size;
151        self
152    }
153
154    /// Check confluence of a rewrite system.
155    pub fn check(&mut self, system: &RewriteSystem) -> ConfluenceReport {
156        let mut report = ConfluenceReport::new();
157
158        // Find all critical pairs
159        self.find_critical_pairs_basic(system, &mut report);
160
161        // Test joinability for each critical pair
162        for pair in &mut report.critical_pairs {
163            if pair.is_trivially_joinable() {
164                pair.joinable = Some(true);
165                report.joinable_count += 1;
166            } else {
167                let joinable = self.test_joinability(&pair.result1, &pair.result2, system);
168                pair.joinable = Some(joinable);
169                if joinable {
170                    report.joinable_count += 1;
171                } else {
172                    report.non_joinable_count += 1;
173                }
174            }
175        }
176
177        // System is locally confluent if all critical pairs are joinable
178        report.is_locally_confluent = report.non_joinable_count == 0;
179
180        // For termination, we use a simple heuristic: no rule increases expression size
181        report.is_terminating = self.check_termination_heuristic(system);
182
183        report
184    }
185
186    /// Find critical pairs in the system (basic version).
187    ///
188    /// This is a simplified implementation that checks for overlaps at the top level.
189    fn find_critical_pairs_basic(&self, _system: &RewriteSystem, _report: &mut ConfluenceReport) {
190        // In a full implementation, we would:
191        // 1. For each pair of rules (r1, r2)
192        // 2. Find all ways their patterns can overlap
193        // 3. Apply both rules and record the results
194        //
195        // This is complex because it requires:
196        // - Unification of patterns
197        // - Finding all overlap positions
198        // - Handling variable bindings correctly
199        //
200        // For now, we provide the infrastructure without the full implementation.
201        // A production system would use a sophisticated unification algorithm.
202    }
203
204    /// Test if two expressions are joinable (can be rewritten to a common form).
205    ///
206    /// Uses breadth-first search to explore possible rewrites.
207    pub fn test_joinability(
208        &mut self,
209        expr1: &TLExpr,
210        expr2: &TLExpr,
211        system: &RewriteSystem,
212    ) -> bool {
213        // Check cache first
214        let key = (format!("{:?}", expr1), format!("{:?}", expr2));
215        if let Some(&result) = self.joinability_cache.get(&key) {
216            return result;
217        }
218
219        if expr1 == expr2 {
220            self.joinability_cache.insert(key, true);
221            return true;
222        }
223
224        // BFS from both expressions
225        let mut visited1 = HashSet::new();
226        let mut visited2 = HashSet::new();
227        let mut queue1 = VecDeque::new();
228        let mut queue2 = VecDeque::new();
229
230        queue1.push_back((expr1.clone(), 0));
231        queue2.push_back((expr2.clone(), 0));
232
233        visited1.insert(format!("{:?}", expr1));
234        visited2.insert(format!("{:?}", expr2));
235
236        while !queue1.is_empty() || !queue2.is_empty() {
237            // Expand from expr1
238            if let Some((current, depth)) = queue1.pop_front() {
239                if depth >= self.max_depth {
240                    continue;
241                }
242
243                // Check if we've reached a form that expr2 can reach
244                let current_key = format!("{:?}", &current);
245                if visited2.contains(&current_key) {
246                    self.joinability_cache.insert(key, true);
247                    return true;
248                }
249
250                // Apply all possible rewrites
251                for rewrite in self.get_all_rewrites(&current, system) {
252                    let rewrite_key = format!("{:?}", &rewrite);
253                    if !visited1.contains(&rewrite_key) {
254                        visited1.insert(rewrite_key);
255                        queue1.push_back((rewrite, depth + 1));
256                    }
257                }
258            }
259
260            // Expand from expr2
261            if let Some((current, depth)) = queue2.pop_front() {
262                if depth >= self.max_depth {
263                    continue;
264                }
265
266                let current_key = format!("{:?}", &current);
267                if visited1.contains(&current_key) {
268                    self.joinability_cache.insert(key, true);
269                    return true;
270                }
271
272                for rewrite in self.get_all_rewrites(&current, system) {
273                    let rewrite_key = format!("{:?}", &rewrite);
274                    if !visited2.contains(&rewrite_key) {
275                        visited2.insert(rewrite_key);
276                        queue2.push_back((rewrite, depth + 1));
277                    }
278                }
279            }
280        }
281
282        self.joinability_cache.insert(key, false);
283        false
284    }
285
286    /// Get all possible one-step rewrites of an expression.
287    #[allow(clippy::only_used_in_recursion)]
288    fn get_all_rewrites(&self, expr: &TLExpr, system: &RewriteSystem) -> Vec<TLExpr> {
289        let mut results = Vec::new();
290
291        // Try each rule
292        if let Some(rewritten) = system.apply_once(expr) {
293            results.push(rewritten);
294        }
295
296        // Recursively try rewrites on subexpressions
297        match expr {
298            TLExpr::And(l, r) => {
299                for l_rewrite in self.get_all_rewrites(l, system) {
300                    results.push(TLExpr::and(l_rewrite, (**r).clone()));
301                }
302                for r_rewrite in self.get_all_rewrites(r, system) {
303                    results.push(TLExpr::and((**l).clone(), r_rewrite));
304                }
305            }
306            TLExpr::Or(l, r) => {
307                for l_rewrite in self.get_all_rewrites(l, system) {
308                    results.push(TLExpr::or(l_rewrite, (**r).clone()));
309                }
310                for r_rewrite in self.get_all_rewrites(r, system) {
311                    results.push(TLExpr::or((**l).clone(), r_rewrite));
312                }
313            }
314            TLExpr::Not(e) => {
315                for e_rewrite in self.get_all_rewrites(e, system) {
316                    results.push(TLExpr::negate(e_rewrite));
317                }
318            }
319            _ => {}
320        }
321
322        results
323    }
324
325    /// Check termination using a simple heuristic.
326    ///
327    /// Returns true if no rule seems to increase expression size indefinitely.
328    fn check_termination_heuristic(&self, _system: &RewriteSystem) -> bool {
329        // A proper termination checker would use techniques like:
330        // - Polynomial interpretations
331        // - Lexicographic path ordering
332        // - Dependency pairs
333        //
334        // For now, we assume termination (conservative)
335        true
336    }
337}
338
339impl Default for ConfluenceChecker {
340    fn default() -> Self {
341        Self::new()
342    }
343}
344
345/// Check if two expressions can be joined under a rewrite system.
346///
347/// Convenience function that creates a checker and tests joinability.
348pub fn are_joinable(expr1: &TLExpr, expr2: &TLExpr, system: &RewriteSystem) -> bool {
349    let mut checker = ConfluenceChecker::new();
350    checker.test_joinability(expr1, expr2, system)
351}
352
353/// Compute normal form of an expression (if it exists).
354///
355/// Returns None if no normal form is reached within max_steps.
356pub fn normalize(expr: &TLExpr, system: &RewriteSystem, max_steps: usize) -> Option<TLExpr> {
357    let mut current = expr.clone();
358    let mut steps = 0;
359
360    while steps < max_steps {
361        if let Some(next) = system.apply_once(&current) {
362            current = next;
363            steps += 1;
364        } else {
365            return Some(current); // Normal form reached
366        }
367    }
368
369    None // Max steps exceeded
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use crate::{Pattern, RewriteRule, Term};
376
377    #[test]
378    fn test_critical_pair_trivial_joinable() {
379        let overlap = TLExpr::pred("P", vec![Term::var("x")]);
380        let result = TLExpr::pred("Q", vec![Term::var("x")]);
381
382        let pair = CriticalPair::new(
383            overlap,
384            result.clone(),
385            result,
386            "rule1".to_string(),
387            "rule2".to_string(),
388        );
389
390        assert!(pair.is_trivially_joinable());
391        assert!(!pair.has_conflict());
392    }
393
394    #[test]
395    fn test_critical_pair_with_conflict() {
396        let overlap = TLExpr::pred("P", vec![Term::var("x")]);
397        let result1 = TLExpr::pred("Q", vec![Term::var("x")]);
398        let result2 = TLExpr::pred("R", vec![Term::var("x")]);
399
400        let pair = CriticalPair::new(
401            overlap,
402            result1,
403            result2,
404            "rule1".to_string(),
405            "rule2".to_string(),
406        );
407
408        assert!(!pair.is_trivially_joinable());
409        assert!(pair.has_conflict());
410    }
411
412    #[test]
413    fn test_joinability_identical() {
414        let system = RewriteSystem::new();
415        let expr = TLExpr::pred("P", vec![Term::var("x")]);
416
417        let mut checker = ConfluenceChecker::new();
418        assert!(checker.test_joinability(&expr, &expr, &system));
419    }
420
421    #[test]
422    fn test_joinability_via_rewriting() {
423        let system = RewriteSystem::new().add_rule(RewriteRule::new(
424            Pattern::negation(Pattern::negation(Pattern::var("A"))),
425            |bindings| bindings.get("A").unwrap().clone(),
426        ));
427
428        let expr1 = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
429        let expr2 = TLExpr::pred("P", vec![Term::var("x")]);
430
431        let mut checker = ConfluenceChecker::new();
432        assert!(checker.test_joinability(&expr1, &expr2, &system));
433    }
434
435    #[test]
436    fn test_normalize_to_normal_form() {
437        let system = RewriteSystem::new().add_rule(RewriteRule::new(
438            Pattern::negation(Pattern::negation(Pattern::var("A"))),
439            |bindings| bindings.get("A").unwrap().clone(),
440        ));
441
442        let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
443        let normal_form = normalize(&expr, &system, 100).unwrap();
444
445        assert!(matches!(normal_form, TLExpr::Pred { .. }));
446    }
447
448    #[test]
449    fn test_confluence_report_summary() {
450        let mut report = ConfluenceReport::new();
451        report.joinable_count = 5;
452        report.non_joinable_count = 2;
453        report.is_locally_confluent = false;
454        report.is_terminating = true;
455
456        let summary = report.summary();
457        assert!(summary.contains("Joinable: 5"));
458        assert!(summary.contains("Non-joinable: 2"));
459        assert!(summary.contains("Confluent: false"));
460    }
461
462    #[test]
463    fn test_confluence_via_newmans_lemma() {
464        let mut report = ConfluenceReport::new();
465
466        // Case 1: Terminating and locally confluent => confluent
467        report.is_terminating = true;
468        report.is_locally_confluent = true;
469        assert!(report.is_confluent());
470
471        // Case 2: Not terminating => can't deduce confluence
472        report.is_terminating = false;
473        report.is_locally_confluent = true;
474        assert!(!report.is_confluent());
475
476        // Case 3: Not locally confluent => not confluent
477        report.is_terminating = true;
478        report.is_locally_confluent = false;
479        assert!(!report.is_confluent());
480    }
481}