1use std::collections::{HashMap, HashSet};
94use std::fmt;
95
96use crate::{
97 errors::SqliteGraphError,
98 graph::SqliteGraph,
99 progress::ProgressCallback,
100};
101
102use super::subgraph_isomorphism::{find_subgraph_patterns, SubgraphPatternBounds};
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub struct RewriteBounds {
123 pub max_matches: Option<usize>,
129
130 pub validate_after_rewrite: bool,
136}
137
138impl Default for RewriteBounds {
139 fn default() -> Self {
140 Self {
141 max_matches: Some(10),
142 validate_after_rewrite: true,
143 }
144 }
145}
146
147impl RewriteBounds {
148 #[inline]
150 pub fn new() -> Self {
151 Self::default()
152 }
153
154 #[inline]
156 pub fn with_max_matches(mut self, max: usize) -> Self {
157 self.max_matches = Some(max);
158 self
159 }
160
161 #[inline]
163 pub fn with_validation(mut self, validate: bool) -> Self {
164 self.validate_after_rewrite = validate;
165 self
166 }
167
168 #[inline]
173 pub fn unlimited(mut self) -> Self {
174 self.max_matches = None;
175 self
176 }
177}
178
179#[derive(Debug, Clone, PartialEq, Eq)]
185pub enum RewriteOperation {
186 NodeDeleted(i64),
188
189 NodeAdded(i64),
191
192 EdgeDeleted { from: i64, to: i64 },
194
195 EdgeAdded { from: i64, to: i64 },
197}
198
199impl fmt::Display for RewriteOperation {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 match self {
202 Self::NodeDeleted(id) => write!(f, "Deleted node {}", id),
203 Self::NodeAdded(id) => write!(f, "Added node {}", id),
204 Self::EdgeDeleted { from, to } => write!(f, "Deleted edge {} -> {}", from, to),
205 Self::EdgeAdded { from, to } => write!(f, "Added edge {} -> {}", from, to),
206 }
207 }
208}
209
210pub struct RewriteRule {
239 pub pattern: SqliteGraph,
241
242 pub replacement: SqliteGraph,
244
245 pub interface: Vec<(usize, usize)>,
251}
252
253impl RewriteRule {
254 #[inline]
256 pub fn interface_size(&self) -> usize {
257 self.interface.len()
258 }
259
260 fn validate_interface(&self) -> Result<(), SqliteGraphError> {
265 let pattern_ids = self.pattern.all_entity_ids()?;
266 let replacement_ids = self.replacement.all_entity_ids()?;
267
268 let pattern_count = pattern_ids.len();
269 let replacement_count = replacement_ids.len();
270
271 for &(pattern_idx, replacement_idx) in &self.interface {
272 if pattern_idx >= pattern_count {
273 return Err(SqliteGraphError::invalid_input(format!(
274 "Interface pattern index {} out of bounds (pattern has {} nodes)",
275 pattern_idx, pattern_count
276 )));
277 }
278 if replacement_idx >= replacement_count {
279 return Err(SqliteGraphError::invalid_input(format!(
280 "Interface replacement index {} out of bounds (replacement has {} nodes)",
281 replacement_idx, replacement_count
282 )));
283 }
284 }
285
286 Ok(())
287 }
288}
289
290pub struct RewriteResult {
315 pub rewritten_graph: SqliteGraph,
317
318 pub patterns_replaced: usize,
320
321 pub operations_applied: Vec<RewriteOperation>,
323
324 pub validation_errors: Vec<String>,
329}
330
331impl RewriteResult {
332 #[inline]
338 pub fn is_valid(&self) -> bool {
339 self.validation_errors.is_empty()
340 }
341
342 #[inline]
344 pub fn operation_count(&self) -> usize {
345 self.operations_applied.len()
346 }
347
348 #[inline]
350 pub fn nodes_added(&self) -> usize {
351 self.operations_applied
352 .iter()
353 .filter(|op| matches!(op, RewriteOperation::NodeAdded(_)))
354 .count()
355 }
356
357 #[inline]
359 pub fn nodes_deleted(&self) -> usize {
360 self.operations_applied
361 .iter()
362 .filter(|op| matches!(op, RewriteOperation::NodeDeleted(_)))
363 .count()
364 }
365
366 #[inline]
368 pub fn edges_added(&self) -> usize {
369 self.operations_applied
370 .iter()
371 .filter(|op| matches!(op, RewriteOperation::EdgeAdded { .. }))
372 .count()
373 }
374
375 #[inline]
377 pub fn edges_deleted(&self) -> usize {
378 self.operations_applied
379 .iter()
380 .filter(|op| matches!(op, RewriteOperation::EdgeDeleted { .. }))
381 .count()
382 }
383}
384
385fn validate_no_dangling_edges(graph: &SqliteGraph) -> Vec<String> {
398 let mut errors = Vec::new();
399
400 let valid_ids: HashSet<i64> = match graph.all_entity_ids() {
402 Ok(ids) => ids.into_iter().collect(),
403 Err(e) => {
404 errors.push(format!("Failed to get entity IDs: {}", e));
405 return errors;
406 }
407 };
408
409 for &node_id in &valid_ids {
411 if let Ok(outgoing) = graph.fetch_outgoing(node_id) {
412 for &target_id in &outgoing {
413 if !valid_ids.contains(&target_id) {
414 errors.push(format!(
415 "Dangling edge: {} -> {} (target node does not exist)",
416 node_id, target_id
417 ));
418 }
419 }
420 }
421 }
422
423 errors
424}
425
426fn copy_graph(graph: &SqliteGraph) -> Result<SqliteGraph, SqliteGraphError> {
439 let new_graph = SqliteGraph::open_in_memory()?;
440
441 let entity_ids = graph.all_entity_ids()?;
443 for &id in &entity_ids {
444 if let Ok(entity) = graph.get_entity(id) {
445 let _ = new_graph.insert_entity(&crate::GraphEntity {
446 id: 0,
447 kind: entity.kind.clone(),
448 name: entity.name.clone(),
449 file_path: entity.file_path.clone(),
450 data: entity.data.clone(),
451 });
452 }
453 }
454
455 let new_ids: Vec<i64> = new_graph.all_entity_ids()?
457 .into_iter()
458 .take(entity_ids.len())
459 .collect();
460
461 let mut old_to_new: HashMap<i64, i64> = HashMap::new();
462 for (old_id, new_id) in entity_ids.iter().zip(new_ids.iter()) {
463 old_to_new.insert(*old_id, *new_id);
464 }
465
466 for &from_id in &entity_ids {
467 if let Ok(outgoing) = graph.fetch_outgoing(from_id) {
468 for to_id in outgoing {
469 if let (Some(&new_from), Some(&new_to)) = (
470 old_to_new.get(&from_id),
471 old_to_new.get(&to_id)
472 ) {
473 let edge = crate::GraphEdge {
474 id: 0,
475 from_id: new_from,
476 to_id: new_to,
477 edge_type: "edge".to_string(),
478 data: serde_json::json!({}),
479 };
480 let _ = new_graph.insert_edge(&edge);
481 }
482 }
483 }
484 }
485
486 Ok(new_graph)
487}
488
489pub fn rewrite_graph_patterns(
551 graph: &SqliteGraph,
552 rule: &RewriteRule,
553 bounds: RewriteBounds,
554) -> Result<RewriteResult, SqliteGraphError> {
555 rule.validate_interface()?;
557
558 let pattern_bounds = SubgraphPatternBounds {
560 max_matches: bounds.max_matches,
561 timeout_ms: Some(5000),
562 max_pattern_nodes: Some(20),
563 };
564
565 let match_result = find_subgraph_patterns(graph, &rule.pattern, pattern_bounds)?;
566
567 if match_result.matches.is_empty() {
568 return Ok(RewriteResult {
570 rewritten_graph: copy_graph(graph)?,
571 patterns_replaced: 0,
572 operations_applied: vec![],
573 validation_errors: vec![],
574 });
575 }
576
577 let mut current_graph = copy_graph(graph)?;
579 let mut all_operations = Vec::new();
580 let mut patterns_replaced = 0;
581
582 let max_rewrites = bounds.max_matches.unwrap_or(match_result.matches.len());
584 let rewrites_to_apply = match_result.matches.len().min(max_rewrites);
585
586 for match_idx in 0..rewrites_to_apply {
587 let pattern_match = &match_result.matches[match_idx];
588
589 let (new_graph, operations) = apply_single_rewrite(
591 ¤t_graph,
592 rule,
593 pattern_match,
594 patterns_replaced,
595 )?;
596
597 current_graph = new_graph;
598 all_operations.extend(operations);
599 patterns_replaced += 1;
600 }
601
602 let validation_errors = if bounds.validate_after_rewrite {
604 validate_no_dangling_edges(¤t_graph)
605 } else {
606 vec![]
607 };
608
609 Ok(RewriteResult {
610 rewritten_graph: current_graph,
611 patterns_replaced,
612 operations_applied: all_operations,
613 validation_errors,
614 })
615}
616
617pub fn rewrite_graph_patterns_with_progress<F>(
652 graph: &SqliteGraph,
653 rule: &RewriteRule,
654 bounds: RewriteBounds,
655 progress: &F,
656) -> Result<RewriteResult, SqliteGraphError>
657where
658 F: ProgressCallback,
659{
660 progress.on_progress(0, Some(4), "Validating rewrite rule");
661
662 rule.validate_interface()?;
664
665 progress.on_progress(1, Some(4), "Finding pattern matches");
666
667 let pattern_bounds = SubgraphPatternBounds {
669 max_matches: bounds.max_matches,
670 timeout_ms: Some(5000),
671 max_pattern_nodes: Some(20),
672 };
673
674 let match_result = find_subgraph_patterns(graph, &rule.pattern, pattern_bounds)?;
675
676 progress.on_progress(
677 2,
678 Some(4),
679 &format!("Found {} pattern matches", match_result.matches.len()),
680 );
681
682 if match_result.matches.is_empty() {
683 progress.on_progress(3, Some(4), "No matches found, returning original graph");
685 progress.on_complete();
686
687 return Ok(RewriteResult {
688 rewritten_graph: copy_graph(graph)?,
689 patterns_replaced: 0,
690 operations_applied: vec![],
691 validation_errors: vec![],
692 });
693 }
694
695 let mut current_graph = copy_graph(graph)?;
697 let mut all_operations = Vec::new();
698 let mut patterns_replaced = 0;
699
700 let max_rewrites = bounds.max_matches.unwrap_or(match_result.matches.len());
702 let rewrites_to_apply = match_result.matches.len().min(max_rewrites);
703
704 for match_idx in 0..rewrites_to_apply {
705 let pattern_match = &match_result.matches[match_idx];
706
707 progress.on_progress(
708 2,
709 Some(4),
710 &format!("Applying rewrite {}/{}", match_idx + 1, rewrites_to_apply),
711 );
712
713 let (new_graph, operations) = apply_single_rewrite(
715 ¤t_graph,
716 rule,
717 pattern_match,
718 patterns_replaced,
719 )?;
720
721 current_graph = new_graph;
722 all_operations.extend(operations);
723 patterns_replaced += 1;
724 }
725
726 progress.on_progress(3, Some(4), "Validating rewritten graph");
727
728 let validation_errors = if bounds.validate_after_rewrite {
730 validate_no_dangling_edges(¤t_graph)
731 } else {
732 vec![]
733 };
734
735 let final_msg = if validation_errors.is_empty() {
736 format!(
737 "Rewrite complete: {} patterns replaced, {} operations applied",
738 patterns_replaced,
739 all_operations.len()
740 )
741 } else {
742 format!(
743 "Rewrite complete with errors: {} patterns replaced, {} validation errors",
744 patterns_replaced,
745 validation_errors.len()
746 )
747 };
748
749 progress.on_progress(4, Some(4), &final_msg);
750 progress.on_complete();
751
752 Ok(RewriteResult {
753 rewritten_graph: current_graph,
754 patterns_replaced,
755 operations_applied: all_operations,
756 validation_errors,
757 })
758}
759
760fn apply_single_rewrite(
780 graph: &SqliteGraph,
781 rule: &RewriteRule,
782 pattern_match: &[i64],
783 rewrite_index: usize,
784) -> Result<(SqliteGraph, Vec<RewriteOperation>), SqliteGraphError> {
785 let mut operations = Vec::new();
786
787 let pattern_ids = rule.pattern.all_entity_ids()?;
789 let replacement_ids = rule.replacement.all_entity_ids()?;
790
791 let mut interface_pattern_indices: HashSet<usize> = HashSet::new();
793 for &(pattern_idx, replacement_idx) in &rule.interface {
794 if pattern_idx < pattern_match.len() && replacement_idx < replacement_ids.len() {
795 interface_pattern_indices.insert(pattern_idx);
796 }
797 }
798
799 let mut non_interface_pattern_ids: HashSet<i64> = HashSet::new();
801 for (idx, _pattern_id) in pattern_ids.iter().enumerate() {
802 if idx < pattern_match.len() && !interface_pattern_indices.contains(&idx) {
803 let target_id = pattern_match[idx];
804 non_interface_pattern_ids.insert(target_id);
805 }
806 }
807
808 let new_graph = SqliteGraph::open_in_memory()?;
810
811 let mut old_to_new_id: HashMap<i64, i64> = HashMap::new();
813
814 let all_old_ids = graph.all_entity_ids()?;
816 for &old_id in &all_old_ids {
817 if !non_interface_pattern_ids.contains(&old_id) {
818 if let Ok(entity) = graph.get_entity(old_id) {
819 let new_id = new_graph.insert_entity(&crate::GraphEntity {
820 id: 0,
821 kind: entity.kind.clone(),
822 name: entity.name.clone(),
823 file_path: entity.file_path.clone(),
824 data: entity.data.clone(),
825 })?;
826 old_to_new_id.insert(old_id, new_id);
827 }
828 } else {
829 operations.push(RewriteOperation::NodeDeleted(old_id));
830 }
831 }
832
833 for &deleted_id in &non_interface_pattern_ids {
835 if let Ok(outgoing) = graph.fetch_outgoing(deleted_id) {
837 for &target_id in &outgoing {
838 operations.push(RewriteOperation::EdgeDeleted {
839 from: deleted_id,
840 to: target_id,
841 });
842 }
843 }
844 for &from_id in &all_old_ids {
846 if let Ok(outgoing) = graph.fetch_outgoing(from_id) {
847 if outgoing.contains(&deleted_id) {
848 operations.push(RewriteOperation::EdgeDeleted {
849 from: from_id,
850 to: deleted_id,
851 });
852 }
853 }
854 }
855 }
856
857 let mut replacement_node_map: HashMap<usize, i64> = HashMap::new();
859
860 for (idx, &replacement_id) in replacement_ids.iter().enumerate() {
861 let is_interface = rule.interface.iter().any(|(_, rep_idx)| *rep_idx == idx);
862
863 if !is_interface {
864 if let Ok(entity) = rule.replacement.get_entity(replacement_id) {
865 let fresh_id = new_graph.insert_entity(&crate::GraphEntity {
866 id: 0,
867 kind: entity.kind.clone(),
868 name: format!("{}_rewrite_{}", entity.name, rewrite_index),
869 file_path: entity.file_path.clone(),
870 data: entity.data.clone(),
871 })?;
872 replacement_node_map.insert(idx, fresh_id);
873 operations.push(RewriteOperation::NodeAdded(fresh_id));
874 }
875 }
876 }
877
878 for &from_old in &all_old_ids {
880 if let Some(&from_new) = old_to_new_id.get(&from_old) {
881 if let Ok(outgoing) = graph.fetch_outgoing(from_old) {
882 for to_old in outgoing {
883 if let Some(&to_new) = old_to_new_id.get(&to_old) {
884 let edge = crate::GraphEdge {
885 id: 0,
886 from_id: from_new,
887 to_id: to_new,
888 edge_type: "edge".to_string(),
889 data: serde_json::json!({}),
890 };
891 if new_graph.insert_edge(&edge).is_ok() {
892 operations.push(RewriteOperation::EdgeAdded {
893 from: from_new,
894 to: to_new,
895 });
896 }
897 }
898 }
899 }
900 }
901 }
902
903 if let Ok(repl_node_ids) = rule.replacement.all_entity_ids() {
905 for &from_repl_id in &repl_node_ids {
906 if let Ok(outgoing) = rule.replacement.fetch_outgoing(from_repl_id) {
907 for to_repl_id in outgoing {
908 let from_idx = repl_node_ids.iter().position(|&id| id == from_repl_id);
909 let to_idx = repl_node_ids.iter().position(|&id| id == to_repl_id);
910
911 if let (Some(from_i), Some(to_i)) = (from_idx, to_idx) {
912 let from_id = if let Some((pat_idx, _)) = rule.interface.iter().find(|(_, rep_idx)| *rep_idx == from_i) {
914 if *pat_idx < pattern_match.len() {
915 old_to_new_id.get(&pattern_match[*pat_idx]).copied()
916 } else {
917 None
918 }
919 } else {
920 replacement_node_map.get(&from_i).copied()
921 };
922
923 let to_id = if let Some((pat_idx, _)) = rule.interface.iter().find(|(_, rep_idx)| *rep_idx == to_i) {
924 if *pat_idx < pattern_match.len() {
925 old_to_new_id.get(&pattern_match[*pat_idx]).copied()
926 } else {
927 None
928 }
929 } else {
930 replacement_node_map.get(&to_i).copied()
931 };
932
933 if let (Some(from), Some(to)) = (from_id, to_id) {
934 let edge = crate::GraphEdge {
935 id: 0,
936 from_id: from,
937 to_id: to,
938 edge_type: "edge".to_string(),
939 data: serde_json::json!({}),
940 };
941 if new_graph.insert_edge(&edge).is_ok() {
942 operations.push(RewriteOperation::EdgeAdded {
943 from: from,
944 to: to,
945 });
946 }
947 }
948 }
949 }
950 }
951 }
952 }
953
954 Ok((new_graph, operations))
955}
956
957#[cfg(test)]
958mod tests {
959 use super::*;
960 use crate::{GraphEdge, GraphEntity};
961
962 fn create_test_graph_with_nodes(count: usize) -> SqliteGraph {
964 let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
965
966 for i in 0..count {
967 let entity = GraphEntity {
968 id: 0,
969 kind: "test".to_string(),
970 name: format!("test_{}", i),
971 file_path: Some(format!("test_{}.rs", i)),
972 data: serde_json::json!({"index": i}),
973 };
974 graph
975 .insert_entity(&entity)
976 .expect("Failed to insert entity");
977 }
978
979 graph
980 }
981
982 fn get_entity_ids(graph: &SqliteGraph, count: usize) -> Vec<i64> {
984 graph
985 .all_entity_ids()
986 .expect("Failed to get IDs")
987 .into_iter()
988 .take(count)
989 .collect()
990 }
991
992 fn add_edge(graph: &SqliteGraph, from_idx: i64, to_idx: i64) {
994 let ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
995
996 let edge = GraphEdge {
997 id: 0,
998 from_id: ids[from_idx as usize],
999 to_id: ids[to_idx as usize],
1000 edge_type: "edge".to_string(),
1001 data: serde_json::json!({}),
1002 };
1003 graph.insert_edge(&edge).ok();
1004 }
1005
1006 #[test]
1008 fn test_rewrite_bounds_default() {
1009 let bounds = RewriteBounds::default();
1010
1011 assert_eq!(bounds.max_matches, Some(10));
1012 assert!(bounds.validate_after_rewrite);
1013 }
1014
1015 #[test]
1017 fn test_rewrite_bounds_builder() {
1018 let bounds = RewriteBounds::default()
1019 .with_max_matches(100)
1020 .with_validation(false);
1021
1022 assert_eq!(bounds.max_matches, Some(100));
1023 assert!(!bounds.validate_after_rewrite);
1024 }
1025
1026 #[test]
1028 fn test_rewrite_bounds_unlimited() {
1029 let bounds = RewriteBounds::default().unlimited();
1030
1031 assert_eq!(bounds.max_matches, None);
1032 assert!(bounds.validate_after_rewrite);
1033 }
1034
1035 #[test]
1037 fn test_rewrite_operation_display() {
1038 assert_eq!(
1039 format!("{}", RewriteOperation::NodeDeleted(5)),
1040 "Deleted node 5"
1041 );
1042 assert_eq!(
1043 format!("{}", RewriteOperation::NodeAdded(10)),
1044 "Added node 10"
1045 );
1046 assert_eq!(
1047 format!("{}", RewriteOperation::EdgeDeleted { from: 1, to: 2 }),
1048 "Deleted edge 1 -> 2"
1049 );
1050 assert_eq!(
1051 format!("{}", RewriteOperation::EdgeAdded { from: 3, to: 4 }),
1052 "Added edge 3 -> 4"
1053 );
1054 }
1055
1056 #[test]
1058 fn test_rewrite_result_helpers() {
1059 let result = RewriteResult {
1060 rewritten_graph: SqliteGraph::open_in_memory().unwrap(),
1061 patterns_replaced: 2,
1062 operations_applied: vec![
1063 RewriteOperation::NodeDeleted(1),
1064 RewriteOperation::NodeDeleted(2),
1065 RewriteOperation::NodeAdded(10),
1066 RewriteOperation::EdgeDeleted { from: 1, to: 2 },
1067 RewriteOperation::EdgeAdded { from: 3, to: 10 },
1068 ],
1069 validation_errors: vec![],
1070 };
1071
1072 assert!(result.is_valid());
1073 assert_eq!(result.patterns_replaced, 2);
1074 assert_eq!(result.operation_count(), 5);
1075 assert_eq!(result.nodes_added(), 1);
1076 assert_eq!(result.nodes_deleted(), 2);
1077 assert_eq!(result.edges_added(), 1);
1078 assert_eq!(result.edges_deleted(), 1);
1079 }
1080
1081 #[test]
1083 fn test_rewrite_result_with_errors() {
1084 let result = RewriteResult {
1085 rewritten_graph: SqliteGraph::open_in_memory().unwrap(),
1086 patterns_replaced: 0,
1087 operations_applied: vec![],
1088 validation_errors: vec![
1089 "Dangling edge: 1 -> 999".to_string(),
1090 "Duplicate entity detected".to_string(),
1091 ],
1092 };
1093
1094 assert!(!result.is_valid());
1095 assert_eq!(result.validation_errors.len(), 2);
1096 }
1097
1098 #[test]
1100 fn test_validate_no_dangling_edges_valid() {
1101 let graph = create_test_graph_with_nodes(3);
1102 let ids = get_entity_ids(&graph, 3);
1103
1104 for (from, to) in &[(0, 1), (1, 2)] {
1106 let edge = GraphEdge {
1107 id: 0,
1108 from_id: ids[*from],
1109 to_id: ids[*to],
1110 edge_type: "edge".to_string(),
1111 data: serde_json::json!({}),
1112 };
1113 graph.insert_edge(&edge).ok();
1114 }
1115
1116 let errors = validate_no_dangling_edges(&graph);
1117 assert!(errors.is_empty(), "Expected no errors, got: {:?}", errors);
1118 }
1119
1120 #[test]
1122 fn test_validate_dangling_edges_detected() {
1123 let graph = create_test_graph_with_nodes(3);
1124 let ids = get_entity_ids(&graph, 3);
1125
1126 let edge = GraphEdge {
1128 id: 0,
1129 from_id: ids[0],
1130 to_id: 99999, edge_type: "edge".to_string(),
1132 data: serde_json::json!({}),
1133 };
1134 graph.insert_edge(&edge).ok();
1135
1136 let errors = validate_no_dangling_edges(&graph);
1137 }
1142
1143 #[test]
1145 fn test_rewrite_rule_interface_size() {
1146 let pattern = create_test_graph_with_nodes(3);
1147 let replacement = create_test_graph_with_nodes(2);
1148
1149 let rule = RewriteRule {
1150 pattern,
1151 replacement,
1152 interface: vec![(0, 0), (2, 1)],
1153 };
1154
1155 assert_eq!(rule.interface_size(), 2);
1156 }
1157
1158 #[test]
1160 fn test_rewrite_simple_chain_rewrite() {
1161 let graph = create_test_graph_with_nodes(4);
1163 add_edge(&graph, 0, 1);
1164 add_edge(&graph, 1, 2);
1165 add_edge(&graph, 2, 3);
1166
1167 let pattern = create_test_graph_with_nodes(2);
1169 let pattern_ids = get_entity_ids(&pattern, 2);
1170 let pattern_edge = GraphEdge {
1171 id: 0,
1172 from_id: pattern_ids[0],
1173 to_id: pattern_ids[1],
1174 edge_type: "edge".to_string(),
1175 data: serde_json::json!({}),
1176 };
1177 pattern.insert_edge(&pattern_edge).ok();
1178
1179 let replacement = create_test_graph_with_nodes(1);
1181
1182 let rule = RewriteRule {
1185 pattern,
1186 replacement,
1187 interface: vec![(0, 0)],
1188 };
1189
1190 let bounds = RewriteBounds {
1191 max_matches: Some(1),
1192 validate_after_rewrite: true,
1193 };
1194
1195 let result = rewrite_graph_patterns(&graph, &rule, bounds).unwrap();
1196
1197 assert_eq!(result.patterns_replaced, 1);
1199 assert!(result.is_valid());
1200 }
1201
1202 #[test]
1204 fn test_rewrite_with_interface() {
1205 let graph = create_test_graph_with_nodes(3);
1207 add_edge(&graph, 0, 1);
1208 add_edge(&graph, 1, 2);
1209
1210 let pattern = create_test_graph_with_nodes(2);
1212 let pattern_ids = get_entity_ids(&pattern, 2);
1213 let pattern_edge = GraphEdge {
1214 id: 0,
1215 from_id: pattern_ids[0],
1216 to_id: pattern_ids[1],
1217 edge_type: "edge".to_string(),
1218 data: serde_json::json!({}),
1219 };
1220 pattern.insert_edge(&pattern_edge).ok();
1221
1222 let replacement = create_test_graph_with_nodes(1);
1224
1225 let rule = RewriteRule {
1228 pattern,
1229 replacement,
1230 interface: vec![(0, 0)],
1231 };
1232
1233 let bounds = RewriteBounds::default();
1234
1235 let result = rewrite_graph_patterns(&graph, &rule, bounds).unwrap();
1236
1237 assert_eq!(result.patterns_replaced, 2, "Should find 2 pattern matches");
1239 assert!(result.is_valid());
1240 }
1241
1242 #[test]
1244 fn test_rewrite_max_matches() {
1245 let graph = create_test_graph_with_nodes(5);
1247 for i in 0..4 {
1248 add_edge(&graph, i, i + 1);
1249 }
1250
1251 let pattern = create_test_graph_with_nodes(2);
1253 let pattern_ids = get_entity_ids(&pattern, 2);
1254 let pattern_edge = GraphEdge {
1255 id: 0,
1256 from_id: pattern_ids[0],
1257 to_id: pattern_ids[1],
1258 edge_type: "edge".to_string(),
1259 data: serde_json::json!({}),
1260 };
1261 pattern.insert_edge(&pattern_edge).ok();
1262
1263 let replacement = create_test_graph_with_nodes(1);
1265
1266 let rule = RewriteRule {
1267 pattern,
1268 replacement,
1269 interface: vec![(0, 0)],
1270 };
1271
1272 let bounds = RewriteBounds {
1274 max_matches: Some(2),
1275 validate_after_rewrite: true,
1276 };
1277
1278 let result = rewrite_graph_patterns(&graph, &rule, bounds).unwrap();
1279
1280 assert!(result.patterns_replaced <= 2);
1282 assert!(result.is_valid());
1283 }
1284
1285 #[test]
1287 fn test_rewrite_empty_pattern() {
1288 let graph = create_test_graph_with_nodes(3);
1289 add_edge(&graph, 0, 1);
1290
1291 let pattern = create_test_graph_with_nodes(1);
1293
1294 let replacement = create_test_graph_with_nodes(1);
1295
1296 let rule = RewriteRule {
1297 pattern,
1298 replacement,
1299 interface: vec![],
1300 };
1301
1302 let bounds = RewriteBounds::default();
1303
1304 let result = rewrite_graph_patterns(&graph, &rule, bounds).unwrap();
1305
1306 assert_eq!(result.patterns_replaced, 3, "Single node pattern should match all 3 nodes");
1308 assert!(result.is_valid());
1309 }
1310
1311 #[test]
1313 fn test_rewrite_multiple_occurrences() {
1314 let graph = create_test_graph_with_nodes(4);
1317 add_edge(&graph, 0, 1);
1318 add_edge(&graph, 2, 3);
1319
1320 let pattern = create_test_graph_with_nodes(2);
1322 let pattern_ids = get_entity_ids(&pattern, 2);
1323 let pattern_edge = GraphEdge {
1324 id: 0,
1325 from_id: pattern_ids[0],
1326 to_id: pattern_ids[1],
1327 edge_type: "edge".to_string(),
1328 data: serde_json::json!({}),
1329 };
1330 pattern.insert_edge(&pattern_edge).ok();
1331
1332 let replacement = create_test_graph_with_nodes(1);
1334
1335 let rule = RewriteRule {
1336 pattern,
1337 replacement,
1338 interface: vec![(0, 0)],
1339 };
1340
1341 let bounds = RewriteBounds {
1342 max_matches: Some(10),
1343 validate_after_rewrite: true,
1344 };
1345
1346 let result = rewrite_graph_patterns(&graph, &rule, bounds).unwrap();
1347
1348 assert_eq!(result.patterns_replaced, 2);
1350 assert!(result.is_valid());
1351 }
1352
1353 #[test]
1355 fn test_rewrite_common_subexpression_elimination() {
1356 let graph = SqliteGraph::open_in_memory().unwrap();
1359
1360 let add1 = graph.insert_entity(&GraphEntity {
1362 id: 0,
1363 kind: "Op".to_string(),
1364 name: "Add1".to_string(),
1365 file_path: None,
1366 data: serde_json::json!({"op": "add"}),
1367 }).unwrap();
1368
1369 let add2 = graph.insert_entity(&GraphEntity {
1370 id: 0,
1371 kind: "Op".to_string(),
1372 name: "Add2".to_string(),
1373 file_path: None,
1374 data: serde_json::json!({"op": "add"}),
1375 }).unwrap();
1376
1377 let x = graph.insert_entity(&GraphEntity {
1378 id: 0,
1379 kind: "Var".to_string(),
1380 name: "x".to_string(),
1381 file_path: None,
1382 data: serde_json::json!({}),
1383 }).unwrap();
1384
1385 let y = graph.insert_entity(&GraphEntity {
1386 id: 0,
1387 kind: "Var".to_string(),
1388 name: "y".to_string(),
1389 file_path: None,
1390 data: serde_json::json!({}),
1391 }).unwrap();
1392
1393 let _ = graph.insert_edge(&GraphEdge {
1395 id: 0,
1396 from_id: add1,
1397 to_id: x,
1398 edge_type: "uses".to_string(),
1399 data: serde_json::json!({}),
1400 });
1401
1402 let _ = graph.insert_edge(&GraphEdge {
1403 id: 0,
1404 from_id: add1,
1405 to_id: y,
1406 edge_type: "uses".to_string(),
1407 data: serde_json::json!({}),
1408 });
1409
1410 let _ = graph.insert_edge(&GraphEdge {
1411 id: 0,
1412 from_id: add2,
1413 to_id: x,
1414 edge_type: "uses".to_string(),
1415 data: serde_json::json!({}),
1416 });
1417
1418 let _ = graph.insert_edge(&GraphEdge {
1419 id: 0,
1420 from_id: add2,
1421 to_id: y,
1422 edge_type: "uses".to_string(),
1423 data: serde_json::json!({}),
1424 });
1425
1426 let original_node_count = graph.all_entity_ids().unwrap().len();
1429
1430 assert_eq!(original_node_count, 4);
1432 }
1433
1434 #[test]
1436 fn test_rewrite_progress_callback() {
1437 use crate::progress::NoProgress;
1438
1439 let graph = create_test_graph_with_nodes(3);
1440 add_edge(&graph, 0, 1);
1441 add_edge(&graph, 1, 2);
1442
1443 let pattern = create_test_graph_with_nodes(2);
1445 let pattern_ids = get_entity_ids(&pattern, 2);
1446 let pattern_edge = GraphEdge {
1447 id: 0,
1448 from_id: pattern_ids[0],
1449 to_id: pattern_ids[1],
1450 edge_type: "edge".to_string(),
1451 data: serde_json::json!({}),
1452 };
1453 pattern.insert_edge(&pattern_edge).ok();
1454
1455 let replacement = create_test_graph_with_nodes(1);
1456
1457 let rule = RewriteRule {
1458 pattern,
1459 replacement,
1460 interface: vec![(0, 0)],
1461 };
1462
1463 let progress = NoProgress;
1464 let bounds = RewriteBounds::default();
1465
1466 let result = rewrite_graph_patterns_with_progress(&graph, &rule, bounds, &progress).unwrap();
1468
1469 assert_eq!(result.patterns_replaced, 2, "Should find 2 pattern matches");
1471 assert!(result.is_valid());
1472 }
1473}