1use std::collections::{HashMap, HashSet, VecDeque};
12
13use super::rewriting::RewriteSystem;
14use super::TLExpr;
15
16#[derive(Debug, Clone)]
18pub struct CriticalPair {
19 pub overlap: TLExpr,
21 pub result1: TLExpr,
23 pub result2: TLExpr,
25 pub rule1_name: String,
27 pub rule2_name: String,
28 pub joinable: Option<bool>,
30}
31
32impl CriticalPair {
33 pub fn new(
35 overlap: TLExpr,
36 result1: TLExpr,
37 result2: TLExpr,
38 rule1_name: String,
39 rule2_name: String,
40 ) -> Self {
41 Self {
42 overlap,
43 result1,
44 result2,
45 rule1_name,
46 rule2_name,
47 joinable: None,
48 }
49 }
50
51 pub fn is_trivially_joinable(&self) -> bool {
53 self.result1 == self.result2
54 }
55
56 pub fn has_conflict(&self) -> bool {
58 !self.is_trivially_joinable()
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ConfluenceReport {
65 pub critical_pairs: Vec<CriticalPair>,
67 pub joinable_count: usize,
69 pub non_joinable_count: usize,
71 pub is_locally_confluent: bool,
73 pub is_terminating: bool,
75}
76
77impl ConfluenceReport {
78 pub fn new() -> Self {
80 Self {
81 critical_pairs: Vec::new(),
82 joinable_count: 0,
83 non_joinable_count: 0,
84 is_locally_confluent: false,
85 is_terminating: false,
86 }
87 }
88
89 pub fn is_confluent(&self) -> bool {
93 self.is_terminating && self.is_locally_confluent
94 }
95
96 pub fn summary(&self) -> String {
98 format!(
99 "Confluence Report:\n\
100 - Critical pairs: {}\n\
101 - Joinable: {}\n\
102 - Non-joinable: {}\n\
103 - Locally confluent: {}\n\
104 - Terminating: {}\n\
105 - Confluent: {}",
106 self.critical_pairs.len(),
107 self.joinable_count,
108 self.non_joinable_count,
109 self.is_locally_confluent,
110 self.is_terminating,
111 self.is_confluent()
112 )
113 }
114}
115
116impl Default for ConfluenceReport {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122pub struct ConfluenceChecker {
124 max_depth: usize,
126 max_expr_size: usize,
128 joinability_cache: HashMap<(String, String), bool>,
130}
131
132impl ConfluenceChecker {
133 pub fn new() -> Self {
135 Self {
136 max_depth: 10,
137 max_expr_size: 1000,
138 joinability_cache: HashMap::new(),
139 }
140 }
141
142 pub fn with_max_depth(mut self, depth: usize) -> Self {
144 self.max_depth = depth;
145 self
146 }
147
148 pub fn with_max_expr_size(mut self, size: usize) -> Self {
150 self.max_expr_size = size;
151 self
152 }
153
154 pub fn check(&mut self, system: &RewriteSystem) -> ConfluenceReport {
156 let mut report = ConfluenceReport::new();
157
158 self.find_critical_pairs_basic(system, &mut report);
160
161 for pair in &mut report.critical_pairs {
163 if pair.is_trivially_joinable() {
164 pair.joinable = Some(true);
165 report.joinable_count += 1;
166 } else {
167 let joinable = self.test_joinability(&pair.result1, &pair.result2, system);
168 pair.joinable = Some(joinable);
169 if joinable {
170 report.joinable_count += 1;
171 } else {
172 report.non_joinable_count += 1;
173 }
174 }
175 }
176
177 report.is_locally_confluent = report.non_joinable_count == 0;
179
180 report.is_terminating = self.check_termination_heuristic(system);
182
183 report
184 }
185
186 fn find_critical_pairs_basic(&self, _system: &RewriteSystem, _report: &mut ConfluenceReport) {
190 }
203
204 pub fn test_joinability(
208 &mut self,
209 expr1: &TLExpr,
210 expr2: &TLExpr,
211 system: &RewriteSystem,
212 ) -> bool {
213 let key = (format!("{:?}", expr1), format!("{:?}", expr2));
215 if let Some(&result) = self.joinability_cache.get(&key) {
216 return result;
217 }
218
219 if expr1 == expr2 {
220 self.joinability_cache.insert(key, true);
221 return true;
222 }
223
224 let mut visited1 = HashSet::new();
226 let mut visited2 = HashSet::new();
227 let mut queue1 = VecDeque::new();
228 let mut queue2 = VecDeque::new();
229
230 queue1.push_back((expr1.clone(), 0));
231 queue2.push_back((expr2.clone(), 0));
232
233 visited1.insert(format!("{:?}", expr1));
234 visited2.insert(format!("{:?}", expr2));
235
236 while !queue1.is_empty() || !queue2.is_empty() {
237 if let Some((current, depth)) = queue1.pop_front() {
239 if depth >= self.max_depth {
240 continue;
241 }
242
243 let current_key = format!("{:?}", ¤t);
245 if visited2.contains(¤t_key) {
246 self.joinability_cache.insert(key, true);
247 return true;
248 }
249
250 for rewrite in self.get_all_rewrites(¤t, system) {
252 let rewrite_key = format!("{:?}", &rewrite);
253 if !visited1.contains(&rewrite_key) {
254 visited1.insert(rewrite_key);
255 queue1.push_back((rewrite, depth + 1));
256 }
257 }
258 }
259
260 if let Some((current, depth)) = queue2.pop_front() {
262 if depth >= self.max_depth {
263 continue;
264 }
265
266 let current_key = format!("{:?}", ¤t);
267 if visited1.contains(¤t_key) {
268 self.joinability_cache.insert(key, true);
269 return true;
270 }
271
272 for rewrite in self.get_all_rewrites(¤t, system) {
273 let rewrite_key = format!("{:?}", &rewrite);
274 if !visited2.contains(&rewrite_key) {
275 visited2.insert(rewrite_key);
276 queue2.push_back((rewrite, depth + 1));
277 }
278 }
279 }
280 }
281
282 self.joinability_cache.insert(key, false);
283 false
284 }
285
286 #[allow(clippy::only_used_in_recursion)]
288 fn get_all_rewrites(&self, expr: &TLExpr, system: &RewriteSystem) -> Vec<TLExpr> {
289 let mut results = Vec::new();
290
291 if let Some(rewritten) = system.apply_once(expr) {
293 results.push(rewritten);
294 }
295
296 match expr {
298 TLExpr::And(l, r) => {
299 for l_rewrite in self.get_all_rewrites(l, system) {
300 results.push(TLExpr::and(l_rewrite, (**r).clone()));
301 }
302 for r_rewrite in self.get_all_rewrites(r, system) {
303 results.push(TLExpr::and((**l).clone(), r_rewrite));
304 }
305 }
306 TLExpr::Or(l, r) => {
307 for l_rewrite in self.get_all_rewrites(l, system) {
308 results.push(TLExpr::or(l_rewrite, (**r).clone()));
309 }
310 for r_rewrite in self.get_all_rewrites(r, system) {
311 results.push(TLExpr::or((**l).clone(), r_rewrite));
312 }
313 }
314 TLExpr::Not(e) => {
315 for e_rewrite in self.get_all_rewrites(e, system) {
316 results.push(TLExpr::negate(e_rewrite));
317 }
318 }
319 _ => {}
320 }
321
322 results
323 }
324
325 fn check_termination_heuristic(&self, _system: &RewriteSystem) -> bool {
329 true
336 }
337}
338
339impl Default for ConfluenceChecker {
340 fn default() -> Self {
341 Self::new()
342 }
343}
344
345pub fn are_joinable(expr1: &TLExpr, expr2: &TLExpr, system: &RewriteSystem) -> bool {
349 let mut checker = ConfluenceChecker::new();
350 checker.test_joinability(expr1, expr2, system)
351}
352
353pub fn normalize(expr: &TLExpr, system: &RewriteSystem, max_steps: usize) -> Option<TLExpr> {
357 let mut current = expr.clone();
358 let mut steps = 0;
359
360 while steps < max_steps {
361 if let Some(next) = system.apply_once(¤t) {
362 current = next;
363 steps += 1;
364 } else {
365 return Some(current); }
367 }
368
369 None }
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use crate::{Pattern, RewriteRule, Term};
376
377 #[test]
378 fn test_critical_pair_trivial_joinable() {
379 let overlap = TLExpr::pred("P", vec![Term::var("x")]);
380 let result = TLExpr::pred("Q", vec![Term::var("x")]);
381
382 let pair = CriticalPair::new(
383 overlap,
384 result.clone(),
385 result,
386 "rule1".to_string(),
387 "rule2".to_string(),
388 );
389
390 assert!(pair.is_trivially_joinable());
391 assert!(!pair.has_conflict());
392 }
393
394 #[test]
395 fn test_critical_pair_with_conflict() {
396 let overlap = TLExpr::pred("P", vec![Term::var("x")]);
397 let result1 = TLExpr::pred("Q", vec![Term::var("x")]);
398 let result2 = TLExpr::pred("R", vec![Term::var("x")]);
399
400 let pair = CriticalPair::new(
401 overlap,
402 result1,
403 result2,
404 "rule1".to_string(),
405 "rule2".to_string(),
406 );
407
408 assert!(!pair.is_trivially_joinable());
409 assert!(pair.has_conflict());
410 }
411
412 #[test]
413 fn test_joinability_identical() {
414 let system = RewriteSystem::new();
415 let expr = TLExpr::pred("P", vec![Term::var("x")]);
416
417 let mut checker = ConfluenceChecker::new();
418 assert!(checker.test_joinability(&expr, &expr, &system));
419 }
420
421 #[test]
422 fn test_joinability_via_rewriting() {
423 let system = RewriteSystem::new().add_rule(RewriteRule::new(
424 Pattern::negation(Pattern::negation(Pattern::var("A"))),
425 |bindings| bindings.get("A").unwrap().clone(),
426 ));
427
428 let expr1 = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
429 let expr2 = TLExpr::pred("P", vec![Term::var("x")]);
430
431 let mut checker = ConfluenceChecker::new();
432 assert!(checker.test_joinability(&expr1, &expr2, &system));
433 }
434
435 #[test]
436 fn test_normalize_to_normal_form() {
437 let system = RewriteSystem::new().add_rule(RewriteRule::new(
438 Pattern::negation(Pattern::negation(Pattern::var("A"))),
439 |bindings| bindings.get("A").unwrap().clone(),
440 ));
441
442 let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
443 let normal_form = normalize(&expr, &system, 100).unwrap();
444
445 assert!(matches!(normal_form, TLExpr::Pred { .. }));
446 }
447
448 #[test]
449 fn test_confluence_report_summary() {
450 let mut report = ConfluenceReport::new();
451 report.joinable_count = 5;
452 report.non_joinable_count = 2;
453 report.is_locally_confluent = false;
454 report.is_terminating = true;
455
456 let summary = report.summary();
457 assert!(summary.contains("Joinable: 5"));
458 assert!(summary.contains("Non-joinable: 2"));
459 assert!(summary.contains("Confluent: false"));
460 }
461
462 #[test]
463 fn test_confluence_via_newmans_lemma() {
464 let mut report = ConfluenceReport::new();
465
466 report.is_terminating = true;
468 report.is_locally_confluent = true;
469 assert!(report.is_confluent());
470
471 report.is_terminating = false;
473 report.is_locally_confluent = true;
474 assert!(!report.is_confluent());
475
476 report.is_terminating = true;
478 report.is_locally_confluent = false;
479 assert!(!report.is_confluent());
480 }
481}