ruvector_dag/attention/
temporal_btsp.rs1use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
8use crate::dag::QueryDag;
9use std::collections::HashMap;
10use std::time::Instant;
11
12#[derive(Debug, Clone)]
13pub struct TemporalBTSPConfig {
14 pub plateau_duration_ms: u64,
16 pub eligibility_decay: f32,
18 pub learning_rate: f32,
20 pub temperature: f32,
22 pub baseline_attention: f32,
24}
25
26impl Default for TemporalBTSPConfig {
27 fn default() -> Self {
28 Self {
29 plateau_duration_ms: 500,
30 eligibility_decay: 0.95,
31 learning_rate: 0.1,
32 temperature: 0.1,
33 baseline_attention: 0.5,
34 }
35 }
36}
37
38pub struct TemporalBTSPAttention {
39 config: TemporalBTSPConfig,
40 eligibility_traces: HashMap<usize, f32>,
42 last_plateau: HashMap<usize, Instant>,
44 update_count: usize,
46}
47
48impl TemporalBTSPAttention {
49 pub fn new(config: TemporalBTSPConfig) -> Self {
50 Self {
51 config,
52 eligibility_traces: HashMap::new(),
53 last_plateau: HashMap::new(),
54 update_count: 0,
55 }
56 }
57
58 fn update_eligibility(&mut self, node_id: usize, signal: f32) {
60 let trace = self.eligibility_traces.entry(node_id).or_insert(0.0);
61 *trace = *trace * self.config.eligibility_decay + signal * self.config.learning_rate;
62
63 *trace = trace.max(0.0).min(1.0);
65 }
66
67 fn is_plateau(&self, node_id: usize) -> bool {
69 self.last_plateau
70 .get(&node_id)
71 .map(|t| t.elapsed().as_millis() < self.config.plateau_duration_ms as u128)
72 .unwrap_or(false)
73 }
74
75 fn trigger_plateau(&mut self, node_id: usize) {
77 self.last_plateau.insert(node_id, Instant::now());
78 }
79
80 fn compute_topology_attention(&self, dag: &QueryDag) -> Vec<f32> {
82 let n = dag.node_count();
83 let mut scores = vec![self.config.baseline_attention; n];
84
85 for node in dag.nodes() {
87 if node.id < n {
88 let cost_factor = (node.estimated_cost as f32 / 100.0).min(1.0);
89 let rows_factor = (node.estimated_rows as f32 / 1000.0).min(1.0);
90 scores[node.id] = 0.5 * cost_factor + 0.5 * rows_factor;
91 }
92 }
93
94 scores
95 }
96
97 fn apply_eligibility_modulation(&self, base_scores: &mut [f32]) {
99 for (node_id, &trace) in &self.eligibility_traces {
100 if *node_id < base_scores.len() {
101 base_scores[*node_id] *= 1.0 + trace;
103 }
104 }
105 }
106
107 fn apply_plateau_boost(&self, scores: &mut [f32]) {
109 for (node_id, _) in &self.last_plateau {
110 if *node_id < scores.len() && self.is_plateau(*node_id) {
111 scores[*node_id] *= 1.5;
113 }
114 }
115 }
116
117 fn normalize_scores(&self, scores: &mut [f32]) {
119 if scores.is_empty() {
120 return;
121 }
122
123 let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
124 let exp_sum: f32 = scores
125 .iter()
126 .map(|&s| ((s - max_score) / self.config.temperature).exp())
127 .sum();
128
129 if exp_sum > 0.0 {
130 for score in scores.iter_mut() {
131 *score = ((*score - max_score) / self.config.temperature).exp() / exp_sum;
132 }
133 } else {
134 let uniform = 1.0 / scores.len() as f32;
135 scores.fill(uniform);
136 }
137 }
138}
139
140impl DagAttentionMechanism for TemporalBTSPAttention {
141 fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
142 if dag.nodes.is_empty() {
143 return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
144 }
145
146 let mut scores = self.compute_topology_attention(dag);
148
149 self.apply_eligibility_modulation(&mut scores);
151
152 self.apply_plateau_boost(&mut scores);
154
155 self.normalize_scores(&mut scores);
157
158 let mut result = AttentionScores::new(scores)
160 .with_metadata("mechanism".to_string(), "temporal_btsp".to_string())
161 .with_metadata("update_count".to_string(), self.update_count.to_string());
162
163 let active_traces = self
164 .eligibility_traces
165 .values()
166 .filter(|&&t| t > 0.01)
167 .count();
168 result
169 .metadata
170 .insert("active_traces".to_string(), active_traces.to_string());
171
172 let active_plateaus = self
173 .last_plateau
174 .keys()
175 .filter(|k| self.is_plateau(**k))
176 .count();
177 result
178 .metadata
179 .insert("active_plateaus".to_string(), active_plateaus.to_string());
180
181 Ok(result)
182 }
183
184 fn name(&self) -> &'static str {
185 "temporal_btsp"
186 }
187
188 fn complexity(&self) -> &'static str {
189 "O(n + t)"
190 }
191
192 fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>) {
193 self.update_count += 1;
194
195 for (node_id, &exec_time) in execution_times {
197 let node = match dag.get_node(*node_id) {
198 Some(n) => n,
199 None => continue,
200 };
201
202 let expected_time = node.estimated_cost;
203
204 let time_ratio = exec_time / expected_time.max(0.001);
206 let reward = if time_ratio < 1.0 {
207 1.0 - time_ratio as f32
209 } else {
210 -(time_ratio as f32 - 1.0).min(1.0)
212 };
213
214 self.update_eligibility(*node_id, reward);
216
217 if reward > 0.3 {
219 self.trigger_plateau(*node_id);
220 }
221 }
222
223 let executed_nodes: std::collections::HashSet<_> = execution_times.keys().collect();
225 for node_id in 0..dag.node_count() {
226 if !executed_nodes.contains(&node_id) {
227 self.update_eligibility(node_id, 0.0);
228 }
229 }
230 }
231
232 fn reset(&mut self) {
233 self.eligibility_traces.clear();
234 self.last_plateau.clear();
235 self.update_count = 0;
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::dag::{OperatorNode, OperatorType};
243 use std::thread::sleep;
244 use std::time::Duration;
245
246 #[test]
247 fn test_eligibility_update() {
248 let config = TemporalBTSPConfig::default();
249 let mut attention = TemporalBTSPAttention::new(config);
250
251 attention.update_eligibility(0, 0.5);
252 assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
253
254 attention.update_eligibility(0, 0.5);
255 assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
256 }
257
258 #[test]
259 fn test_plateau_state() {
260 let mut config = TemporalBTSPConfig::default();
261 config.plateau_duration_ms = 100;
262 let mut attention = TemporalBTSPAttention::new(config);
263
264 attention.trigger_plateau(0);
265 assert!(attention.is_plateau(0));
266
267 sleep(Duration::from_millis(150));
268 assert!(!attention.is_plateau(0));
269 }
270
271 #[test]
272 fn test_temporal_attention() {
273 let config = TemporalBTSPConfig::default();
274 let mut attention = TemporalBTSPAttention::new(config);
275
276 let mut dag = QueryDag::new();
277 for i in 0..3 {
278 let mut node = OperatorNode::new(i, OperatorType::Scan);
279 node.estimated_cost = 10.0;
280 dag.add_node(node);
281 }
282
283 let result1 = attention.forward(&dag).unwrap();
285 assert_eq!(result1.scores.len(), 3);
286
287 let mut exec_times = HashMap::new();
289 exec_times.insert(0, 5.0); exec_times.insert(1, 15.0); attention.update(&dag, &exec_times);
293
294 let result2 = attention.forward(&dag).unwrap();
296 assert_eq!(result2.scores.len(), 3);
297
298 assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
300 }
301}