Skip to main content

ruvector_dag/attention/
parallel_branch.rs

1//! Parallel Branch Attention: Coordinates attention across parallel execution branches
2//!
3//! This mechanism identifies parallel branches in the DAG and distributes attention
4//! to balance workload and minimize synchronization overhead.
5
6use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
7use crate::dag::QueryDag;
8use std::collections::{HashMap, HashSet};
9
10#[derive(Debug, Clone)]
11pub struct ParallelBranchConfig {
12    /// Maximum number of parallel branches to consider
13    pub max_branches: usize,
14    /// Penalty for synchronization between branches
15    pub sync_penalty: f32,
16    /// Weight for branch balance in attention computation
17    pub balance_weight: f32,
18    /// Temperature for softmax
19    pub temperature: f32,
20}
21
22impl Default for ParallelBranchConfig {
23    fn default() -> Self {
24        Self {
25            max_branches: 8,
26            sync_penalty: 0.2,
27            balance_weight: 0.5,
28            temperature: 0.1,
29        }
30    }
31}
32
33pub struct ParallelBranchAttention {
34    config: ParallelBranchConfig,
35}
36
37impl ParallelBranchAttention {
38    pub fn new(config: ParallelBranchConfig) -> Self {
39        Self { config }
40    }
41
42    /// Detect parallel branches (nodes with same parent, no edges between them)
43    fn detect_branches(&self, dag: &QueryDag) -> Vec<Vec<usize>> {
44        let n = dag.node_count();
45        let mut children_of: HashMap<usize, Vec<usize>> = HashMap::new();
46        let mut parents_of: HashMap<usize, Vec<usize>> = HashMap::new();
47
48        // Build parent-child relationships from adjacency
49        for node_id in dag.node_ids() {
50            let children = dag.children(node_id);
51            if !children.is_empty() {
52                for &child in children {
53                    children_of
54                        .entry(node_id)
55                        .or_insert_with(Vec::new)
56                        .push(child);
57                    parents_of
58                        .entry(child)
59                        .or_insert_with(Vec::new)
60                        .push(node_id);
61                }
62            }
63        }
64
65        let mut branches = Vec::new();
66        let mut visited = HashSet::new();
67
68        // For each node, check if its children form parallel branches
69        for node_id in 0..n {
70            if let Some(children) = children_of.get(&node_id) {
71                if children.len() > 1 {
72                    // Check if children are truly parallel (no edges between them)
73                    let mut parallel_group = Vec::new();
74
75                    for &child in children {
76                        if !visited.contains(&child) {
77                            // Check if this child has edges to any siblings
78                            let child_children = dag.children(child);
79                            let has_sibling_edge = children
80                                .iter()
81                                .any(|&other| other != child && child_children.contains(&other));
82
83                            if !has_sibling_edge {
84                                parallel_group.push(child);
85                                visited.insert(child);
86                            }
87                        }
88                    }
89
90                    if parallel_group.len() > 1 {
91                        branches.push(parallel_group);
92                    }
93                }
94            }
95        }
96
97        branches
98    }
99
100    /// Compute branch balance score (lower is better balanced)
101    fn branch_balance(&self, branches: &[Vec<usize>], dag: &QueryDag) -> f32 {
102        if branches.is_empty() {
103            return 1.0;
104        }
105
106        let mut total_variance = 0.0;
107
108        for branch in branches {
109            if branch.len() <= 1 {
110                continue;
111            }
112
113            // Compute costs for each node in the branch
114            let costs: Vec<f64> = branch
115                .iter()
116                .filter_map(|&id| dag.get_node(id).map(|n| n.estimated_cost))
117                .collect();
118
119            if costs.is_empty() {
120                continue;
121            }
122
123            // Compute variance
124            let mean = costs.iter().sum::<f64>() / costs.len() as f64;
125            let variance =
126                costs.iter().map(|&c| (c - mean).powi(2)).sum::<f64>() / costs.len() as f64;
127
128            total_variance += variance as f32;
129        }
130
131        // Normalize by number of branches
132        if branches.is_empty() {
133            1.0
134        } else {
135            (total_variance / branches.len() as f32).sqrt()
136        }
137    }
138
139    /// Compute criticality score for a branch
140    fn branch_criticality(&self, branch: &[usize], dag: &QueryDag) -> f32 {
141        if branch.is_empty() {
142            return 0.0;
143        }
144
145        // Sum of costs in the branch
146        let total_cost: f64 = branch
147            .iter()
148            .filter_map(|&id| dag.get_node(id).map(|n| n.estimated_cost))
149            .sum();
150
151        // Average rows (higher rows = more critical for filtering)
152        let avg_rows: f64 = branch
153            .iter()
154            .filter_map(|&id| dag.get_node(id).map(|n| n.estimated_rows))
155            .sum::<f64>()
156            / branch.len().max(1) as f64;
157
158        // Criticality is high cost + high row count
159        (total_cost * (avg_rows / 1000.0).min(1.0)) as f32
160    }
161
162    /// Compute attention scores based on parallel branch analysis
163    fn compute_branch_attention(&self, dag: &QueryDag, branches: &[Vec<usize>]) -> Vec<f32> {
164        let n = dag.node_count();
165        let mut scores = vec![0.0; n];
166
167        // Base score for nodes not in any branch
168        let base_score = 0.5;
169        for i in 0..n {
170            scores[i] = base_score;
171        }
172
173        // Compute balance metric
174        let balance_penalty = self.branch_balance(branches, dag);
175
176        // Assign scores based on branch criticality
177        for branch in branches {
178            let criticality = self.branch_criticality(branch, dag);
179
180            // Higher criticality = higher attention
181            // Apply balance penalty
182            let branch_score = criticality * (1.0 - self.config.balance_weight * balance_penalty);
183
184            for &node_id in branch {
185                if node_id < n {
186                    scores[node_id] = branch_score;
187                }
188            }
189        }
190
191        // Apply sync penalty to nodes that synchronize branches
192        for from in dag.node_ids() {
193            for &to in dag.children(from) {
194                if from < n && to < n {
195                    // Check if this edge connects different branches
196                    let from_branch = branches.iter().position(|b| b.iter().any(|&x| x == from));
197                    let to_branch = branches.iter().position(|b| b.iter().any(|&x| x == to));
198
199                    if from_branch.is_some() && to_branch.is_some() && from_branch != to_branch {
200                        scores[to] *= 1.0 - self.config.sync_penalty;
201                    }
202                }
203            }
204        }
205
206        // Normalize using softmax
207        let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
208        let exp_sum: f32 = scores
209            .iter()
210            .map(|&s| ((s - max_score) / self.config.temperature).exp())
211            .sum();
212
213        if exp_sum > 0.0 {
214            for score in scores.iter_mut() {
215                *score = ((*score - max_score) / self.config.temperature).exp() / exp_sum;
216            }
217        } else {
218            // Uniform if all scores are too low
219            let uniform = 1.0 / n as f32;
220            scores.fill(uniform);
221        }
222
223        scores
224    }
225}
226
227impl DagAttentionMechanism for ParallelBranchAttention {
228    fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
229        if dag.node_count() == 0 {
230            return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
231        }
232
233        // Step 1: Detect parallel branches
234        let branches = self.detect_branches(dag);
235
236        // Step 2: Compute attention based on branches
237        let scores = self.compute_branch_attention(dag, &branches);
238
239        // Step 3: Build result
240        let mut result = AttentionScores::new(scores)
241            .with_metadata("mechanism".to_string(), "parallel_branch".to_string())
242            .with_metadata("num_branches".to_string(), branches.len().to_string());
243
244        let balance = self.branch_balance(&branches, dag);
245        result
246            .metadata
247            .insert("balance_score".to_string(), format!("{:.4}", balance));
248
249        Ok(result)
250    }
251
252    fn name(&self) -> &'static str {
253        "parallel_branch"
254    }
255
256    fn complexity(&self) -> &'static str {
257        "O(n² + b·n)"
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::dag::{OperatorNode, OperatorType};
265
266    #[test]
267    fn test_detect_branches() {
268        let config = ParallelBranchConfig::default();
269        let attention = ParallelBranchAttention::new(config);
270
271        let mut dag = QueryDag::new();
272        for i in 0..4 {
273            dag.add_node(OperatorNode::new(i, OperatorType::Scan));
274        }
275
276        // Create parallel branches: 0 -> 1, 0 -> 2, 1 -> 3, 2 -> 3
277        dag.add_edge(0, 1).unwrap();
278        dag.add_edge(0, 2).unwrap();
279        dag.add_edge(1, 3).unwrap();
280        dag.add_edge(2, 3).unwrap();
281
282        let branches = attention.detect_branches(&dag);
283        assert!(!branches.is_empty());
284    }
285
286    #[test]
287    fn test_parallel_attention() {
288        let config = ParallelBranchConfig::default();
289        let attention = ParallelBranchAttention::new(config);
290
291        let mut dag = QueryDag::new();
292        for i in 0..3 {
293            let mut node = OperatorNode::new(i, OperatorType::Scan);
294            node.estimated_cost = (i + 1) as f64;
295            dag.add_node(node);
296        }
297        dag.add_edge(0, 1).unwrap();
298        dag.add_edge(0, 2).unwrap();
299
300        let result = attention.forward(&dag).unwrap();
301        assert_eq!(result.scores.len(), 3);
302    }
303}