ruvector_dag/attention/
mincut_gated.rs1use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
4use crate::dag::QueryDag;
5use std::collections::{HashMap, HashSet, VecDeque};
6
7#[derive(Debug, Clone)]
8pub enum FlowCapacity {
9 UnitCapacity,
10 CostBased,
11 RowBased,
12}
13
14#[derive(Debug, Clone)]
15pub struct MinCutConfig {
16 pub gate_threshold: f32,
17 pub flow_capacity: FlowCapacity,
18}
19
20impl Default for MinCutConfig {
21 fn default() -> Self {
22 Self {
23 gate_threshold: 0.5,
24 flow_capacity: FlowCapacity::UnitCapacity,
25 }
26 }
27}
28
29pub struct MinCutGatedAttention {
30 config: MinCutConfig,
31}
32
33impl MinCutGatedAttention {
34 pub fn new(config: MinCutConfig) -> Self {
35 Self { config }
36 }
37
38 pub fn with_defaults() -> Self {
39 Self::new(MinCutConfig::default())
40 }
41
42 fn compute_min_cut(&self, dag: &QueryDag) -> HashSet<usize> {
44 let mut cut_nodes = HashSet::new();
45
46 let mut capacity: HashMap<(usize, usize), f64> = HashMap::new();
48 for node_id in 0..dag.node_count() {
49 if dag.get_node(node_id).is_none() {
50 continue;
51 }
52 for &child in dag.children(node_id) {
53 let cap = match self.config.flow_capacity {
54 FlowCapacity::UnitCapacity => 1.0,
55 FlowCapacity::CostBased => dag
56 .get_node(node_id)
57 .map(|n| n.estimated_cost)
58 .unwrap_or(1.0),
59 FlowCapacity::RowBased => dag
60 .get_node(node_id)
61 .map(|n| n.estimated_rows)
62 .unwrap_or(1.0),
63 };
64 capacity.insert((node_id, child), cap);
65 }
66 }
67
68 let source = match dag.root() {
70 Some(root) => root,
71 None => return cut_nodes,
72 };
73
74 let leaves = dag.leaves();
75 if leaves.is_empty() {
76 return cut_nodes;
77 }
78
79 let sink = leaves[0];
81
82 let mut residual = capacity.clone();
84 #[allow(unused_variables, unused_assignments)]
85 let mut total_flow = 0.0;
86
87 loop {
88 let mut parent: HashMap<usize, usize> = HashMap::new();
90 let mut visited = HashSet::new();
91 let mut queue = VecDeque::new();
92
93 queue.push_back(source);
94 visited.insert(source);
95
96 while let Some(u) = queue.pop_front() {
97 if u == sink {
98 break;
99 }
100
101 for v in dag.children(u) {
102 if !visited.contains(v) && residual.get(&(u, *v)).copied().unwrap_or(0.0) > 0.0
103 {
104 visited.insert(*v);
105 parent.insert(*v, u);
106 queue.push_back(*v);
107 }
108 }
109 }
110
111 if !parent.contains_key(&sink) {
113 break;
114 }
115
116 let mut path_flow = f64::INFINITY;
118 let mut v = sink;
119 while v != source {
120 let u = parent[&v];
121 path_flow = path_flow.min(residual.get(&(u, v)).copied().unwrap_or(0.0));
122 v = u;
123 }
124
125 v = sink;
127 while v != source {
128 let u = parent[&v];
129 *residual.entry((u, v)).or_insert(0.0) -= path_flow;
130 *residual.entry((v, u)).or_insert(0.0) += path_flow;
131 v = u;
132 }
133
134 total_flow += path_flow;
135 }
136
137 let mut reachable = HashSet::new();
139 let mut queue = VecDeque::new();
140 queue.push_back(source);
141 reachable.insert(source);
142
143 while let Some(u) = queue.pop_front() {
144 for &v in dag.children(u) {
145 if !reachable.contains(&v) && residual.get(&(u, v)).copied().unwrap_or(0.0) > 0.0 {
146 reachable.insert(v);
147 queue.push_back(v);
148 }
149 }
150 }
151
152 for node_id in 0..dag.node_count() {
154 if dag.get_node(node_id).is_none() {
155 continue;
156 }
157 for &child in dag.children(node_id) {
158 if reachable.contains(&node_id) && !reachable.contains(&child) {
159 cut_nodes.insert(node_id);
160 cut_nodes.insert(child);
161 }
162 }
163 }
164
165 cut_nodes
166 }
167}
168
169impl DagAttentionMechanism for MinCutGatedAttention {
170 fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
171 if dag.node_count() == 0 {
172 return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
173 }
174
175 let cut_nodes = self.compute_min_cut(dag);
176 let n = dag.node_count();
177 let mut score_vec = vec![0.0; n];
178 let mut total = 0.0f32;
179
180 for node_id in 0..n {
182 if dag.get_node(node_id).is_none() {
183 continue;
184 }
185
186 let is_in_cut = cut_nodes.contains(&node_id);
187
188 let score = if is_in_cut {
189 1.0
191 } else {
192 self.config.gate_threshold
194 };
195
196 score_vec[node_id] = score;
197 total += score;
198 }
199
200 if total > 0.0 {
202 for score in score_vec.iter_mut() {
203 *score /= total;
204 }
205 }
206
207 Ok(AttentionScores::new(score_vec))
208 }
209
210 fn name(&self) -> &'static str {
211 "mincut_gated"
212 }
213
214 fn complexity(&self) -> &'static str {
215 "O(n * e^2)"
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::dag::{OperatorNode, OperatorType};
223
224 #[test]
225 fn test_mincut_gated_attention() {
226 let mut dag = QueryDag::new();
227
228 let id0 = dag.add_node(OperatorNode::seq_scan(0, "table1"));
230 let id1 = dag.add_node(OperatorNode::seq_scan(0, "table2"));
231 let id2 = dag.add_node(OperatorNode::hash_join(0, "id"));
232 let id3 = dag.add_node(OperatorNode::filter(0, "status = 'active'"));
233 let id4 = dag.add_node(OperatorNode::project(0, vec!["name".to_string()]));
234
235 dag.add_edge(id0, id2).unwrap();
237 dag.add_edge(id1, id2).unwrap();
238 dag.add_edge(id2, id3).unwrap();
239 dag.add_edge(id2, id4).unwrap();
240
241 let attention = MinCutGatedAttention::with_defaults();
242 let scores = attention.forward(&dag).unwrap();
243
244 let sum: f32 = scores.scores.iter().sum();
246 assert!((sum - 1.0).abs() < 1e-5);
247
248 for &score in &scores.scores {
250 assert!(score >= 0.0 && score <= 1.0);
251 }
252 }
253}