1use std::collections::{HashMap, VecDeque};
97
98use ahash::AHashSet;
99
100use crate::errors::SqliteGraphError;
101use crate::graph::SqliteGraph;
102use crate::progress::ProgressCallback;
103
104#[derive(Debug, Clone, PartialEq, Eq)]
135pub struct MinCutResult {
136 pub source_side: AHashSet<i64>,
138 pub sink_side: AHashSet<i64>,
140 pub cut_edges: Vec<(i64, i64)>,
142 pub cut_size: usize,
144}
145
146#[derive(Debug, Clone, PartialEq, Eq)]
173pub struct MinVertexCutResult {
174 pub separator: AHashSet<i64>,
176 pub source_side: AHashSet<i64>,
178 pub sink_side: AHashSet<i64>,
180 pub cut_size: usize,
182}
183
184#[derive(Debug, Clone, PartialEq, Eq)]
216pub struct PartitionResult {
217 pub partitions: Vec<AHashSet<i64>>,
219 pub cut_edges: Vec<(i64, i64)>,
221 pub node_to_partition: HashMap<i64, usize>,
223}
224
225#[derive(Debug, Clone)]
256pub struct PartitionConfig {
257 pub k: usize,
259 pub max_size: usize,
261 pub max_imbalance: f64,
263 pub seeds: Option<Vec<i64>>,
265}
266
267impl Default for PartitionConfig {
268 fn default() -> Self {
269 Self {
270 k: 2,
271 max_size: usize::MAX,
272 max_imbalance: 0.1,
273 seeds: None,
274 }
275 }
276}
277
278#[derive(Debug, Clone)]
284struct FlowEdge {
285 to: i64,
286 capacity: usize,
287 flow: usize,
288}
289
290impl FlowEdge {
291 fn new(to: i64, capacity: usize) -> Self {
292 Self {
293 to,
294 capacity,
295 flow: 0,
296 }
297 }
298
299 fn residual(&self) -> usize {
301 self.capacity - self.flow
302 }
303
304 fn add_flow(&mut self, amount: usize) -> usize {
306 let can_add = self.residual().min(amount);
307 self.flow += can_add;
308 amount - can_add
309 }
310}
311
312struct FlowNetwork {
314 adjacency: HashMap<i64, Vec<FlowEdge>>,
316 reverse_edge: HashMap<(i64, i64), usize>,
319}
320
321impl FlowNetwork {
322 fn new() -> Self {
324 Self {
325 adjacency: HashMap::new(),
326 reverse_edge: HashMap::new(),
327 }
328 }
329
330 fn add_edge(&mut self, from: i64, to: i64, capacity: usize) {
332 if from == to {
334 return;
335 }
336
337 let forward_idx = self.adjacency.entry(from).or_insert_with(Vec::new).len();
339 let reverse_idx = self.adjacency.entry(to).or_insert_with(Vec::new).len();
341
342 self.adjacency.entry(from).or_insert_with(Vec::new).push(FlowEdge::new(to, capacity));
344 self.adjacency.entry(to).or_insert_with(Vec::new).push(FlowEdge::new(from, 0));
346
347 self.reverse_edge.insert((from, to), reverse_idx);
349 self.reverse_edge.insert((to, from), forward_idx);
350 }
351
352 fn neighbors(&self, node: i64) -> &[FlowEdge] {
354 self.adjacency.get(&node).map(|v| v.as_slice()).unwrap_or(&[])
355 }
356
357 fn nodes(&self) -> AHashSet<i64> {
359 self.adjacency.keys().copied().collect()
360 }
361
362 fn reachable_residual(&self, source: i64) -> AHashSet<i64> {
364 let mut visited = AHashSet::new();
365 let mut queue = VecDeque::new();
366
367 visited.insert(source);
368 queue.push_back(source);
369
370 while let Some(node) = queue.pop_front() {
371 for edge in self.neighbors(node) {
372 if edge.residual() > 0 && visited.insert(edge.to) {
373 queue.push_back(edge.to);
374 }
375 }
376 }
377
378 visited
379 }
380
381 fn find_cut_edges(&self, source_side: &AHashSet<i64>) -> Vec<(i64, i64)> {
383 let mut cut_edges = Vec::new();
384
385 for &from in source_side {
386 for edge in self.neighbors(from) {
387 if !source_side.contains(&edge.to) && edge.residual() == 0 {
392 cut_edges.push((from, edge.to));
393 }
394 }
395 }
396
397 cut_edges
398 }
399}
400
401fn edmonds_karp(
422 mut network: FlowNetwork,
423 source: i64,
424 sink: i64,
425) -> (usize, FlowNetwork) {
426 let mut max_flow = 0;
427
428 while let Some(path) = bfs_augmenting_path(&network, source, sink) {
430 let bottleneck = find_bottleneck(&network, &path);
432
433 augment_flow(&mut network, &path, bottleneck);
435
436 max_flow += bottleneck;
437 }
438
439 (max_flow, network)
440}
441
442fn bfs_augmenting_path(network: &FlowNetwork, source: i64, sink: i64) -> Option<Vec<i64>> {
447 let mut parent: HashMap<i64, (i64, usize)> = HashMap::new();
448 let mut queue = VecDeque::new();
449
450 queue.push_back(source);
451 parent.insert(source, (source, 0)); while let Some(node) = queue.pop_front() {
454 if node == sink {
455 let mut path = vec![sink];
457 let mut current = sink;
458
459 while current != source {
460 let (prev_node, _edge_idx) = *parent.get(¤t)?;
461 path.push(prev_node);
462 current = prev_node;
463 }
464
465 path.reverse();
466 return Some(path);
467 }
468
469 for (edge_idx, edge) in network.neighbors(node).iter().enumerate() {
471 if edge.residual() > 0 && !parent.contains_key(&edge.to) {
472 parent.insert(edge.to, (node, edge_idx));
473 queue.push_back(edge.to);
474 }
475 }
476 }
477
478 None }
480
481fn find_bottleneck(network: &FlowNetwork, path: &[i64]) -> usize {
483 let mut bottleneck = usize::MAX;
484
485 for i in 0..path.len().saturating_sub(1) {
486 let from = path[i];
487 let to = path[i + 1];
488
489 for edge in network.neighbors(from) {
490 if edge.to == to {
491 bottleneck = bottleneck.min(edge.residual());
492 break;
493 }
494 }
495 }
496
497 bottleneck
498}
499
500fn augment_flow(network: &mut FlowNetwork, path: &[i64], amount: usize) {
502 for i in 0..path.len().saturating_sub(1) {
503 let from = path[i];
504 let to = path[i + 1];
505
506 if let Some(forward_edges) = network.adjacency.get_mut(&from) {
508 for edge in forward_edges.iter_mut() {
509 if edge.to == to {
510 edge.flow += amount;
511 break;
512 }
513 }
514 }
515
516 if let Some(reverse_edges) = network.adjacency.get_mut(&to) {
518 for edge in reverse_edges.iter_mut() {
519 if edge.to == from {
520 edge.flow = edge.flow.saturating_sub(amount);
522 break;
523 }
524 }
525 }
526 }
527}
528
529fn build_flow_network(graph: &SqliteGraph, source: i64, _sink: i64) -> FlowNetwork {
538 let mut network = FlowNetwork::new();
539
540 let mut nodes_to_visit = vec![source];
542 let mut visited = AHashSet::new();
543 visited.insert(source);
544
545 while let Some(node) = nodes_to_visit.pop() {
547 if let Ok(neighbors) = graph.fetch_outgoing(node) {
548 for &neighbor in &neighbors {
549 network.add_edge(node, neighbor, 1);
551
552 if visited.insert(neighbor) {
553 nodes_to_visit.push(neighbor);
554 }
555 }
556 }
557 }
558
559 network
560}
561
562pub fn min_st_cut(
605 graph: &SqliteGraph,
606 source: i64,
607 sink: i64,
608) -> Result<MinCutResult, SqliteGraphError> {
609 if source == sink {
611 return Ok(MinCutResult {
612 source_side: {
613 let mut set = AHashSet::new();
614 set.insert(source);
615 set
616 },
617 sink_side: AHashSet::new(),
618 cut_edges: vec![],
619 cut_size: 0,
620 });
621 }
622
623 let network = build_flow_network(graph, source, sink);
625
626 if network.nodes().contains(&source) && !network.nodes().contains(&sink) {
628 return Ok(MinCutResult {
630 source_side: network.nodes(),
631 sink_side: AHashSet::new(),
632 cut_edges: vec![],
633 cut_size: 0,
634 });
635 }
636
637 let (max_flow, residual_network) = edmonds_karp(network, source, sink);
639
640 let source_side = residual_network.reachable_residual(source);
642 let all_nodes = residual_network.nodes();
643 let sink_side = all_nodes.difference(&source_side).copied().collect();
644 let cut_edges = residual_network.find_cut_edges(&source_side);
645
646 Ok(MinCutResult {
647 source_side,
648 sink_side,
649 cut_edges,
650 cut_size: max_flow,
651 })
652}
653
654pub fn min_st_cut_with_progress<F>(
691 graph: &SqliteGraph,
692 source: i64,
693 sink: i64,
694 progress: &F,
695) -> Result<MinCutResult, SqliteGraphError>
696where
697 F: ProgressCallback,
698{
699 if source == sink {
701 return Ok(MinCutResult {
702 source_side: {
703 let mut set = AHashSet::new();
704 set.insert(source);
705 set
706 },
707 sink_side: AHashSet::new(),
708 cut_edges: vec![],
709 cut_size: 0,
710 });
711 }
712
713 let network = build_flow_network(graph, source, sink);
715
716 if network.nodes().contains(&source) && !network.nodes().contains(&sink) {
718 return Ok(MinCutResult {
719 source_side: network.nodes(),
720 sink_side: AHashSet::new(),
721 cut_edges: vec![],
722 cut_size: 0,
723 });
724 }
725
726 let mut current_network = network;
728 let mut max_flow = 0;
729 let mut iteration = 0;
730
731 while let Some(path) = bfs_augmenting_path(¤t_network, source, sink) {
732 iteration += 1;
733
734 let bottleneck = find_bottleneck(¤t_network, &path);
735 augment_flow(&mut current_network, &path, bottleneck);
736 max_flow += bottleneck;
737
738 progress.on_progress(
740 iteration,
741 None,
742 &format!("Min cut: iteration {}, flow so far: {}", iteration, max_flow),
743 );
744 }
745
746 progress.on_complete();
748
749 let source_side = current_network.reachable_residual(source);
751 let all_nodes = current_network.nodes();
752 let sink_side = all_nodes.difference(&source_side).copied().collect();
753 let cut_edges = current_network.find_cut_edges(&source_side);
754
755 Ok(MinCutResult {
756 source_side,
757 sink_side,
758 cut_edges,
759 cut_size: max_flow,
760 })
761}
762
763struct VertexSplitTransform {
780 source: i64,
781 sink: i64,
782}
783
784impl VertexSplitTransform {
785 fn new(source: i64, sink: i64) -> Self {
786 Self { source, sink }
787 }
788
789 fn node_in(&self, x: i64) -> i64 {
791 if x == self.source || x == self.sink {
792 x } else {
794 x * 2
795 }
796 }
797
798 fn node_out(&self, x: i64) -> i64 {
800 if x == self.source || x == self.sink {
801 x } else {
803 x * 2 + 1
804 }
805 }
806
807 fn is_original_node(&self, node_id: i64, original: i64) -> bool {
809 if original == self.source || original == self.sink {
810 node_id == original
811 } else {
812 node_id == original * 2 || node_id == original * 2 + 1
813 }
814 }
815
816 fn to_original(&self, node_id: i64) -> i64 {
818 if node_id == self.source || node_id == self.sink {
819 node_id
820 } else if node_id % 2 == 0 {
821 node_id / 2 } else {
823 (node_id - 1) / 2 }
825 }
826
827 fn is_internal_edge(&self, from: i64, to: i64) -> Option<i64> {
829 if from % 2 == 0 && to == from + 1 {
832 let original = from / 2;
833 if original != self.source && original != self.sink {
834 return Some(original);
835 }
836 }
837 if from % 2 == 1 && to == from - 1 {
840 let original = (from - 1) / 2;
841 if original != self.source && original != self.sink {
842 return Some(original);
843 }
844 }
845 None
846 }
847}
848
849fn build_vertex_split_network(
858 graph: &SqliteGraph,
859 source: i64,
860 sink: i64,
861) -> (FlowNetwork, VertexSplitTransform) {
862 let transform = VertexSplitTransform::new(source, sink);
863 let mut network = FlowNetwork::new();
864
865 let mut nodes_to_visit = vec![source];
867 let mut visited = AHashSet::new();
868 visited.insert(source);
869
870 while let Some(node) = nodes_to_visit.pop() {
871 if let Ok(neighbors) = graph.fetch_outgoing(node) {
872 for &neighbor in &neighbors {
873 let node_out = transform.node_out(node);
874 let neighbor_in = transform.node_in(neighbor);
875
876 network.add_edge(node_out, neighbor_in, 1);
878
879 let neighbor_out = transform.node_out(neighbor);
881 if neighbor != source && neighbor != sink {
882 network.add_edge(neighbor_in, neighbor_out, 1);
883 }
884
885 let node_in = transform.node_in(node);
887 if node != source && node != sink {
888 network.add_edge(node_in, node_out, 1);
889 }
890
891 if visited.insert(neighbor) {
892 nodes_to_visit.push(neighbor);
893 }
894 }
895 }
896 }
897
898 let source_in = transform.node_in(source);
900 let source_out = transform.node_out(source);
901 if source_in != source_out {
902 network.add_edge(source_in, source_out, 1);
903 }
904
905 (network, transform)
906}
907
908pub fn min_vertex_cut(
954 graph: &SqliteGraph,
955 source: i64,
956 target: i64,
957) -> Result<MinVertexCutResult, SqliteGraphError> {
958 if source == target {
960 return Ok(MinVertexCutResult {
961 separator: AHashSet::new(),
962 source_side: {
963 let mut set = AHashSet::new();
964 set.insert(source);
965 set
966 },
967 sink_side: AHashSet::new(),
968 cut_size: 0,
969 });
970 }
971
972 let (network, transform) = build_vertex_split_network(graph, source, target);
974
975 let source_out = transform.node_out(source);
976 let target_in = transform.node_in(target);
977
978 if !network.nodes().contains(&target_in) {
980 return Ok(MinVertexCutResult {
981 separator: AHashSet::new(),
982 source_side: {
983 let mut set = AHashSet::new();
984 set.insert(source);
985 set
986 },
987 sink_side: AHashSet::new(),
988 cut_size: 0,
989 });
990 }
991
992 let (max_flow, residual_network) = edmonds_karp(network, source_out, target_in);
994
995 let mut separator = AHashSet::new();
997 for node in residual_network.nodes() {
998 for edge in residual_network.neighbors(node) {
999 if let Some(original) = transform.is_internal_edge(node, edge.to) {
1000 if edge.residual() == 0 {
1002 separator.insert(original);
1003 }
1004 }
1005 }
1006 }
1007
1008 let source_side_transformed = residual_network.reachable_residual(source_out);
1010 let mut source_side = AHashSet::new();
1011 for node in source_side_transformed {
1012 source_side.insert(transform.to_original(node));
1013 }
1014
1015 let all_nodes_transformed = residual_network.nodes();
1016 let mut sink_side = AHashSet::new();
1017 for node in all_nodes_transformed {
1018 let original = transform.to_original(node);
1019 if !source_side.contains(&original) {
1020 sink_side.insert(original);
1021 }
1022 }
1023
1024 Ok(MinVertexCutResult {
1025 separator: separator.clone(),
1026 source_side,
1027 sink_side,
1028 cut_size: separator.len(),
1029 })
1030}
1031
1032pub fn min_vertex_cut_with_progress<F>(
1065 graph: &SqliteGraph,
1066 source: i64,
1067 target: i64,
1068 progress: &F,
1069) -> Result<MinVertexCutResult, SqliteGraphError>
1070where
1071 F: ProgressCallback,
1072{
1073 if source == target {
1075 return Ok(MinVertexCutResult {
1076 separator: AHashSet::new(),
1077 source_side: {
1078 let mut set = AHashSet::new();
1079 set.insert(source);
1080 set
1081 },
1082 sink_side: AHashSet::new(),
1083 cut_size: 0,
1084 });
1085 }
1086
1087 let (network, transform) = build_vertex_split_network(graph, source, target);
1089
1090 let source_out = transform.node_out(source);
1091 let target_in = transform.node_in(target);
1092
1093 if !network.nodes().contains(&target_in) {
1095 return Ok(MinVertexCutResult {
1096 separator: AHashSet::new(),
1097 source_side: {
1098 let mut set = AHashSet::new();
1099 set.insert(source);
1100 set
1101 },
1102 sink_side: AHashSet::new(),
1103 cut_size: 0,
1104 });
1105 }
1106
1107 let mut current_network = network;
1109 let mut max_flow = 0;
1110 let mut iteration = 0;
1111
1112 while let Some(path) = bfs_augmenting_path(¤t_network, source_out, target_in) {
1113 iteration += 1;
1114
1115 let bottleneck = find_bottleneck(¤t_network, &path);
1116 augment_flow(&mut current_network, &path, bottleneck);
1117 max_flow += bottleneck;
1118
1119 progress.on_progress(
1121 iteration,
1122 None,
1123 &format!("Vertex cut: iteration {}, flow so far: {}", iteration, max_flow),
1124 );
1125 }
1126
1127 progress.on_complete();
1129
1130 let mut separator = AHashSet::new();
1132 for node in current_network.nodes() {
1133 for edge in current_network.neighbors(node) {
1134 if let Some(original) = transform.is_internal_edge(node, edge.to) {
1135 if edge.residual() == 0 {
1136 separator.insert(original);
1137 }
1138 }
1139 }
1140 }
1141
1142 let source_side_transformed = current_network.reachable_residual(source_out);
1144 let mut source_side = AHashSet::new();
1145 for node in source_side_transformed {
1146 source_side.insert(transform.to_original(node));
1147 }
1148
1149 let all_nodes_transformed = current_network.nodes();
1150 let mut sink_side = AHashSet::new();
1151 for node in all_nodes_transformed {
1152 let original = transform.to_original(node);
1153 if !source_side.contains(&original) {
1154 sink_side.insert(original);
1155 }
1156 }
1157
1158 Ok(MinVertexCutResult {
1159 separator,
1160 source_side,
1161 sink_side,
1162 cut_size: max_flow,
1163 })
1164}
1165
1166fn compute_cut_edges(
1181 graph: &SqliteGraph,
1182 node_to_partition: &HashMap<i64, usize>,
1183) -> Vec<(i64, i64)> {
1184 let mut cut_edges = Vec::new();
1185
1186 let nodes_to_check: Vec<i64> = if let Ok(all_ids) = graph.all_entity_ids() {
1188 all_ids
1189 } else {
1190 return cut_edges;
1191 };
1192
1193 for &from_node in &nodes_to_check {
1194 if let Ok(neighbors) = graph.fetch_outgoing(from_node) {
1195 for &to_node in &neighbors {
1196 if let (Some(&from_partition), Some(&to_partition)) = (
1198 node_to_partition.get(&from_node),
1199 node_to_partition.get(&to_node),
1200 ) {
1201 if from_partition != to_partition {
1202 cut_edges.push((from_node, to_node));
1203 }
1204 }
1205 }
1206 }
1207 }
1208
1209 cut_edges
1210}
1211
1212pub fn partition_bfs_level(
1260 graph: &SqliteGraph,
1261 seeds: Vec<i64>,
1262 k: usize,
1263) -> Result<PartitionResult, SqliteGraphError> {
1264 if k < 2 {
1266 return Ok(PartitionResult {
1267 partitions: vec![AHashSet::new()],
1268 cut_edges: vec![],
1269 node_to_partition: HashMap::new(),
1270 });
1271 }
1272
1273 let all_nodes: AHashSet<i64> = graph.all_entity_ids()?.into_iter().collect();
1275
1276 if all_nodes.is_empty() {
1278 return Ok(PartitionResult {
1279 partitions: vec![AHashSet::new(); k],
1280 cut_edges: vec![],
1281 node_to_partition: HashMap::new(),
1282 });
1283 }
1284
1285 let mut effective_seeds = seeds;
1287 if effective_seeds.is_empty() {
1288 let mut sorted_nodes: Vec<i64> = all_nodes.iter().copied().collect();
1290 sorted_nodes.sort();
1291 effective_seeds = sorted_nodes.into_iter().take(k).collect();
1292 }
1293 effective_seeds.truncate(k.min(effective_seeds.len()));
1295
1296 let num_partitions = k.max(effective_seeds.len());
1298 let mut partitions: Vec<AHashSet<i64>> = (0..num_partitions).map(|_| AHashSet::new()).collect();
1299 let mut node_to_partition: HashMap<i64, usize> = HashMap::new();
1300
1301 let mut queue: VecDeque<(i64, usize, usize)> = VecDeque::new();
1303 let mut visited: AHashSet<i64> = AHashSet::new();
1304
1305 for (seed_idx, &seed) in effective_seeds.iter().enumerate() {
1307 if all_nodes.contains(&seed) {
1308 partitions[seed_idx].insert(seed);
1309 node_to_partition.insert(seed, seed_idx);
1310 visited.insert(seed);
1311 queue.push_back((seed, 0, seed_idx));
1312 }
1313 }
1314
1315 while let Some((node, _level, seed_idx)) = queue.pop_front() {
1317 if let Ok(neighbors) = graph.fetch_outgoing(node) {
1319 for &neighbor in &neighbors {
1320 if visited.insert(neighbor) {
1321 partitions[seed_idx].insert(neighbor);
1323 node_to_partition.insert(neighbor, seed_idx);
1324 queue.push_back((neighbor, 0, seed_idx));
1325 }
1326 }
1327 }
1328 }
1329
1330 while partitions.len() < k {
1332 partitions.push(AHashSet::new());
1333 }
1334
1335 let cut_edges = compute_cut_edges(graph, &node_to_partition);
1337
1338 Ok(PartitionResult {
1339 partitions,
1340 cut_edges,
1341 node_to_partition,
1342 })
1343}
1344
1345pub fn partition_greedy(
1394 graph: &SqliteGraph,
1395 initial_partition: Option<Vec<AHashSet<i64>>>,
1396 max_iterations: usize,
1397) -> Result<PartitionResult, SqliteGraphError> {
1398 let all_nodes: AHashSet<i64> = graph.all_entity_ids()?.into_iter().collect();
1400
1401 if all_nodes.is_empty() {
1403 return Ok(PartitionResult {
1404 partitions: vec![AHashSet::new(), AHashSet::new()],
1405 cut_edges: vec![],
1406 node_to_partition: HashMap::new(),
1407 });
1408 }
1409
1410 let (mut partitions, mut node_to_partition) = if let Some(init) = initial_partition {
1412 if init.len() < 2 {
1413 let init_result = partition_bfs_level(graph, vec![], 2)?;
1415 (init_result.partitions, init_result.node_to_partition)
1416 } else {
1417 let mut mapping = HashMap::new();
1419 for (pidx, partition) in init.iter().enumerate() {
1420 for &node in partition {
1421 mapping.insert(node, pidx);
1422 }
1423 }
1424 (init, mapping)
1425 }
1426 } else {
1427 let init_result = partition_bfs_level(graph, vec![], 2)?;
1429 (init_result.partitions, init_result.node_to_partition)
1430 };
1431
1432 if partitions.len() != 2 {
1434 partitions.resize(2, AHashSet::new());
1435 }
1436
1437 let initial_cut_size = compute_cut_edges(graph, &node_to_partition).len();
1438 let mut best_partitions = partitions.clone();
1439 let mut best_mapping = node_to_partition.clone();
1440 let mut best_cut_size = initial_cut_size;
1441
1442 for _iteration in 0..max_iterations {
1444 let mut improvement_found = false;
1445 let mut best_move: Option<(i64, usize, i64)> = None; let mut best_gain: i64 = 0;
1447
1448 for &node in all_nodes.iter() {
1450 if let Some(&from_partition) = node_to_partition.get(&node) {
1451 let to_partition = 1 - from_partition; let mut edges_to_other = 0i64;
1455 let mut edges_within = 0i64;
1456
1457 if let Ok(neighbors) = graph.fetch_outgoing(node) {
1458 for &neighbor in &neighbors {
1459 if let Some(&neighbor_partition) = node_to_partition.get(&neighbor) {
1460 if neighbor_partition == to_partition {
1461 edges_to_other += 1;
1462 } else if neighbor_partition == from_partition && neighbor != node {
1463 edges_within += 1;
1464 }
1465 }
1466 }
1467 }
1468
1469 let gain = edges_to_other - edges_within;
1470
1471 if gain > best_gain {
1472 best_gain = gain;
1473 best_move = Some((node, from_partition, gain));
1474 improvement_found = true;
1475 }
1476 }
1477 }
1478
1479 if !improvement_found || best_gain <= 0 {
1480 break; }
1482
1483 if let Some((node, from_partition, _gain)) = best_move {
1485 let to_partition = 1 - from_partition;
1486
1487 partitions[from_partition].remove(&node);
1489 partitions[to_partition].insert(node);
1490
1491 node_to_partition.insert(node, to_partition);
1493
1494 let current_cut_size = compute_cut_edges(graph, &node_to_partition).len();
1496 if current_cut_size < best_cut_size {
1497 best_cut_size = current_cut_size;
1498 best_partitions = partitions.clone();
1499 best_mapping = node_to_partition.clone();
1500 }
1501 }
1502 }
1503
1504 let cut_edges = compute_cut_edges(graph, &best_mapping);
1506
1507 Ok(PartitionResult {
1508 partitions: best_partitions,
1509 cut_edges,
1510 node_to_partition: best_mapping,
1511 })
1512}
1513
1514fn select_seeds_by_degree(
1526 graph: &SqliteGraph,
1527 k: usize,
1528 available_nodes: &AHashSet<i64>,
1529) -> Vec<i64> {
1530 let mut node_degrees: Vec<(i64, usize)> = Vec::new();
1531
1532 for &node in available_nodes {
1533 if let Ok(outgoing) = graph.fetch_outgoing(node) {
1534 let degree = outgoing.len();
1535 node_degrees.push((node, degree));
1536 }
1537 }
1538
1539 node_degrees.sort_by(|a, b| b.1.cmp(&a.1));
1541 node_degrees.truncate(k);
1542 node_degrees.into_iter().map(|(node, _)| node).collect()
1543}
1544
1545fn shortest_distance_to_targets(
1557 graph: &SqliteGraph,
1558 from: i64,
1559 targets: &AHashSet<i64>,
1560) -> usize {
1561 if targets.contains(&from) {
1562 return 0;
1563 }
1564
1565 let mut visited: AHashSet<i64> = AHashSet::new();
1566 let mut queue: VecDeque<(i64, usize)> = VecDeque::new();
1567
1568 visited.insert(from);
1569 queue.push_back((from, 0));
1570
1571 while let Some((node, dist)) = queue.pop_front() {
1572 if let Ok(neighbors) = graph.fetch_outgoing(node) {
1573 for &neighbor in &neighbors {
1574 if targets.contains(&neighbor) {
1575 return dist + 1;
1576 }
1577 if visited.insert(neighbor) {
1578 queue.push_back((neighbor, dist + 1));
1579 }
1580 }
1581 }
1582 }
1583
1584 usize::MAX }
1586
1587pub fn partition_kway(
1636 graph: &SqliteGraph,
1637 config: &PartitionConfig,
1638) -> Result<PartitionResult, SqliteGraphError> {
1639 if config.k < 2 {
1640 return Err(SqliteGraphError::InvalidInput(
1641 "k must be at least 2 for partitioning".to_string(),
1642 ));
1643 }
1644
1645 let all_nodes: AHashSet<i64> = graph.all_entity_ids()?.into_iter().collect();
1647
1648 if all_nodes.is_empty() {
1650 return Ok(PartitionResult {
1651 partitions: vec![AHashSet::new(); config.k],
1652 cut_edges: vec![],
1653 node_to_partition: HashMap::new(),
1654 });
1655 }
1656
1657 let effective_k = config.k.min(all_nodes.len());
1659 let mut partitions: Vec<AHashSet<i64>> = (0..effective_k).map(|_| AHashSet::new()).collect();
1660 let mut node_to_partition: HashMap<i64, usize> = HashMap::new();
1661
1662 let seeds = if let Some(ref provided_seeds) = config.seeds {
1664 provided_seeds.clone()
1665 } else {
1666 select_seeds_by_degree(graph, effective_k, &all_nodes)
1667 };
1668
1669 let mut effective_seeds = seeds;
1671 effective_seeds.truncate(effective_k);
1672 while effective_seeds.len() < effective_k {
1673 for &node in &all_nodes {
1675 if !effective_seeds.contains(&node) {
1676 effective_seeds.push(node);
1677 if effective_seeds.len() >= effective_k {
1678 break;
1679 }
1680 }
1681 }
1682 }
1683
1684 let target_size = (all_nodes.len() / effective_k).max(1);
1686 let max_allowed = if config.max_size == usize::MAX {
1687 ((target_size as f64) * (1.0 + config.max_imbalance)) as usize
1688 } else {
1689 config.max_size.min(all_nodes.len())
1690 };
1691
1692 let mut queue: VecDeque<(i64, usize)> = VecDeque::new(); let mut unassigned: AHashSet<i64> = AHashSet::new();
1695
1696 for (pidx, &seed) in effective_seeds.iter().enumerate() {
1698 if all_nodes.contains(&seed) {
1699 partitions[pidx].insert(seed);
1700 node_to_partition.insert(seed, pidx);
1701 queue.push_back((seed, pidx));
1702 }
1703 }
1704
1705 for &node in &all_nodes {
1707 if !node_to_partition.contains_key(&node) {
1708 unassigned.insert(node);
1709 }
1710 }
1711
1712 while let Some((node, pidx)) = queue.pop_front() {
1714 if partitions[pidx].len() >= max_allowed {
1716 continue;
1717 }
1718
1719 if let Ok(neighbors) = graph.fetch_outgoing(node) {
1721 for &neighbor in &neighbors {
1722 if unassigned.remove(&neighbor) {
1723 partitions[pidx].insert(neighbor);
1724 node_to_partition.insert(neighbor, pidx);
1725 queue.push_back((neighbor, pidx));
1726 }
1727 }
1728 }
1729 }
1730
1731 for &node in &unassigned {
1733 let mut best_partition = 0;
1734 let mut best_distance = usize::MAX;
1735
1736 for pidx in 0..effective_k {
1737 let target_nodes: AHashSet<i64> = partitions[pidx].iter().copied().collect();
1739 if target_nodes.is_empty() {
1740 continue;
1741 }
1742
1743 let distance = shortest_distance_to_targets(graph, node, &target_nodes);
1744 if distance < best_distance {
1745 best_distance = distance;
1746 best_partition = pidx;
1747 }
1748 }
1749
1750 partitions[best_partition].insert(node);
1751 node_to_partition.insert(node, best_partition);
1752 }
1753
1754 while partitions.len() < config.k {
1756 partitions.push(AHashSet::new());
1757 }
1758
1759 let cut_edges = compute_cut_edges(graph, &node_to_partition);
1761
1762 Ok(PartitionResult {
1763 partitions,
1764 cut_edges,
1765 node_to_partition,
1766 })
1767}
1768
1769pub fn partition_kway_with_progress<F>(
1801 graph: &SqliteGraph,
1802 config: &PartitionConfig,
1803 progress: &F,
1804) -> Result<PartitionResult, SqliteGraphError>
1805where
1806 F: ProgressCallback,
1807{
1808 if config.k < 2 {
1809 return Err(SqliteGraphError::InvalidInput(
1810 "k must be at least 2 for partitioning".to_string(),
1811 ));
1812 }
1813
1814 let all_nodes: AHashSet<i64> = graph.all_entity_ids()?.into_iter().collect();
1816 let total_nodes = all_nodes.len();
1817
1818 if all_nodes.is_empty() {
1820 progress.on_complete();
1821 return Ok(PartitionResult {
1822 partitions: vec![AHashSet::new(); config.k],
1823 cut_edges: vec![],
1824 node_to_partition: HashMap::new(),
1825 });
1826 }
1827
1828 let effective_k = config.k.min(all_nodes.len());
1830 let mut partitions: Vec<AHashSet<i64>> = (0..effective_k).map(|_| AHashSet::new()).collect();
1831 let mut node_to_partition: HashMap<i64, usize> = HashMap::new();
1832
1833 let seeds = if let Some(ref provided_seeds) = config.seeds {
1835 provided_seeds.clone()
1836 } else {
1837 select_seeds_by_degree(graph, effective_k, &all_nodes)
1838 };
1839
1840 let mut effective_seeds = seeds;
1842 effective_seeds.truncate(effective_k);
1843 while effective_seeds.len() < effective_k {
1844 for &node in &all_nodes {
1845 if !effective_seeds.contains(&node) {
1846 effective_seeds.push(node);
1847 if effective_seeds.len() >= effective_k {
1848 break;
1849 }
1850 }
1851 }
1852 }
1853
1854 let target_size = (all_nodes.len() / effective_k).max(1);
1856 let max_allowed = if config.max_size == usize::MAX {
1857 ((target_size as f64) * (1.0 + config.max_imbalance)) as usize
1858 } else {
1859 config.max_size.min(all_nodes.len())
1860 };
1861
1862 let mut queue: VecDeque<(i64, usize)> = VecDeque::new();
1864 let mut unassigned: AHashSet<i64> = AHashSet::new();
1865 let mut assigned_count = 0;
1866
1867 for (pidx, &seed) in effective_seeds.iter().enumerate() {
1869 if all_nodes.contains(&seed) {
1870 partitions[pidx].insert(seed);
1871 node_to_partition.insert(seed, pidx);
1872 assigned_count += 1;
1873 queue.push_back((seed, pidx));
1874 }
1875 }
1876
1877 for &node in &all_nodes {
1879 if !node_to_partition.contains_key(&node) {
1880 unassigned.insert(node);
1881 }
1882 }
1883
1884 while let Some((node, pidx)) = queue.pop_front() {
1886 if partitions[pidx].len() >= max_allowed {
1888 continue;
1889 }
1890
1891 if let Ok(neighbors) = graph.fetch_outgoing(node) {
1893 for &neighbor in &neighbors {
1894 if unassigned.remove(&neighbor) {
1895 partitions[pidx].insert(neighbor);
1896 node_to_partition.insert(neighbor, pidx);
1897 assigned_count += 1;
1898 queue.push_back((neighbor, pidx));
1899
1900 if assigned_count % 10 == 0 {
1902 progress.on_progress(
1903 assigned_count,
1904 Some(total_nodes),
1905 &format!("K-way partition: assigned {}/{} nodes", assigned_count, total_nodes),
1906 );
1907 }
1908 }
1909 }
1910 }
1911 }
1912
1913 for &node in &unassigned {
1915 let mut best_partition = 0;
1916 let mut best_distance = usize::MAX;
1917
1918 for pidx in 0..effective_k {
1919 let target_nodes: AHashSet<i64> = partitions[pidx].iter().copied().collect();
1920 if target_nodes.is_empty() {
1921 continue;
1922 }
1923
1924 let distance = shortest_distance_to_targets(graph, node, &target_nodes);
1925 if distance < best_distance {
1926 best_distance = distance;
1927 best_partition = pidx;
1928 }
1929 }
1930
1931 partitions[best_partition].insert(node);
1932 node_to_partition.insert(node, best_partition);
1933 assigned_count += 1;
1934 }
1935
1936 let _ = assigned_count; progress.on_complete();
1939
1940 while partitions.len() < config.k {
1942 partitions.push(AHashSet::new());
1943 }
1944
1945 let cut_edges = compute_cut_edges(graph, &node_to_partition);
1947
1948 Ok(PartitionResult {
1949 partitions,
1950 cut_edges,
1951 node_to_partition,
1952 })
1953}
1954
1955#[cfg(test)]
1960mod tests {
1961 use super::*;
1962 use crate::{GraphEdge, GraphEntity};
1963
1964 fn create_linear_chain() -> (SqliteGraph, i64, i64) {
1966 let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
1967
1968 for i in 0..4 {
1970 let entity = GraphEntity {
1971 id: 0,
1972 kind: "node".to_string(),
1973 name: format!("node_{}", i),
1974 file_path: Some(format!("node_{}.rs", i)),
1975 data: serde_json::json!({"index": i}),
1976 };
1977 graph.insert_entity(&entity).expect("Failed to insert entity");
1978 }
1979
1980 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
1981
1982 for i in 0..entity_ids.len().saturating_sub(1) {
1984 let edge = GraphEdge {
1985 id: 0,
1986 from_id: entity_ids[i],
1987 to_id: entity_ids[i + 1],
1988 edge_type: "next".to_string(),
1989 data: serde_json::json!({}),
1990 };
1991 graph.insert_edge(&edge).expect("Failed to insert edge");
1992 }
1993
1994 (graph, entity_ids[0], entity_ids[3])
1995 }
1996
1997 fn create_diamond() -> (SqliteGraph, i64, i64) {
1999 let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
2000
2001 for i in 0..4 {
2003 let entity = GraphEntity {
2004 id: 0,
2005 kind: "node".to_string(),
2006 name: format!("node_{}", i),
2007 file_path: Some(format!("node_{}.rs", i)),
2008 data: serde_json::json!({"index": i}),
2009 };
2010 graph.insert_entity(&entity).expect("Failed to insert entity");
2011 }
2012
2013 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
2014
2015 let edges = vec![(0, 1), (0, 2), (1, 3), (2, 3)];
2017 for (from_idx, to_idx) in edges {
2018 let edge = GraphEdge {
2019 id: 0,
2020 from_id: entity_ids[from_idx],
2021 to_id: entity_ids[to_idx],
2022 edge_type: "next".to_string(),
2023 data: serde_json::json!({}),
2024 };
2025 graph.insert_edge(&edge).expect("Failed to insert edge");
2026 }
2027
2028 (graph, entity_ids[0], entity_ids[3])
2029 }
2030
2031 fn create_parallel_paths() -> (SqliteGraph, i64, i64) {
2033 let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
2034
2035 for i in 0..5 {
2037 let entity = GraphEntity {
2038 id: 0,
2039 kind: "node".to_string(),
2040 name: format!("node_{}", i),
2041 file_path: Some(format!("node_{}.rs", i)),
2042 data: serde_json::json!({"index": i}),
2043 };
2044 graph.insert_entity(&entity).expect("Failed to insert entity");
2045 }
2046
2047 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
2048
2049 let edges = vec![(0, 1), (1, 4), (0, 2), (2, 4), (0, 3), (3, 4)];
2051 for (from_idx, to_idx) in edges {
2052 let edge = GraphEdge {
2053 id: 0,
2054 from_id: entity_ids[from_idx],
2055 to_id: entity_ids[to_idx],
2056 edge_type: "next".to_string(),
2057 data: serde_json::json!({}),
2058 };
2059 graph.insert_edge(&edge).expect("Failed to insert edge");
2060 }
2061
2062 (graph, entity_ids[0], entity_ids[4])
2063 }
2064
2065 fn create_single_edge() -> (SqliteGraph, i64, i64) {
2067 let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
2068
2069 for i in 0..2 {
2071 let entity = GraphEntity {
2072 id: 0,
2073 kind: "node".to_string(),
2074 name: format!("node_{}", i),
2075 file_path: Some(format!("node_{}.rs", i)),
2076 data: serde_json::json!({"index": i}),
2077 };
2078 graph.insert_entity(&entity).expect("Failed to insert entity");
2079 }
2080
2081 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
2082
2083 let edge = GraphEdge {
2085 id: 0,
2086 from_id: entity_ids[0],
2087 to_id: entity_ids[1],
2088 edge_type: "next".to_string(),
2089 data: serde_json::json!({}),
2090 };
2091 graph.insert_edge(&edge).expect("Failed to insert edge");
2092
2093 (graph, entity_ids[0], entity_ids[1])
2094 }
2095
2096 fn create_disconnected() -> (SqliteGraph, i64, i64) {
2098 let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
2099
2100 for i in 0..4 {
2102 let entity = GraphEntity {
2103 id: 0,
2104 kind: "node".to_string(),
2105 name: format!("node_{}", i),
2106 file_path: Some(format!("node_{}.rs", i)),
2107 data: serde_json::json!({"index": i}),
2108 };
2109 graph.insert_entity(&entity).expect("Failed to insert entity");
2110 }
2111
2112 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
2113
2114 let edge1 = GraphEdge {
2116 id: 0,
2117 from_id: entity_ids[0],
2118 to_id: entity_ids[1],
2119 edge_type: "next".to_string(),
2120 data: serde_json::json!({}),
2121 };
2122 graph.insert_edge(&edge1).expect("Failed to insert edge");
2123
2124 let edge2 = GraphEdge {
2125 id: 0,
2126 from_id: entity_ids[2],
2127 to_id: entity_ids[3],
2128 edge_type: "next".to_string(),
2129 data: serde_json::json!({}),
2130 };
2131 graph.insert_edge(&edge2).expect("Failed to insert edge");
2132
2133 (graph, entity_ids[0], entity_ids[3])
2134 }
2135
2136 #[test]
2139 fn test_min_st_cut_linear_chain() {
2140 let (graph, source, sink) = create_linear_chain();
2143
2144 let result = min_st_cut(&graph, source, sink).expect("Failed to compute min cut");
2145
2146 assert_eq!(result.cut_size, 1, "Linear chain should have cut size 1");
2147 assert_eq!(result.cut_edges.len(), 1, "Should have 1 cut edge");
2148 assert!(
2149 result.source_side.contains(&source),
2150 "Source side should contain source"
2151 );
2152 assert!(
2153 result.sink_side.contains(&sink),
2154 "Sink side should contain sink"
2155 );
2156 }
2157
2158 #[test]
2159 fn test_min_st_cut_diamond() {
2160 let (graph, source, sink) = create_diamond();
2163
2164 let result = min_st_cut(&graph, source, sink).expect("Failed to compute min cut");
2165
2166 assert_eq!(result.cut_size, 2, "Diamond should have cut size 2");
2167 assert_eq!(result.cut_edges.len(), 2, "Should have 2 cut edges");
2168 }
2169
2170 #[test]
2171 fn test_min_st_cut_parallel_paths() {
2172 let (graph, source, sink) = create_parallel_paths();
2175
2176 let result = min_st_cut(&graph, source, sink).expect("Failed to compute min cut");
2177
2178 assert_eq!(
2179 result.cut_size,
2180 3,
2181 "Parallel paths should have cut size 3"
2182 );
2183 assert_eq!(result.cut_edges.len(), 3, "Should have 3 cut edges");
2184 }
2185
2186 #[test]
2187 fn test_min_st_cut_single_edge() {
2188 let (graph, source, sink) = create_single_edge();
2191
2192 let result = min_st_cut(&graph, source, sink).expect("Failed to compute min cut");
2193
2194 assert_eq!(result.cut_size, 1, "Single edge should have cut size 1");
2195 assert_eq!(result.cut_edges.len(), 1, "Should have 1 cut edge");
2196 assert_eq!(
2197 result.cut_edges[0],
2198 (source, sink),
2199 "Cut edge should be (source, sink)"
2200 );
2201 }
2202
2203 #[test]
2204 fn test_min_st_cut_source_equals_target() {
2205 let (graph, source, _) = create_single_edge();
2208
2209 let result = min_st_cut(&graph, source, source).expect("Failed to compute min cut");
2210
2211 assert_eq!(result.cut_size, 0, "Source==target should have cut size 0");
2212 assert!(result.cut_edges.is_empty(), "Cut edges should be empty");
2213 assert!(result.source_side.contains(&source), "Source side contains source");
2214 assert!(result.sink_side.is_empty(), "Sink side should be empty");
2215 }
2216
2217 #[test]
2218 fn test_min_st_cut_with_progress_matches() {
2219 use crate::progress::NoProgress;
2222
2223 let (graph, source, sink) = create_diamond();
2224
2225 let progress = NoProgress;
2226 let result_with =
2227 min_st_cut_with_progress(&graph, source, sink, &progress).expect("Failed");
2228 let result_without = min_st_cut(&graph, source, sink).expect("Failed");
2229
2230 assert_eq!(
2231 result_with.cut_size,
2232 result_without.cut_size,
2233 "Cut size should match"
2234 );
2235 assert_eq!(
2236 result_with.cut_edges.len(),
2237 result_without.cut_edges.len(),
2238 "Cut edges count should match"
2239 );
2240 }
2241
2242 #[test]
2245 fn test_min_vertex_cut_bridge_node() {
2246 let (graph, source, sink) = create_linear_chain();
2250
2251 let result = min_vertex_cut(&graph, source, sink).expect("Failed to compute vertex cut");
2252
2253 assert_eq!(
2254 result.cut_size,
2255 2,
2256 "Linear chain should have vertex cut size 2 (both intermediate nodes)"
2257 );
2258 assert_eq!(result.separator.len(), 2, "Should have 2 separator vertices");
2259 }
2260
2261 #[test]
2262 fn test_min_vertex_cut_two_parallel_paths() {
2263 let (graph, source, sink) = create_diamond();
2266
2267 let result = min_vertex_cut(&graph, source, sink).expect("Failed to compute vertex cut");
2268
2269 assert_eq!(
2270 result.cut_size,
2271 2,
2272 "Two parallel paths should have vertex cut size 2"
2273 );
2274 assert_eq!(result.separator.len(), 2, "Should have 2 separator vertices");
2275 }
2276
2277 #[test]
2278 fn test_min_vertex_cut_direct_edge() {
2279 let (graph, source, sink) = create_single_edge();
2282
2283 eprintln!("Direct edge test: source={}, sink={}", source, sink);
2284
2285 let result = min_vertex_cut(&graph, source, sink).expect("Failed to compute vertex cut");
2286
2287 eprintln!("cut_size={}, separator={:?}", result.cut_size, result.separator);
2288
2289 assert_eq!(
2290 result.cut_size,
2291 0,
2292 "Direct edge should have vertex cut size 0"
2293 );
2294 assert!(
2295 result.separator.is_empty(),
2296 "Separator should be empty for direct edge"
2297 );
2298 }
2299
2300 #[test]
2301 fn test_min_vertex_cut_source_equals_target() {
2302 let (graph, source, _) = create_single_edge();
2305
2306 let result =
2307 min_vertex_cut(&graph, source, source).expect("Failed to compute vertex cut");
2308
2309 assert_eq!(result.cut_size, 0, "Source==target should have cut size 0");
2310 assert!(result.separator.is_empty(), "Separator should be empty");
2311 assert!(result.source_side.contains(&source), "Source side contains source");
2312 }
2313
2314 #[test]
2315 fn test_min_vertex_cut_with_progress_matches() {
2316 use crate::progress::NoProgress;
2319
2320 let (graph, source, sink) = create_diamond();
2321
2322 let progress = NoProgress;
2323 let result_with =
2324 min_vertex_cut_with_progress(&graph, source, sink, &progress).expect("Failed");
2325 let result_without = min_vertex_cut(&graph, source, sink).expect("Failed");
2326
2327 assert_eq!(
2328 result_with.cut_size,
2329 result_without.cut_size,
2330 "Cut size should match"
2331 );
2332 assert_eq!(
2333 result_with.separator.len(),
2334 result_without.separator.len(),
2335 "Separator size should match"
2336 );
2337 }
2338
2339 #[test]
2340 fn test_min_vertex_cut_three_parallel_paths() {
2341 let (graph, source, sink) = create_parallel_paths();
2344
2345 let result = min_vertex_cut(&graph, source, sink).expect("Failed to compute vertex cut");
2346
2347 assert_eq!(
2348 result.cut_size,
2349 3,
2350 "Three parallel paths should have vertex cut size 3"
2351 );
2352 assert_eq!(result.separator.len(), 3, "Should have 3 separator vertices");
2353 }
2354
2355 fn create_path_graph() -> SqliteGraph {
2361 let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
2362
2363 for i in 0..5 {
2364 let entity = GraphEntity {
2365 id: 0,
2366 kind: "node".to_string(),
2367 name: format!("node_{}", i),
2368 file_path: Some(format!("node_{}.rs", i)),
2369 data: serde_json::json!({"index": i}),
2370 };
2371 graph.insert_entity(&entity).expect("Failed to insert entity");
2372 }
2373
2374 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
2375 for i in 0..entity_ids.len().saturating_sub(1) {
2376 let edge = GraphEdge {
2377 id: 0,
2378 from_id: entity_ids[i],
2379 to_id: entity_ids[i + 1],
2380 edge_type: "next".to_string(),
2381 data: serde_json::json!({}),
2382 };
2383 graph.insert_edge(&edge).expect("Failed to insert edge");
2384 }
2385
2386 graph
2387 }
2388
2389 fn create_star_graph(leaves: usize) -> SqliteGraph {
2391 let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
2392
2393 let center_entity = GraphEntity {
2395 id: 0,
2396 kind: "node".to_string(),
2397 name: "center".to_string(),
2398 file_path: Some("center.rs".to_string()),
2399 data: serde_json::json!({}),
2400 };
2401 graph.insert_entity(¢er_entity).expect("Failed to insert entity");
2402
2403 for i in 0..leaves {
2405 let leaf_entity = GraphEntity {
2406 id: 0,
2407 kind: "node".to_string(),
2408 name: format!("leaf_{}", i),
2409 file_path: Some(format!("leaf_{}.rs", i)),
2410 data: serde_json::json!({}),
2411 };
2412 graph.insert_entity(&leaf_entity).expect("Failed to insert entity");
2413 }
2414
2415 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
2416 let center_id = entity_ids[0];
2417
2418 for i in 1..entity_ids.len() {
2420 let edge = GraphEdge {
2421 id: 0,
2422 from_id: center_id,
2423 to_id: entity_ids[i],
2424 edge_type: "edge".to_string(),
2425 data: serde_json::json!({}),
2426 };
2427 graph.insert_edge(&edge).expect("Failed to insert edge");
2428 }
2429
2430 graph
2431 }
2432
2433 fn create_binary_tree(height: usize) -> SqliteGraph {
2435 let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
2436
2437 let num_nodes = 2_usize.pow(height as u32 + 1) - 1;
2438 for i in 0..num_nodes {
2439 let entity = GraphEntity {
2440 id: 0,
2441 kind: "node".to_string(),
2442 name: format!("node_{}", i),
2443 file_path: Some(format!("node_{}.rs", i)),
2444 data: serde_json::json!({"index": i}),
2445 };
2446 graph.insert_entity(&entity).expect("Failed to insert entity");
2447 }
2448
2449 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
2450
2451 for i in 0..num_nodes / 2 {
2453 let left_child = 2 * i + 1;
2454 let right_child = 2 * i + 2;
2455
2456 if left_child < num_nodes {
2457 let edge = GraphEdge {
2458 id: 0,
2459 from_id: entity_ids[i],
2460 to_id: entity_ids[left_child],
2461 edge_type: "left".to_string(),
2462 data: serde_json::json!({}),
2463 };
2464 graph.insert_edge(&edge).expect("Failed to insert edge");
2465 }
2466
2467 if right_child < num_nodes {
2468 let edge = GraphEdge {
2469 id: 0,
2470 from_id: entity_ids[i],
2471 to_id: entity_ids[right_child],
2472 edge_type: "right".to_string(),
2473 data: serde_json::json!({}),
2474 };
2475 graph.insert_edge(&edge).expect("Failed to insert edge");
2476 }
2477 }
2478
2479 graph
2480 }
2481
2482 fn create_two_cliques() -> SqliteGraph {
2484 let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
2485
2486 for i in 0..3 {
2488 let entity = GraphEntity {
2489 id: 0,
2490 kind: "node".to_string(),
2491 name: format!("c1_{}", i),
2492 file_path: Some(format!("c1_{}.rs", i)),
2493 data: serde_json::json!({"clique": 1}),
2494 };
2495 graph.insert_entity(&entity).expect("Failed to insert entity");
2496 }
2497
2498 for i in 3..6 {
2500 let entity = GraphEntity {
2501 id: 0,
2502 kind: "node".to_string(),
2503 name: format!("c2_{}", i),
2504 file_path: Some(format!("c2_{}.rs", i)),
2505 data: serde_json::json!({"clique": 2}),
2506 };
2507 graph.insert_entity(&entity).expect("Failed to insert entity");
2508 }
2509
2510 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
2511
2512 for i in 0..3 {
2514 for j in (i + 1)..3 {
2515 let edge = GraphEdge {
2516 id: 0,
2517 from_id: entity_ids[i],
2518 to_id: entity_ids[j],
2519 edge_type: "intra".to_string(),
2520 data: serde_json::json!({}),
2521 };
2522 graph.insert_edge(&edge).expect("Failed to insert edge");
2523 }
2524 }
2525
2526 for i in 3..6 {
2528 for j in (i + 1)..6 {
2529 let edge = GraphEdge {
2530 id: 0,
2531 from_id: entity_ids[i],
2532 to_id: entity_ids[j],
2533 edge_type: "intra".to_string(),
2534 data: serde_json::json!({}),
2535 };
2536 graph.insert_edge(&edge).expect("Failed to insert edge");
2537 }
2538 }
2539
2540 let bridge = GraphEdge {
2542 id: 0,
2543 from_id: entity_ids[1],
2544 to_id: entity_ids[4],
2545 edge_type: "bridge".to_string(),
2546 data: serde_json::json!({}),
2547 };
2548 graph.insert_edge(&bridge).expect("Failed to insert edge");
2549
2550 graph
2551 }
2552
2553 #[test]
2556 fn test_partition_bfs_level_path_graph() {
2557 let graph = create_path_graph();
2560 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
2561
2562 let result = partition_bfs_level(&graph, vec![entity_ids[0], entity_ids[4]], 2)
2563 .expect("Failed to partition");
2564
2565 assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
2566 assert_eq!(
2567 result.partitions[0].len() + result.partitions[1].len(),
2568 5,
2569 "All nodes should be assigned"
2570 );
2571 assert!(result.cut_edges.len() <= 2, "Cut edges should be minimal");
2573 }
2574
2575 #[test]
2576 fn test_partition_bfs_level_star_graph() {
2577 let graph = create_star_graph(4);
2580 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
2581
2582 let result = partition_bfs_level(&graph, vec![entity_ids[0], entity_ids[2]], 2)
2583 .expect("Failed to partition");
2584
2585 assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
2586 assert_eq!(
2588 result.partitions[0].len() + result.partitions[1].len(),
2589 5,
2590 "All nodes should be assigned"
2591 );
2592 }
2593
2594 #[test]
2595 fn test_partition_bfs_level_binary_tree() {
2596 let graph = create_binary_tree(2);
2599 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
2600
2601 let result = partition_bfs_level(&graph, vec![entity_ids[0], entity_ids[6]], 2)
2602 .expect("Failed to partition");
2603
2604 assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
2605 assert_eq!(
2606 result.partitions[0].len() + result.partitions[1].len(),
2607 7,
2608 "All nodes should be assigned"
2609 );
2610 }
2611
2612 #[test]
2613 fn test_partition_bfs_level_disconnected() {
2614 let (graph, node_a, node_b) = create_disconnected();
2617
2618 let result = partition_bfs_level(&graph, vec![node_a, node_b], 2)
2619 .expect("Failed to partition");
2620
2621 assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
2622 assert!(
2624 result.partitions.iter().all(|p| p.len() > 0),
2625 "Each partition should have at least one node"
2626 );
2627 }
2628
2629 #[test]
2630 fn test_partition_bfs_level_empty_seeds() {
2631 let graph = create_path_graph();
2634
2635 let result = partition_bfs_level(&graph, vec![], 2)
2636 .expect("Failed to partition");
2637
2638 assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
2639 }
2640
2641 #[test]
2644 fn test_partition_greedy_two_cliques() {
2645 let graph = create_two_cliques();
2648
2649 let result = partition_greedy(&graph, None, 100)
2650 .expect("Failed to partition");
2651
2652 assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
2653 let total_assigned = result.partitions[0].len() + result.partitions[1].len();
2656 assert!(total_assigned >= 3, "Should assign at least some nodes, got {}", total_assigned);
2657 }
2658
2659 #[test]
2660 fn test_partition_greedy_cut_size_decreases() {
2661 let graph = create_binary_tree(2);
2664
2665 let initial = partition_bfs_level(&graph, vec![], 2).expect("Failed");
2667 let initial_cut_size = initial.cut_edges.len();
2668
2669 let result = partition_greedy(&graph, None, 100)
2671 .expect("Failed to partition");
2672
2673 assert!(
2674 result.cut_edges.len() <= initial_cut_size,
2675 "Greedy should not increase cut size"
2676 );
2677 }
2678
2679 #[test]
2680 fn test_partition_greedy_with_initial_partition() {
2681 let graph = create_path_graph();
2684
2685 let initial_partition = vec![
2686 graph.all_entity_ids().unwrap().into_iter().take(2).collect(),
2687 graph.all_entity_ids().unwrap().into_iter().skip(2).collect(),
2688 ];
2689
2690 let result = partition_greedy(&graph, Some(initial_partition), 10)
2691 .expect("Failed to partition");
2692
2693 assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
2694 }
2695
2696 #[test]
2699 fn test_partition_kway_balanced() {
2700 let graph = create_path_graph(); let large_graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
2706 for i in 0..10 {
2707 let entity = GraphEntity {
2708 id: 0,
2709 kind: "node".to_string(),
2710 name: format!("node_{}", i),
2711 file_path: Some(format!("node_{}.rs", i)),
2712 data: serde_json::json!({"index": i}),
2713 };
2714 large_graph.insert_entity(&entity).expect("Failed to insert entity");
2715 }
2716
2717 let entity_ids: Vec<i64> = large_graph.list_entity_ids().expect("Failed to get IDs");
2718 for i in 0..entity_ids.len().saturating_sub(1) {
2719 let edge = GraphEdge {
2720 id: 0,
2721 from_id: entity_ids[i],
2722 to_id: entity_ids[i + 1],
2723 edge_type: "next".to_string(),
2724 data: serde_json::json!({}),
2725 };
2726 large_graph.insert_edge(&edge).expect("Failed to insert edge");
2727 }
2728
2729 let config = PartitionConfig {
2730 k: 2,
2731 max_size: 5,
2732 max_imbalance: 0.1,
2733 seeds: None,
2734 };
2735
2736 let result = partition_kway(&large_graph, &config)
2737 .expect("Failed to partition");
2738
2739 assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
2740 let total: usize = result.partitions.iter().map(|p| p.len()).sum();
2743 assert_eq!(total, 10, "All 10 nodes should be assigned");
2744 }
2745
2746 #[test]
2747 fn test_partition_kway_three_way() {
2748 let graph = create_path_graph(); let config = PartitionConfig {
2753 k: 3,
2754 max_size: 4,
2755 max_imbalance: 0.5, seeds: None,
2757 };
2758
2759 let result = partition_kway(&graph, &config)
2760 .expect("Failed to partition");
2761
2762 assert_eq!(result.partitions.len(), 3, "Should have 3 partitions");
2763 let total_assigned: usize = result.partitions.iter().map(|p| p.len()).sum();
2765 assert_eq!(total_assigned, 5, "All 5 nodes should be assigned");
2766 }
2767
2768 #[test]
2769 fn test_partition_kway_with_isolated_node() {
2770 let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
2773
2774 for i in 0..3 {
2776 let entity = GraphEntity {
2777 id: 0,
2778 kind: "node".to_string(),
2779 name: format!("node_{}", i),
2780 file_path: Some(format!("node_{}.rs", i)),
2781 data: serde_json::json!({}),
2782 };
2783 graph.insert_entity(&entity).expect("Failed to insert entity");
2784 }
2785
2786 let isolated = GraphEntity {
2788 id: 0,
2789 kind: "node".to_string(),
2790 name: "isolated".to_string(),
2791 file_path: Some("isolated.rs".to_string()),
2792 data: serde_json::json!({}),
2793 };
2794 graph.insert_entity(&isolated).expect("Failed to insert entity");
2795
2796 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
2797
2798 for i in 0..2 {
2800 let edge = GraphEdge {
2801 id: 0,
2802 from_id: entity_ids[i],
2803 to_id: entity_ids[i + 1],
2804 edge_type: "next".to_string(),
2805 data: serde_json::json!({}),
2806 };
2807 graph.insert_edge(&edge).expect("Failed to insert edge");
2808 }
2809
2810 let config = PartitionConfig {
2811 k: 2,
2812 max_size: usize::MAX,
2813 max_imbalance: 0.1,
2814 seeds: None,
2815 };
2816
2817 let result = partition_kway(&graph, &config)
2818 .expect("Failed to partition");
2819
2820 let total_assigned: usize = result.partitions.iter().map(|p| p.len()).sum();
2822 assert_eq!(total_assigned, 4, "All nodes including isolated should be assigned");
2823 }
2824
2825 #[test]
2826 fn test_partition_kway_with_seeds() {
2827 let graph = create_path_graph();
2830 let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
2831
2832 let config = PartitionConfig {
2833 k: 2,
2834 max_size: usize::MAX,
2835 max_imbalance: 0.1,
2836 seeds: Some(vec![entity_ids[0], entity_ids[4]]),
2837 };
2838
2839 let result = partition_kway(&graph, &config)
2840 .expect("Failed to partition");
2841
2842 assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
2843 let p0 = result.node_to_partition.get(&entity_ids[0]);
2845 let p4 = result.node_to_partition.get(&entity_ids[4]);
2846 assert!(p0.is_some() && p4.is_some(), "All seeds should be assigned");
2847 assert_ne!(p0, p4, "Seeds should be in different partitions");
2848 }
2849
2850 #[test]
2851 fn test_partition_kway_invalid_k() {
2852 let graph = create_path_graph();
2855
2856 let config = PartitionConfig {
2857 k: 1, ..Default::default()
2859 };
2860
2861 let result = partition_kway(&graph, &config);
2862 assert!(result.is_err(), "Should return error for k < 2");
2863 }
2864
2865 #[test]
2866 fn test_partition_kway_with_progress_matches() {
2867 use crate::progress::NoProgress;
2870
2871 let graph = create_path_graph();
2872 let config = PartitionConfig::default();
2873
2874 let progress = NoProgress;
2875 let result_with = partition_kway_with_progress(&graph, &config, &progress)
2876 .expect("Failed");
2877 let result_without = partition_kway(&graph, &config)
2878 .expect("Failed");
2879
2880 assert_eq!(
2881 result_with.partitions.len(),
2882 result_without.partitions.len(),
2883 "Partition count should match"
2884 );
2885
2886 let total_with: usize = result_with.partitions.iter().map(|p| p.len()).sum();
2887 let total_without: usize = result_without.partitions.iter().map(|p| p.len()).sum();
2888 assert_eq!(total_with, total_without, "Total assigned nodes should match");
2889 }
2890
2891 #[test]
2892 fn test_partition_result_consistency() {
2893 let graph = create_binary_tree(2);
2896
2897 let result = partition_bfs_level(&graph, vec![], 3)
2898 .expect("Failed to partition");
2899
2900 for (pidx, partition) in result.partitions.iter().enumerate() {
2902 for &node in partition {
2903 assert_eq!(
2904 result.node_to_partition.get(&node),
2905 Some(&pidx),
2906 "Node {} should map to partition {}",
2907 node,
2908 pidx
2909 );
2910 }
2911 }
2912 }
2913
2914 #[test]
2915 fn test_partition_empty_graph() {
2916 let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
2919
2920 let result = partition_bfs_level(&graph, vec![], 2)
2921 .expect("Failed to partition");
2922
2923 assert_eq!(result.partitions.len(), 2, "Should have k partitions");
2924 assert!(result.partitions.iter().all(|p| p.is_empty()), "All partitions should be empty");
2925 assert!(result.cut_edges.is_empty(), "No cut edges for empty graph");
2926 }
2927
2928 #[test]
2929 fn test_partition_k_greater_than_nodes() {
2930 let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
2933
2934 for i in 0..3 {
2936 let entity = GraphEntity {
2937 id: 0,
2938 kind: "node".to_string(),
2939 name: format!("node_{}", i),
2940 file_path: Some(format!("node_{}.rs", i)),
2941 data: serde_json::json!({}),
2942 };
2943 graph.insert_entity(&entity).expect("Failed to insert entity");
2944 }
2945
2946 let result = partition_bfs_level(&graph, vec![], 10)
2947 .expect("Failed to partition");
2948
2949 assert_eq!(result.partitions.len(), 10, "Should have 10 partitions");
2950 let non_empty_count = result.partitions.iter().filter(|p| !p.is_empty()).count();
2951 assert_eq!(non_empty_count, 3, "Only 3 partitions should be non-empty");
2952 }
2953}