Skip to main content

ruvector_dag/attention/
causal_cone.rs

1//! Causal Cone Attention: Focuses on ancestors with temporal discount
2
3use super::{AttentionError, AttentionScores, DagAttention};
4use crate::dag::QueryDag;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
8pub struct CausalConeConfig {
9    pub time_window_ms: u64,
10    pub future_discount: f32,
11    pub ancestor_weight: f32,
12}
13
14impl Default for CausalConeConfig {
15    fn default() -> Self {
16        Self {
17            time_window_ms: 1000,
18            future_discount: 0.8,
19            ancestor_weight: 0.9,
20        }
21    }
22}
23
24pub struct CausalConeAttention {
25    config: CausalConeConfig,
26}
27
28impl CausalConeAttention {
29    pub fn new(config: CausalConeConfig) -> Self {
30        Self { config }
31    }
32
33    pub fn with_defaults() -> Self {
34        Self::new(CausalConeConfig::default())
35    }
36}
37
38impl DagAttention for CausalConeAttention {
39    fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
40        if dag.node_count() == 0 {
41            return Err(AttentionError::EmptyDag);
42        }
43
44        let mut scores = HashMap::new();
45        let mut total = 0.0f32;
46
47        // For each node, compute attention based on:
48        // 1. Number of ancestors (causal influence)
49        // 2. Distance from node (temporal decay)
50        let node_ids: Vec<usize> = (0..dag.node_count()).collect();
51        for node_id in node_ids {
52            if dag.get_node(node_id).is_none() {
53                continue;
54            }
55
56            let ancestors = dag.ancestors(node_id);
57            let ancestor_count = ancestors.len();
58
59            // Base score is proportional to causal influence (number of ancestors)
60            let mut score = 1.0 + (ancestor_count as f32 * self.config.ancestor_weight);
61
62            // Apply temporal discount based on depth
63            let depths = dag.compute_depths();
64            if let Some(&depth) = depths.get(&node_id) {
65                score *= self.config.future_discount.powi(depth as i32);
66            }
67
68            scores.insert(node_id, score);
69            total += score;
70        }
71
72        // Normalize to sum to 1
73        if total > 0.0 {
74            for score in scores.values_mut() {
75                *score /= total;
76            }
77        }
78
79        Ok(scores)
80    }
81
82    fn update(&mut self, _dag: &QueryDag, _times: &HashMap<usize, f64>) {
83        // Could update temporal discount based on actual execution times
84        // For now, static configuration
85    }
86
87    fn name(&self) -> &'static str {
88        "causal_cone"
89    }
90
91    fn complexity(&self) -> &'static str {
92        "O(n^2)"
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::dag::{OperatorNode, OperatorType};
100
101    #[test]
102    fn test_causal_cone_attention() {
103        let mut dag = QueryDag::new();
104
105        // Create a DAG with multiple paths
106        let id0 = dag.add_node(OperatorNode::seq_scan(0, "table1"));
107        let id1 = dag.add_node(OperatorNode::seq_scan(0, "table2"));
108        let id2 = dag.add_node(OperatorNode::hash_join(0, "id"));
109        let id3 = dag.add_node(OperatorNode::project(0, vec!["name".to_string()]));
110
111        dag.add_edge(id0, id2).unwrap();
112        dag.add_edge(id1, id2).unwrap();
113        dag.add_edge(id2, id3).unwrap();
114
115        let attention = CausalConeAttention::with_defaults();
116        let scores = attention.forward(&dag).unwrap();
117
118        // Check normalization
119        let sum: f32 = scores.values().sum();
120        assert!((sum - 1.0).abs() < 1e-5);
121
122        // All scores should be in [0, 1]
123        for &score in scores.values() {
124            assert!(score >= 0.0 && score <= 1.0);
125        }
126    }
127}