ruvector_dag/mincut/
engine.rs1use super::local_kcut::LocalKCut;
4use crate::dag::QueryDag;
5use std::collections::{HashMap, HashSet};
6
7#[derive(Debug, Clone)]
8pub struct MinCutConfig {
9 pub epsilon: f32, pub local_search_depth: usize,
11 pub cache_cuts: bool,
12}
13
14impl Default for MinCutConfig {
15 fn default() -> Self {
16 Self {
17 epsilon: 0.1,
18 local_search_depth: 3,
19 cache_cuts: true,
20 }
21 }
22}
23
24#[derive(Debug, Clone)]
26pub struct FlowEdge {
27 pub from: usize,
28 pub to: usize,
29 pub capacity: f64,
30 pub flow: f64,
31}
32
33#[derive(Debug, Clone)]
35pub struct MinCutResult {
36 pub cut_value: f64,
37 pub source_side: HashSet<usize>,
38 pub sink_side: HashSet<usize>,
39 pub cut_edges: Vec<(usize, usize)>,
40}
41
42pub struct DagMinCutEngine {
43 config: MinCutConfig,
44 adjacency: HashMap<usize, Vec<FlowEdge>>,
45 node_count: usize,
46 local_kcut: LocalKCut,
47 cached_cuts: HashMap<(usize, usize), MinCutResult>,
48}
49
50impl DagMinCutEngine {
51 pub fn new(config: MinCutConfig) -> Self {
52 Self {
53 config,
54 adjacency: HashMap::new(),
55 node_count: 0,
56 local_kcut: LocalKCut::new(),
57 cached_cuts: HashMap::new(),
58 }
59 }
60
61 pub fn build_from_dag(&mut self, dag: &QueryDag) {
63 self.adjacency.clear();
64 self.node_count = dag.node_count();
65
66 for node_id in 0..dag.node_count() {
68 if let Some(node) = dag.get_node(node_id) {
69 let capacity = node.estimated_cost.max(1.0);
70
71 for &child_id in dag.children(node_id) {
72 self.add_edge(node_id, child_id, capacity);
73 }
74 }
75 }
76 }
77
78 pub fn add_edge(&mut self, from: usize, to: usize, capacity: f64) {
79 self.adjacency.entry(from).or_default().push(FlowEdge {
80 from,
81 to,
82 capacity,
83 flow: 0.0,
84 });
85 self.adjacency.entry(to).or_default().push(FlowEdge {
87 from: to,
88 to: from,
89 capacity: 0.0,
90 flow: 0.0,
91 });
92
93 self.node_count = self.node_count.max(from + 1).max(to + 1);
94
95 self.cached_cuts.clear();
97 }
98
99 pub fn compute_mincut(&mut self, source: usize, sink: usize) -> MinCutResult {
101 if self.config.cache_cuts {
103 if let Some(cached) = self.cached_cuts.get(&(source, sink)) {
104 return cached.clone();
105 }
106 }
107
108 let result = self.local_kcut.compute(
110 &self.adjacency,
111 source,
112 sink,
113 self.config.local_search_depth,
114 );
115
116 if self.config.cache_cuts {
117 self.cached_cuts.insert((source, sink), result.clone());
118 }
119
120 result
121 }
122
123 pub fn update_edge(&mut self, from: usize, to: usize, new_capacity: f64) {
125 if let Some(edges) = self.adjacency.get_mut(&from) {
126 for edge in edges.iter_mut() {
127 if edge.to == to {
128 edge.capacity = new_capacity;
129 break;
130 }
131 }
132 }
133
134 let keys_to_remove: Vec<(usize, usize)> = self
137 .cached_cuts
138 .keys()
139 .filter(|(s, t)| self.cut_involves_edge(*s, *t, from, to))
140 .copied()
141 .collect();
142
143 for key in keys_to_remove {
144 self.cached_cuts.remove(&key);
145 }
146 }
147
148 fn cut_involves_edge(&self, _source: usize, _sink: usize, _from: usize, _to: usize) -> bool {
149 true
152 }
153
154 pub fn compute_criticality(&mut self, dag: &QueryDag) -> HashMap<usize, f64> {
156 let mut criticality = HashMap::new();
157
158 let leaves = dag.leaves();
159 let root = dag.root();
160
161 if leaves.is_empty() || root.is_none() {
162 return criticality;
163 }
164
165 let root = root.unwrap();
166
167 for node_id in 0..dag.node_count() {
169 if dag.get_node(node_id).is_none() {
170 continue;
171 }
172
173 let cut_with = self.compute_mincut(leaves[0], root);
175
176 for &child in dag.children(node_id) {
178 self.update_edge(node_id, child, f64::INFINITY);
179 }
180
181 let cut_without = self.compute_mincut(leaves[0], root);
182
183 let node = dag.get_node(node_id).unwrap();
185 for &child in dag.children(node_id) {
186 self.update_edge(node_id, child, node.estimated_cost);
187 }
188
189 let crit = (cut_without.cut_value - cut_with.cut_value) / cut_with.cut_value.max(1.0);
191 criticality.insert(node_id, crit.max(0.0));
192 }
193
194 criticality
195 }
196}