Skip to main content

ruvector_dag/attention/
mincut_gated.rs

1//! MinCut Gated Attention: Gates attention by graph cut criticality
2
3use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
4use crate::dag::QueryDag;
5use std::collections::{HashMap, HashSet, VecDeque};
6
7#[derive(Debug, Clone)]
8pub enum FlowCapacity {
9    UnitCapacity,
10    CostBased,
11    RowBased,
12}
13
14#[derive(Debug, Clone)]
15pub struct MinCutConfig {
16    pub gate_threshold: f32,
17    pub flow_capacity: FlowCapacity,
18}
19
20impl Default for MinCutConfig {
21    fn default() -> Self {
22        Self {
23            gate_threshold: 0.5,
24            flow_capacity: FlowCapacity::UnitCapacity,
25        }
26    }
27}
28
29pub struct MinCutGatedAttention {
30    config: MinCutConfig,
31}
32
33impl MinCutGatedAttention {
34    pub fn new(config: MinCutConfig) -> Self {
35        Self { config }
36    }
37
38    pub fn with_defaults() -> Self {
39        Self::new(MinCutConfig::default())
40    }
41
42    /// Compute min-cut between leaves and root using Ford-Fulkerson
43    fn compute_min_cut(&self, dag: &QueryDag) -> HashSet<usize> {
44        let mut cut_nodes = HashSet::new();
45
46        // Build capacity matrix from the DAG structure
47        let mut capacity: HashMap<(usize, usize), f64> = HashMap::new();
48        for node_id in 0..dag.node_count() {
49            if dag.get_node(node_id).is_none() {
50                continue;
51            }
52            for &child in dag.children(node_id) {
53                let cap = match self.config.flow_capacity {
54                    FlowCapacity::UnitCapacity => 1.0,
55                    FlowCapacity::CostBased => dag
56                        .get_node(node_id)
57                        .map(|n| n.estimated_cost)
58                        .unwrap_or(1.0),
59                    FlowCapacity::RowBased => dag
60                        .get_node(node_id)
61                        .map(|n| n.estimated_rows)
62                        .unwrap_or(1.0),
63                };
64                capacity.insert((node_id, child), cap);
65            }
66        }
67
68        // Find source (root) and sink (any leaf)
69        let source = match dag.root() {
70            Some(root) => root,
71            None => return cut_nodes,
72        };
73
74        let leaves = dag.leaves();
75        if leaves.is_empty() {
76            return cut_nodes;
77        }
78
79        // Use first leaf as sink
80        let sink = leaves[0];
81
82        // Ford-Fulkerson to find max flow
83        let mut residual = capacity.clone();
84        #[allow(unused_variables, unused_assignments)]
85        let mut total_flow = 0.0;
86
87        loop {
88            // BFS to find augmenting path
89            let mut parent: HashMap<usize, usize> = HashMap::new();
90            let mut visited = HashSet::new();
91            let mut queue = VecDeque::new();
92
93            queue.push_back(source);
94            visited.insert(source);
95
96            while let Some(u) = queue.pop_front() {
97                if u == sink {
98                    break;
99                }
100
101                for v in dag.children(u) {
102                    if !visited.contains(v) && residual.get(&(u, *v)).copied().unwrap_or(0.0) > 0.0
103                    {
104                        visited.insert(*v);
105                        parent.insert(*v, u);
106                        queue.push_back(*v);
107                    }
108                }
109            }
110
111            // No augmenting path found
112            if !parent.contains_key(&sink) {
113                break;
114            }
115
116            // Find minimum capacity along the path
117            let mut path_flow = f64::INFINITY;
118            let mut v = sink;
119            while v != source {
120                let u = parent[&v];
121                path_flow = path_flow.min(residual.get(&(u, v)).copied().unwrap_or(0.0));
122                v = u;
123            }
124
125            // Update residual capacities
126            v = sink;
127            while v != source {
128                let u = parent[&v];
129                *residual.entry((u, v)).or_insert(0.0) -= path_flow;
130                *residual.entry((v, u)).or_insert(0.0) += path_flow;
131                v = u;
132            }
133
134            total_flow += path_flow;
135        }
136
137        // Find nodes reachable from source in residual graph
138        let mut reachable = HashSet::new();
139        let mut queue = VecDeque::new();
140        queue.push_back(source);
141        reachable.insert(source);
142
143        while let Some(u) = queue.pop_front() {
144            for &v in dag.children(u) {
145                if !reachable.contains(&v) && residual.get(&(u, v)).copied().unwrap_or(0.0) > 0.0 {
146                    reachable.insert(v);
147                    queue.push_back(v);
148                }
149            }
150        }
151
152        // Nodes in the cut are those with edges crossing from reachable to non-reachable
153        for node_id in 0..dag.node_count() {
154            if dag.get_node(node_id).is_none() {
155                continue;
156            }
157            for &child in dag.children(node_id) {
158                if reachable.contains(&node_id) && !reachable.contains(&child) {
159                    cut_nodes.insert(node_id);
160                    cut_nodes.insert(child);
161                }
162            }
163        }
164
165        cut_nodes
166    }
167}
168
169impl DagAttentionMechanism for MinCutGatedAttention {
170    fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
171        if dag.node_count() == 0 {
172            return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
173        }
174
175        let cut_nodes = self.compute_min_cut(dag);
176        let n = dag.node_count();
177        let mut score_vec = vec![0.0; n];
178        let mut total = 0.0f32;
179
180        // Gate attention based on whether node is in cut
181        for node_id in 0..n {
182            if dag.get_node(node_id).is_none() {
183                continue;
184            }
185
186            let is_in_cut = cut_nodes.contains(&node_id);
187
188            let score = if is_in_cut {
189                // Nodes in the cut are critical bottlenecks
190                1.0
191            } else {
192                // Other nodes get reduced attention
193                self.config.gate_threshold
194            };
195
196            score_vec[node_id] = score;
197            total += score;
198        }
199
200        // Normalize to sum to 1
201        if total > 0.0 {
202            for score in score_vec.iter_mut() {
203                *score /= total;
204            }
205        }
206
207        Ok(AttentionScores::new(score_vec))
208    }
209
210    fn name(&self) -> &'static str {
211        "mincut_gated"
212    }
213
214    fn complexity(&self) -> &'static str {
215        "O(n * e^2)"
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::dag::{OperatorNode, OperatorType};
223
224    #[test]
225    fn test_mincut_gated_attention() {
226        let mut dag = QueryDag::new();
227
228        // Create a simple bottleneck DAG
229        let id0 = dag.add_node(OperatorNode::seq_scan(0, "table1"));
230        let id1 = dag.add_node(OperatorNode::seq_scan(0, "table2"));
231        let id2 = dag.add_node(OperatorNode::hash_join(0, "id"));
232        let id3 = dag.add_node(OperatorNode::filter(0, "status = 'active'"));
233        let id4 = dag.add_node(OperatorNode::project(0, vec!["name".to_string()]));
234
235        // Create bottleneck at node id2
236        dag.add_edge(id0, id2).unwrap();
237        dag.add_edge(id1, id2).unwrap();
238        dag.add_edge(id2, id3).unwrap();
239        dag.add_edge(id2, id4).unwrap();
240
241        let attention = MinCutGatedAttention::with_defaults();
242        let scores = attention.forward(&dag).unwrap();
243
244        // Check normalization
245        let sum: f32 = scores.scores.iter().sum();
246        assert!((sum - 1.0).abs() < 1e-5);
247
248        // All scores should be in [0, 1]
249        for &score in &scores.scores {
250            assert!(score >= 0.0 && score <= 1.0);
251        }
252    }
253}