ruvector_dag/attention/
parallel_branch.rs1use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
7use crate::dag::QueryDag;
8use std::collections::{HashMap, HashSet};
9
10#[derive(Debug, Clone)]
11pub struct ParallelBranchConfig {
12 pub max_branches: usize,
14 pub sync_penalty: f32,
16 pub balance_weight: f32,
18 pub temperature: f32,
20}
21
22impl Default for ParallelBranchConfig {
23 fn default() -> Self {
24 Self {
25 max_branches: 8,
26 sync_penalty: 0.2,
27 balance_weight: 0.5,
28 temperature: 0.1,
29 }
30 }
31}
32
33pub struct ParallelBranchAttention {
34 config: ParallelBranchConfig,
35}
36
37impl ParallelBranchAttention {
38 pub fn new(config: ParallelBranchConfig) -> Self {
39 Self { config }
40 }
41
42 fn detect_branches(&self, dag: &QueryDag) -> Vec<Vec<usize>> {
44 let n = dag.node_count();
45 let mut children_of: HashMap<usize, Vec<usize>> = HashMap::new();
46 let mut parents_of: HashMap<usize, Vec<usize>> = HashMap::new();
47
48 for node_id in dag.node_ids() {
50 let children = dag.children(node_id);
51 if !children.is_empty() {
52 for &child in children {
53 children_of
54 .entry(node_id)
55 .or_insert_with(Vec::new)
56 .push(child);
57 parents_of
58 .entry(child)
59 .or_insert_with(Vec::new)
60 .push(node_id);
61 }
62 }
63 }
64
65 let mut branches = Vec::new();
66 let mut visited = HashSet::new();
67
68 for node_id in 0..n {
70 if let Some(children) = children_of.get(&node_id) {
71 if children.len() > 1 {
72 let mut parallel_group = Vec::new();
74
75 for &child in children {
76 if !visited.contains(&child) {
77 let child_children = dag.children(child);
79 let has_sibling_edge = children
80 .iter()
81 .any(|&other| other != child && child_children.contains(&other));
82
83 if !has_sibling_edge {
84 parallel_group.push(child);
85 visited.insert(child);
86 }
87 }
88 }
89
90 if parallel_group.len() > 1 {
91 branches.push(parallel_group);
92 }
93 }
94 }
95 }
96
97 branches
98 }
99
100 fn branch_balance(&self, branches: &[Vec<usize>], dag: &QueryDag) -> f32 {
102 if branches.is_empty() {
103 return 1.0;
104 }
105
106 let mut total_variance = 0.0;
107
108 for branch in branches {
109 if branch.len() <= 1 {
110 continue;
111 }
112
113 let costs: Vec<f64> = branch
115 .iter()
116 .filter_map(|&id| dag.get_node(id).map(|n| n.estimated_cost))
117 .collect();
118
119 if costs.is_empty() {
120 continue;
121 }
122
123 let mean = costs.iter().sum::<f64>() / costs.len() as f64;
125 let variance =
126 costs.iter().map(|&c| (c - mean).powi(2)).sum::<f64>() / costs.len() as f64;
127
128 total_variance += variance as f32;
129 }
130
131 if branches.is_empty() {
133 1.0
134 } else {
135 (total_variance / branches.len() as f32).sqrt()
136 }
137 }
138
139 fn branch_criticality(&self, branch: &[usize], dag: &QueryDag) -> f32 {
141 if branch.is_empty() {
142 return 0.0;
143 }
144
145 let total_cost: f64 = branch
147 .iter()
148 .filter_map(|&id| dag.get_node(id).map(|n| n.estimated_cost))
149 .sum();
150
151 let avg_rows: f64 = branch
153 .iter()
154 .filter_map(|&id| dag.get_node(id).map(|n| n.estimated_rows))
155 .sum::<f64>()
156 / branch.len().max(1) as f64;
157
158 (total_cost * (avg_rows / 1000.0).min(1.0)) as f32
160 }
161
162 fn compute_branch_attention(&self, dag: &QueryDag, branches: &[Vec<usize>]) -> Vec<f32> {
164 let n = dag.node_count();
165 let mut scores = vec![0.0; n];
166
167 let base_score = 0.5;
169 for i in 0..n {
170 scores[i] = base_score;
171 }
172
173 let balance_penalty = self.branch_balance(branches, dag);
175
176 for branch in branches {
178 let criticality = self.branch_criticality(branch, dag);
179
180 let branch_score = criticality * (1.0 - self.config.balance_weight * balance_penalty);
183
184 for &node_id in branch {
185 if node_id < n {
186 scores[node_id] = branch_score;
187 }
188 }
189 }
190
191 for from in dag.node_ids() {
193 for &to in dag.children(from) {
194 if from < n && to < n {
195 let from_branch = branches.iter().position(|b| b.iter().any(|&x| x == from));
197 let to_branch = branches.iter().position(|b| b.iter().any(|&x| x == to));
198
199 if from_branch.is_some() && to_branch.is_some() && from_branch != to_branch {
200 scores[to] *= 1.0 - self.config.sync_penalty;
201 }
202 }
203 }
204 }
205
206 let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
208 let exp_sum: f32 = scores
209 .iter()
210 .map(|&s| ((s - max_score) / self.config.temperature).exp())
211 .sum();
212
213 if exp_sum > 0.0 {
214 for score in scores.iter_mut() {
215 *score = ((*score - max_score) / self.config.temperature).exp() / exp_sum;
216 }
217 } else {
218 let uniform = 1.0 / n as f32;
220 scores.fill(uniform);
221 }
222
223 scores
224 }
225}
226
227impl DagAttentionMechanism for ParallelBranchAttention {
228 fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
229 if dag.node_count() == 0 {
230 return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
231 }
232
233 let branches = self.detect_branches(dag);
235
236 let scores = self.compute_branch_attention(dag, &branches);
238
239 let mut result = AttentionScores::new(scores)
241 .with_metadata("mechanism".to_string(), "parallel_branch".to_string())
242 .with_metadata("num_branches".to_string(), branches.len().to_string());
243
244 let balance = self.branch_balance(&branches, dag);
245 result
246 .metadata
247 .insert("balance_score".to_string(), format!("{:.4}", balance));
248
249 Ok(result)
250 }
251
252 fn name(&self) -> &'static str {
253 "parallel_branch"
254 }
255
256 fn complexity(&self) -> &'static str {
257 "O(n² + b·n)"
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::dag::{OperatorNode, OperatorType};
265
266 #[test]
267 fn test_detect_branches() {
268 let config = ParallelBranchConfig::default();
269 let attention = ParallelBranchAttention::new(config);
270
271 let mut dag = QueryDag::new();
272 for i in 0..4 {
273 dag.add_node(OperatorNode::new(i, OperatorType::Scan));
274 }
275
276 dag.add_edge(0, 1).unwrap();
278 dag.add_edge(0, 2).unwrap();
279 dag.add_edge(1, 3).unwrap();
280 dag.add_edge(2, 3).unwrap();
281
282 let branches = attention.detect_branches(&dag);
283 assert!(!branches.is_empty());
284 }
285
286 #[test]
287 fn test_parallel_attention() {
288 let config = ParallelBranchConfig::default();
289 let attention = ParallelBranchAttention::new(config);
290
291 let mut dag = QueryDag::new();
292 for i in 0..3 {
293 let mut node = OperatorNode::new(i, OperatorType::Scan);
294 node.estimated_cost = (i + 1) as f64;
295 dag.add_node(node);
296 }
297 dag.add_edge(0, 1).unwrap();
298 dag.add_edge(0, 2).unwrap();
299
300 let result = attention.forward(&dag).unwrap();
301 assert_eq!(result.scores.len(), 3);
302 }
303}