ruvector_dag/attention/
critical_path.rs1use super::{AttentionError, AttentionScores, DagAttention};
4use crate::dag::QueryDag;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
8pub struct CriticalPathConfig {
9 pub path_weight: f32,
10 pub branch_penalty: f32,
11}
12
13impl Default for CriticalPathConfig {
14 fn default() -> Self {
15 Self {
16 path_weight: 2.0,
17 branch_penalty: 0.5,
18 }
19 }
20}
21
22pub struct CriticalPathAttention {
23 config: CriticalPathConfig,
24 critical_path: Vec<usize>,
25}
26
27impl CriticalPathAttention {
28 pub fn new(config: CriticalPathConfig) -> Self {
29 Self {
30 config,
31 critical_path: Vec::new(),
32 }
33 }
34
35 pub fn with_defaults() -> Self {
36 Self::new(CriticalPathConfig::default())
37 }
38
39 fn compute_critical_path(&self, dag: &QueryDag) -> Vec<usize> {
41 let mut longest_path: HashMap<usize, (f64, Vec<usize>)> = HashMap::new();
42
43 for &leaf in &dag.leaves() {
45 if let Some(node) = dag.get_node(leaf) {
46 longest_path.insert(leaf, (node.estimated_cost, vec![leaf]));
47 }
48 }
49
50 if let Ok(topo_order) = dag.topological_sort() {
52 for &node_id in topo_order.iter().rev() {
53 let node = match dag.get_node(node_id) {
54 Some(n) => n,
55 None => continue,
56 };
57
58 let mut max_cost = node.estimated_cost;
59 let mut max_path = vec![node_id];
60
61 for &child in dag.children(node_id) {
63 if let Some(&(child_cost, ref child_path)) = longest_path.get(&child) {
64 let total_cost = node.estimated_cost + child_cost;
65 if total_cost > max_cost {
66 max_cost = total_cost;
67 max_path = vec![node_id];
68 max_path.extend(child_path);
69 }
70 }
71 }
72
73 longest_path.insert(node_id, (max_cost, max_path));
74 }
75 }
76
77 longest_path
79 .into_iter()
80 .max_by(|a, b| {
81 a.1 .0
82 .partial_cmp(&b.1 .0)
83 .unwrap_or(std::cmp::Ordering::Equal)
84 })
85 .map(|(_, (_, path))| path)
86 .unwrap_or_default()
87 }
88}
89
90impl DagAttention for CriticalPathAttention {
91 fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
92 if dag.node_count() == 0 {
93 return Err(AttentionError::EmptyDag);
94 }
95
96 let critical = self.compute_critical_path(dag);
97 let mut scores = HashMap::new();
98 let mut total = 0.0f32;
99
100 let node_ids: Vec<usize> = (0..dag.node_count()).collect();
102 for node_id in node_ids {
103 if dag.get_node(node_id).is_none() {
104 continue;
105 }
106
107 let is_on_critical_path = critical.contains(&node_id);
108 let num_children = dag.children(node_id).len();
109
110 let mut score = if is_on_critical_path {
111 self.config.path_weight
112 } else {
113 1.0
114 };
115
116 if num_children > 1 {
118 score *= 1.0 + (num_children as f32 - 1.0) * self.config.branch_penalty;
119 }
120
121 scores.insert(node_id, score);
122 total += score;
123 }
124
125 if total > 0.0 {
127 for score in scores.values_mut() {
128 *score /= total;
129 }
130 }
131
132 Ok(scores)
133 }
134
135 fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>) {
136 self.critical_path = self.compute_critical_path(dag);
139
140 if !execution_times.is_empty() {
142 let max_time = execution_times.values().fold(0.0f64, |a, &b| a.max(b));
143 let avg_time: f64 =
144 execution_times.values().sum::<f64>() / execution_times.len() as f64;
145
146 if max_time > 0.0 && avg_time > 0.0 {
147 let variance_ratio = max_time / avg_time;
149 if variance_ratio > 2.0 {
150 self.config.path_weight = (self.config.path_weight * 1.1).min(5.0);
151 }
152 }
153 }
154 }
155
156 fn name(&self) -> &'static str {
157 "critical_path"
158 }
159
160 fn complexity(&self) -> &'static str {
161 "O(n + e)"
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::dag::{OperatorNode, OperatorType};
169
170 #[test]
171 fn test_critical_path_attention() {
172 let mut dag = QueryDag::new();
173
174 let id0 =
176 dag.add_node(OperatorNode::seq_scan(0, "large_table").with_estimates(10000.0, 10.0));
177 let id1 =
178 dag.add_node(OperatorNode::filter(0, "status = 'active'").with_estimates(1000.0, 1.0));
179 let id2 = dag.add_node(OperatorNode::hash_join(0, "user_id").with_estimates(5000.0, 5.0));
180
181 dag.add_edge(id0, id2).unwrap();
182 dag.add_edge(id1, id2).unwrap();
183
184 let attention = CriticalPathAttention::with_defaults();
185 let scores = attention.forward(&dag).unwrap();
186
187 let sum: f32 = scores.values().sum();
189 assert!((sum - 1.0).abs() < 1e-5);
190
191 let critical = attention.compute_critical_path(&dag);
193 for &node_id in &critical {
194 let score = scores.get(&node_id).unwrap();
195 assert!(*score > 0.0);
196 }
197 }
198}