Skip to main content

ruvector_dag/attention/
topological.rs

1//! Topological Attention: Respects DAG ordering with depth-based decay
2
3use super::{AttentionError, AttentionScores, DagAttention};
4use crate::dag::QueryDag;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
8pub struct TopologicalConfig {
9    pub decay_factor: f32, // 0.9 default
10    pub max_depth: usize,  // 10 default
11}
12
13impl Default for TopologicalConfig {
14    fn default() -> Self {
15        Self {
16            decay_factor: 0.9,
17            max_depth: 10,
18        }
19    }
20}
21
22pub struct TopologicalAttention {
23    config: TopologicalConfig,
24}
25
26impl TopologicalAttention {
27    pub fn new(config: TopologicalConfig) -> Self {
28        Self { config }
29    }
30
31    pub fn with_defaults() -> Self {
32        Self::new(TopologicalConfig::default())
33    }
34}
35
36impl DagAttention for TopologicalAttention {
37    fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
38        if dag.node_count() == 0 {
39            return Err(AttentionError::EmptyDag);
40        }
41
42        let depths = dag.compute_depths();
43        let max_depth = depths.values().max().copied().unwrap_or(0);
44
45        let mut scores = HashMap::new();
46        let mut total = 0.0f32;
47
48        for (&node_id, &depth) in &depths {
49            // Higher attention for nodes closer to root (higher depth from leaves)
50            let normalized_depth = depth as f32 / (max_depth.max(1) as f32);
51            let score = self.config.decay_factor.powf(1.0 - normalized_depth);
52            scores.insert(node_id, score);
53            total += score;
54        }
55
56        // Normalize to sum to 1
57        if total > 0.0 {
58            for score in scores.values_mut() {
59                *score /= total;
60            }
61        }
62
63        Ok(scores)
64    }
65
66    fn update(&mut self, _dag: &QueryDag, _times: &HashMap<usize, f64>) {
67        // Topological attention is static, no updates needed
68    }
69
70    fn name(&self) -> &'static str {
71        "topological"
72    }
73
74    fn complexity(&self) -> &'static str {
75        "O(n)"
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use crate::dag::{OperatorNode, OperatorType};
83
84    #[test]
85    fn test_topological_attention() {
86        let mut dag = QueryDag::new();
87
88        // Create a simple DAG: 0 -> 1 -> 2
89        let id0 = dag.add_node(OperatorNode::seq_scan(0, "users").with_estimates(100.0, 1.0));
90        let id1 = dag.add_node(OperatorNode::filter(0, "age > 18").with_estimates(50.0, 1.0));
91        let id2 = dag
92            .add_node(OperatorNode::project(0, vec!["name".to_string()]).with_estimates(50.0, 1.0));
93
94        dag.add_edge(id0, id1).unwrap();
95        dag.add_edge(id1, id2).unwrap();
96
97        let attention = TopologicalAttention::with_defaults();
98        let scores = attention.forward(&dag).unwrap();
99
100        // Check that scores sum to ~1.0
101        let sum: f32 = scores.values().sum();
102        assert!((sum - 1.0).abs() < 1e-5);
103
104        // All scores should be in [0, 1]
105        for &score in scores.values() {
106            assert!(score >= 0.0 && score <= 1.0);
107        }
108    }
109}