Skip to main content

ruvector_dag/attention/
critical_path.rs

1//! Critical Path Attention: Focuses on bottleneck nodes
2
3use super::{AttentionError, AttentionScores, DagAttention};
4use crate::dag::QueryDag;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
8pub struct CriticalPathConfig {
9    pub path_weight: f32,
10    pub branch_penalty: f32,
11}
12
13impl Default for CriticalPathConfig {
14    fn default() -> Self {
15        Self {
16            path_weight: 2.0,
17            branch_penalty: 0.5,
18        }
19    }
20}
21
22pub struct CriticalPathAttention {
23    config: CriticalPathConfig,
24    critical_path: Vec<usize>,
25}
26
27impl CriticalPathAttention {
28    pub fn new(config: CriticalPathConfig) -> Self {
29        Self {
30            config,
31            critical_path: Vec::new(),
32        }
33    }
34
35    pub fn with_defaults() -> Self {
36        Self::new(CriticalPathConfig::default())
37    }
38
39    /// Compute the critical path (longest path by cost)
40    fn compute_critical_path(&self, dag: &QueryDag) -> Vec<usize> {
41        let mut longest_path: HashMap<usize, (f64, Vec<usize>)> = HashMap::new();
42
43        // Initialize leaves
44        for &leaf in &dag.leaves() {
45            if let Some(node) = dag.get_node(leaf) {
46                longest_path.insert(leaf, (node.estimated_cost, vec![leaf]));
47            }
48        }
49
50        // Process nodes in reverse topological order
51        if let Ok(topo_order) = dag.topological_sort() {
52            for &node_id in topo_order.iter().rev() {
53                let node = match dag.get_node(node_id) {
54                    Some(n) => n,
55                    None => continue,
56                };
57
58                let mut max_cost = node.estimated_cost;
59                let mut max_path = vec![node_id];
60
61                // Check all children
62                for &child in dag.children(node_id) {
63                    if let Some(&(child_cost, ref child_path)) = longest_path.get(&child) {
64                        let total_cost = node.estimated_cost + child_cost;
65                        if total_cost > max_cost {
66                            max_cost = total_cost;
67                            max_path = vec![node_id];
68                            max_path.extend(child_path);
69                        }
70                    }
71                }
72
73                longest_path.insert(node_id, (max_cost, max_path));
74            }
75        }
76
77        // Find the path with maximum cost
78        longest_path
79            .into_iter()
80            .max_by(|a, b| {
81                a.1 .0
82                    .partial_cmp(&b.1 .0)
83                    .unwrap_or(std::cmp::Ordering::Equal)
84            })
85            .map(|(_, (_, path))| path)
86            .unwrap_or_default()
87    }
88}
89
90impl DagAttention for CriticalPathAttention {
91    fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
92        if dag.node_count() == 0 {
93            return Err(AttentionError::EmptyDag);
94        }
95
96        let critical = self.compute_critical_path(dag);
97        let mut scores = HashMap::new();
98        let mut total = 0.0f32;
99
100        // Assign higher attention to nodes on critical path
101        let node_ids: Vec<usize> = (0..dag.node_count()).collect();
102        for node_id in node_ids {
103            if dag.get_node(node_id).is_none() {
104                continue;
105            }
106
107            let is_on_critical_path = critical.contains(&node_id);
108            let num_children = dag.children(node_id).len();
109
110            let mut score = if is_on_critical_path {
111                self.config.path_weight
112            } else {
113                1.0
114            };
115
116            // Apply branch penalty for nodes with many children (potential bottlenecks)
117            if num_children > 1 {
118                score *= 1.0 + (num_children as f32 - 1.0) * self.config.branch_penalty;
119            }
120
121            scores.insert(node_id, score);
122            total += score;
123        }
124
125        // Normalize to sum to 1
126        if total > 0.0 {
127            for score in scores.values_mut() {
128                *score /= total;
129            }
130        }
131
132        Ok(scores)
133    }
134
135    fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>) {
136        // Recompute critical path based on actual execution times
137        // For now, we use the static cost-based approach
138        self.critical_path = self.compute_critical_path(dag);
139
140        // Could adjust path_weight based on execution time variance
141        if !execution_times.is_empty() {
142            let max_time = execution_times.values().fold(0.0f64, |a, &b| a.max(b));
143            let avg_time: f64 =
144                execution_times.values().sum::<f64>() / execution_times.len() as f64;
145
146            if max_time > 0.0 && avg_time > 0.0 {
147                // Increase path weight if there's high variance
148                let variance_ratio = max_time / avg_time;
149                if variance_ratio > 2.0 {
150                    self.config.path_weight = (self.config.path_weight * 1.1).min(5.0);
151                }
152            }
153        }
154    }
155
156    fn name(&self) -> &'static str {
157        "critical_path"
158    }
159
160    fn complexity(&self) -> &'static str {
161        "O(n + e)"
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::dag::{OperatorNode, OperatorType};
169
170    #[test]
171    fn test_critical_path_attention() {
172        let mut dag = QueryDag::new();
173
174        // Create a DAG with different costs
175        let id0 =
176            dag.add_node(OperatorNode::seq_scan(0, "large_table").with_estimates(10000.0, 10.0));
177        let id1 =
178            dag.add_node(OperatorNode::filter(0, "status = 'active'").with_estimates(1000.0, 1.0));
179        let id2 = dag.add_node(OperatorNode::hash_join(0, "user_id").with_estimates(5000.0, 5.0));
180
181        dag.add_edge(id0, id2).unwrap();
182        dag.add_edge(id1, id2).unwrap();
183
184        let attention = CriticalPathAttention::with_defaults();
185        let scores = attention.forward(&dag).unwrap();
186
187        // Check normalization
188        let sum: f32 = scores.values().sum();
189        assert!((sum - 1.0).abs() < 1e-5);
190
191        // Nodes on critical path should have higher attention
192        let critical = attention.compute_critical_path(&dag);
193        for &node_id in &critical {
194            let score = scores.get(&node_id).unwrap();
195            assert!(*score > 0.0);
196        }
197    }
198}