Skip to main content

ruvector_dag/attention/
temporal_btsp.rs

1//! Temporal BTSP Attention: Behavioral Timescale Synaptic Plasticity
2//!
3//! This mechanism implements a biologically-inspired attention mechanism based on
4//! eligibility traces and plateau potentials, allowing the system to learn from
5//! temporal patterns in query execution.
6
7use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
8use crate::dag::QueryDag;
9use std::collections::HashMap;
10use std::time::Instant;
11
12#[derive(Debug, Clone)]
13pub struct TemporalBTSPConfig {
14    /// Duration of plateau state in milliseconds
15    pub plateau_duration_ms: u64,
16    /// Decay rate for eligibility traces (0.0 to 1.0)
17    pub eligibility_decay: f32,
18    /// Learning rate for trace updates
19    pub learning_rate: f32,
20    /// Temperature for softmax
21    pub temperature: f32,
22    /// Baseline attention for nodes without history
23    pub baseline_attention: f32,
24}
25
26impl Default for TemporalBTSPConfig {
27    fn default() -> Self {
28        Self {
29            plateau_duration_ms: 500,
30            eligibility_decay: 0.95,
31            learning_rate: 0.1,
32            temperature: 0.1,
33            baseline_attention: 0.5,
34        }
35    }
36}
37
38pub struct TemporalBTSPAttention {
39    config: TemporalBTSPConfig,
40    /// Eligibility traces for each node
41    eligibility_traces: HashMap<usize, f32>,
42    /// Timestamp of last plateau for each node
43    last_plateau: HashMap<usize, Instant>,
44    /// Total updates counter
45    update_count: usize,
46}
47
48impl TemporalBTSPAttention {
49    pub fn new(config: TemporalBTSPConfig) -> Self {
50        Self {
51            config,
52            eligibility_traces: HashMap::new(),
53            last_plateau: HashMap::new(),
54            update_count: 0,
55        }
56    }
57
58    /// Update eligibility trace for a node
59    fn update_eligibility(&mut self, node_id: usize, signal: f32) {
60        let trace = self.eligibility_traces.entry(node_id).or_insert(0.0);
61        *trace = *trace * self.config.eligibility_decay + signal * self.config.learning_rate;
62
63        // Clamp to [0, 1]
64        *trace = trace.max(0.0).min(1.0);
65    }
66
67    /// Check if node is in plateau state
68    fn is_plateau(&self, node_id: usize) -> bool {
69        self.last_plateau
70            .get(&node_id)
71            .map(|t| t.elapsed().as_millis() < self.config.plateau_duration_ms as u128)
72            .unwrap_or(false)
73    }
74
75    /// Trigger plateau for a node
76    fn trigger_plateau(&mut self, node_id: usize) {
77        self.last_plateau.insert(node_id, Instant::now());
78    }
79
80    /// Compute base attention from topology
81    fn compute_topology_attention(&self, dag: &QueryDag) -> Vec<f32> {
82        let n = dag.node_count();
83        let mut scores = vec![self.config.baseline_attention; n];
84
85        // Simple heuristic: nodes with higher cost get more attention
86        for node in dag.nodes() {
87            if node.id < n {
88                let cost_factor = (node.estimated_cost as f32 / 100.0).min(1.0);
89                let rows_factor = (node.estimated_rows as f32 / 1000.0).min(1.0);
90                scores[node.id] = 0.5 * cost_factor + 0.5 * rows_factor;
91            }
92        }
93
94        scores
95    }
96
97    /// Apply eligibility trace modulation
98    fn apply_eligibility_modulation(&self, base_scores: &mut [f32]) {
99        for (node_id, &trace) in &self.eligibility_traces {
100            if *node_id < base_scores.len() {
101                // Boost attention based on eligibility trace
102                base_scores[*node_id] *= 1.0 + trace;
103            }
104        }
105    }
106
107    /// Apply plateau boosting
108    fn apply_plateau_boost(&self, scores: &mut [f32]) {
109        for (node_id, _) in &self.last_plateau {
110            if *node_id < scores.len() && self.is_plateau(*node_id) {
111                // Strong boost for nodes in plateau state
112                scores[*node_id] *= 1.5;
113            }
114        }
115    }
116
117    /// Normalize scores using softmax
118    fn normalize_scores(&self, scores: &mut [f32]) {
119        if scores.is_empty() {
120            return;
121        }
122
123        let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
124        let exp_sum: f32 = scores
125            .iter()
126            .map(|&s| ((s - max_score) / self.config.temperature).exp())
127            .sum();
128
129        if exp_sum > 0.0 {
130            for score in scores.iter_mut() {
131                *score = ((*score - max_score) / self.config.temperature).exp() / exp_sum;
132            }
133        } else {
134            let uniform = 1.0 / scores.len() as f32;
135            scores.fill(uniform);
136        }
137    }
138}
139
140impl DagAttentionMechanism for TemporalBTSPAttention {
141    fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
142        if dag.nodes.is_empty() {
143            return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
144        }
145
146        // Step 1: Compute base attention from topology
147        let mut scores = self.compute_topology_attention(dag);
148
149        // Step 2: Modulate by eligibility traces
150        self.apply_eligibility_modulation(&mut scores);
151
152        // Step 3: Apply plateau boosting for recently active nodes
153        self.apply_plateau_boost(&mut scores);
154
155        // Step 4: Normalize
156        self.normalize_scores(&mut scores);
157
158        // Build result with metadata
159        let mut result = AttentionScores::new(scores)
160            .with_metadata("mechanism".to_string(), "temporal_btsp".to_string())
161            .with_metadata("update_count".to_string(), self.update_count.to_string());
162
163        let active_traces = self
164            .eligibility_traces
165            .values()
166            .filter(|&&t| t > 0.01)
167            .count();
168        result
169            .metadata
170            .insert("active_traces".to_string(), active_traces.to_string());
171
172        let active_plateaus = self
173            .last_plateau
174            .keys()
175            .filter(|k| self.is_plateau(**k))
176            .count();
177        result
178            .metadata
179            .insert("active_plateaus".to_string(), active_plateaus.to_string());
180
181        Ok(result)
182    }
183
184    fn name(&self) -> &'static str {
185        "temporal_btsp"
186    }
187
188    fn complexity(&self) -> &'static str {
189        "O(n + t)"
190    }
191
192    fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>) {
193        self.update_count += 1;
194
195        // Update eligibility traces based on execution feedback
196        for (node_id, &exec_time) in execution_times {
197            let node = match dag.get_node(*node_id) {
198                Some(n) => n,
199                None => continue,
200            };
201
202            let expected_time = node.estimated_cost;
203
204            // Compute reward signal: positive if faster than expected, negative if slower
205            let time_ratio = exec_time / expected_time.max(0.001);
206            let reward = if time_ratio < 1.0 {
207                // Faster than expected - positive signal
208                1.0 - time_ratio as f32
209            } else {
210                // Slower than expected - negative signal
211                -(time_ratio as f32 - 1.0).min(1.0)
212            };
213
214            // Update eligibility trace
215            self.update_eligibility(*node_id, reward);
216
217            // Trigger plateau for nodes that significantly exceeded expectations
218            if reward > 0.3 {
219                self.trigger_plateau(*node_id);
220            }
221        }
222
223        // Decay traces for nodes that weren't executed
224        let executed_nodes: std::collections::HashSet<_> = execution_times.keys().collect();
225        for node_id in 0..dag.node_count() {
226            if !executed_nodes.contains(&node_id) {
227                self.update_eligibility(node_id, 0.0);
228            }
229        }
230    }
231
232    fn reset(&mut self) {
233        self.eligibility_traces.clear();
234        self.last_plateau.clear();
235        self.update_count = 0;
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::dag::{OperatorNode, OperatorType};
243    use std::thread::sleep;
244    use std::time::Duration;
245
246    #[test]
247    fn test_eligibility_update() {
248        let config = TemporalBTSPConfig::default();
249        let mut attention = TemporalBTSPAttention::new(config);
250
251        attention.update_eligibility(0, 0.5);
252        assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
253
254        attention.update_eligibility(0, 0.5);
255        assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
256    }
257
258    #[test]
259    fn test_plateau_state() {
260        let mut config = TemporalBTSPConfig::default();
261        config.plateau_duration_ms = 100;
262        let mut attention = TemporalBTSPAttention::new(config);
263
264        attention.trigger_plateau(0);
265        assert!(attention.is_plateau(0));
266
267        sleep(Duration::from_millis(150));
268        assert!(!attention.is_plateau(0));
269    }
270
271    #[test]
272    fn test_temporal_attention() {
273        let config = TemporalBTSPConfig::default();
274        let mut attention = TemporalBTSPAttention::new(config);
275
276        let mut dag = QueryDag::new();
277        for i in 0..3 {
278            let mut node = OperatorNode::new(i, OperatorType::Scan);
279            node.estimated_cost = 10.0;
280            dag.add_node(node);
281        }
282
283        // Initial forward pass
284        let result1 = attention.forward(&dag).unwrap();
285        assert_eq!(result1.scores.len(), 3);
286
287        // Simulate execution feedback
288        let mut exec_times = HashMap::new();
289        exec_times.insert(0, 5.0); // Faster than expected
290        exec_times.insert(1, 15.0); // Slower than expected
291
292        attention.update(&dag, &exec_times);
293
294        // Second forward pass should show different attention
295        let result2 = attention.forward(&dag).unwrap();
296        assert_eq!(result2.scores.len(), 3);
297
298        // Node 0 should have higher attention due to positive feedback
299        assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
300    }
301}