ruvector_dag/attention/
topological.rs1use super::{AttentionError, AttentionScores, DagAttention};
4use crate::dag::QueryDag;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
8pub struct TopologicalConfig {
9 pub decay_factor: f32, pub max_depth: usize, }
12
13impl Default for TopologicalConfig {
14 fn default() -> Self {
15 Self {
16 decay_factor: 0.9,
17 max_depth: 10,
18 }
19 }
20}
21
22pub struct TopologicalAttention {
23 config: TopologicalConfig,
24}
25
26impl TopologicalAttention {
27 pub fn new(config: TopologicalConfig) -> Self {
28 Self { config }
29 }
30
31 pub fn with_defaults() -> Self {
32 Self::new(TopologicalConfig::default())
33 }
34}
35
36impl DagAttention for TopologicalAttention {
37 fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
38 if dag.node_count() == 0 {
39 return Err(AttentionError::EmptyDag);
40 }
41
42 let depths = dag.compute_depths();
43 let max_depth = depths.values().max().copied().unwrap_or(0);
44
45 let mut scores = HashMap::new();
46 let mut total = 0.0f32;
47
48 for (&node_id, &depth) in &depths {
49 let normalized_depth = depth as f32 / (max_depth.max(1) as f32);
51 let score = self.config.decay_factor.powf(1.0 - normalized_depth);
52 scores.insert(node_id, score);
53 total += score;
54 }
55
56 if total > 0.0 {
58 for score in scores.values_mut() {
59 *score /= total;
60 }
61 }
62
63 Ok(scores)
64 }
65
66 fn update(&mut self, _dag: &QueryDag, _times: &HashMap<usize, f64>) {
67 }
69
70 fn name(&self) -> &'static str {
71 "topological"
72 }
73
74 fn complexity(&self) -> &'static str {
75 "O(n)"
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use crate::dag::{OperatorNode, OperatorType};
83
84 #[test]
85 fn test_topological_attention() {
86 let mut dag = QueryDag::new();
87
88 let id0 = dag.add_node(OperatorNode::seq_scan(0, "users").with_estimates(100.0, 1.0));
90 let id1 = dag.add_node(OperatorNode::filter(0, "age > 18").with_estimates(50.0, 1.0));
91 let id2 = dag
92 .add_node(OperatorNode::project(0, vec!["name".to_string()]).with_estimates(50.0, 1.0));
93
94 dag.add_edge(id0, id1).unwrap();
95 dag.add_edge(id1, id2).unwrap();
96
97 let attention = TopologicalAttention::with_defaults();
98 let scores = attention.forward(&dag).unwrap();
99
100 let sum: f32 = scores.values().sum();
102 assert!((sum - 1.0).abs() < 1e-5);
103
104 for &score in scores.values() {
106 assert!(score >= 0.0 && score <= 1.0);
107 }
108 }
109}