Skip to main content

ruvector_dag/attention/
selector.rs

1//! Attention Selector: UCB Bandit for mechanism selection
2//!
3//! Implements Upper Confidence Bound (UCB1) algorithm to dynamically select
4//! the best attention mechanism based on observed performance.
5
6use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
7use crate::dag::QueryDag;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
11pub struct SelectorConfig {
12    /// UCB exploration constant (typically sqrt(2))
13    pub exploration_factor: f32,
14    /// Optimistic initialization value
15    pub initial_value: f32,
16    /// Minimum samples before exploitation
17    pub min_samples: usize,
18}
19
20impl Default for SelectorConfig {
21    fn default() -> Self {
22        Self {
23            exploration_factor: (2.0_f32).sqrt(),
24            initial_value: 1.0,
25            min_samples: 5,
26        }
27    }
28}
29
30pub struct AttentionSelector {
31    config: SelectorConfig,
32    mechanisms: Vec<Box<dyn DagAttentionMechanism>>,
33    /// Cumulative rewards for each mechanism
34    rewards: Vec<f32>,
35    /// Number of times each mechanism was selected
36    counts: Vec<usize>,
37    /// Total number of selections
38    total_count: usize,
39}
40
41impl AttentionSelector {
42    pub fn new(mechanisms: Vec<Box<dyn DagAttentionMechanism>>, config: SelectorConfig) -> Self {
43        let n = mechanisms.len();
44        let initial_value = config.initial_value;
45        Self {
46            config,
47            mechanisms,
48            rewards: vec![initial_value; n],
49            counts: vec![0; n],
50            total_count: 0,
51        }
52    }
53
54    /// Select mechanism using UCB1 algorithm
55    pub fn select(&self) -> usize {
56        if self.mechanisms.is_empty() {
57            return 0;
58        }
59
60        // If any mechanism hasn't been tried min_samples times, try it
61        for (i, &count) in self.counts.iter().enumerate() {
62            if count < self.config.min_samples {
63                return i;
64            }
65        }
66
67        // UCB1 selection: exploitation + exploration
68        let ln_total = (self.total_count as f32).ln().max(1.0);
69
70        let ucb_values: Vec<f32> = self
71            .mechanisms
72            .iter()
73            .enumerate()
74            .map(|(i, _)| {
75                let count = self.counts[i] as f32;
76                if count == 0.0 {
77                    return f32::INFINITY;
78                }
79
80                let exploitation = self.rewards[i] / count;
81                let exploration = self.config.exploration_factor * (ln_total / count).sqrt();
82
83                exploitation + exploration
84            })
85            .collect();
86
87        ucb_values
88            .iter()
89            .enumerate()
90            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
91            .map(|(i, _)| i)
92            .unwrap_or(0)
93    }
94
95    /// Update rewards after execution
96    pub fn update(&mut self, mechanism_idx: usize, reward: f32) {
97        if mechanism_idx < self.rewards.len() {
98            self.rewards[mechanism_idx] += reward;
99            self.counts[mechanism_idx] += 1;
100            self.total_count += 1;
101        }
102    }
103
104    /// Get the selected mechanism
105    pub fn get_mechanism(&self, idx: usize) -> Option<&dyn DagAttentionMechanism> {
106        self.mechanisms.get(idx).map(|m| m.as_ref())
107    }
108
109    /// Get mutable reference to mechanism for updates
110    pub fn get_mechanism_mut(&mut self, idx: usize) -> Option<&mut Box<dyn DagAttentionMechanism>> {
111        self.mechanisms.get_mut(idx)
112    }
113
114    /// Get statistics for all mechanisms
115    pub fn stats(&self) -> HashMap<&'static str, MechanismStats> {
116        self.mechanisms
117            .iter()
118            .enumerate()
119            .map(|(i, m)| {
120                let stats = MechanismStats {
121                    total_reward: self.rewards[i],
122                    count: self.counts[i],
123                    avg_reward: if self.counts[i] > 0 {
124                        self.rewards[i] / self.counts[i] as f32
125                    } else {
126                        0.0
127                    },
128                };
129                (m.name(), stats)
130            })
131            .collect()
132    }
133
134    /// Get the best performing mechanism based on average reward
135    pub fn best_mechanism(&self) -> Option<usize> {
136        self.mechanisms
137            .iter()
138            .enumerate()
139            .filter(|(i, _)| self.counts[*i] >= self.config.min_samples)
140            .max_by(|(i, _), (j, _)| {
141                let avg_i = self.rewards[*i] / self.counts[*i] as f32;
142                let avg_j = self.rewards[*j] / self.counts[*j] as f32;
143                avg_i
144                    .partial_cmp(&avg_j)
145                    .unwrap_or(std::cmp::Ordering::Equal)
146            })
147            .map(|(i, _)| i)
148    }
149
150    /// Reset all statistics
151    pub fn reset(&mut self) {
152        for i in 0..self.rewards.len() {
153            self.rewards[i] = self.config.initial_value;
154            self.counts[i] = 0;
155        }
156        self.total_count = 0;
157    }
158
159    /// Forward pass using selected mechanism
160    pub fn forward(&mut self, dag: &QueryDag) -> Result<(AttentionScores, usize), AttentionError> {
161        let selected = self.select();
162        let mechanism = self
163            .get_mechanism(selected)
164            .ok_or_else(|| AttentionError::ConfigError("No mechanisms available".to_string()))?;
165
166        let scores = mechanism.forward(dag)?;
167        Ok((scores, selected))
168    }
169}
170
171#[derive(Debug, Clone)]
172pub struct MechanismStats {
173    pub total_reward: f32,
174    pub count: usize,
175    pub avg_reward: f32,
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::dag::{OperatorNode, OperatorType, QueryDag};
182
183    // Mock mechanism for testing
184    struct MockMechanism {
185        name: &'static str,
186        score_value: f32,
187    }
188
189    impl DagAttentionMechanism for MockMechanism {
190        fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
191            let scores = vec![self.score_value; dag.nodes.len()];
192            Ok(AttentionScores::new(scores))
193        }
194
195        fn name(&self) -> &'static str {
196            self.name
197        }
198
199        fn complexity(&self) -> &'static str {
200            "O(1)"
201        }
202    }
203
204    #[test]
205    fn test_ucb_selection() {
206        let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![
207            Box::new(MockMechanism {
208                name: "mech1",
209                score_value: 0.5,
210            }),
211            Box::new(MockMechanism {
212                name: "mech2",
213                score_value: 0.7,
214            }),
215            Box::new(MockMechanism {
216                name: "mech3",
217                score_value: 0.3,
218            }),
219        ];
220
221        let mut selector = AttentionSelector::new(mechanisms, SelectorConfig::default());
222
223        // First selections should explore all mechanisms
224        for _ in 0..15 {
225            let selected = selector.select();
226            selector.update(selected, 0.5);
227        }
228
229        assert!(selector.total_count > 0);
230        assert!(selector.counts.iter().all(|&c| c > 0));
231    }
232
233    #[test]
234    fn test_best_mechanism() {
235        let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![
236            Box::new(MockMechanism {
237                name: "poor",
238                score_value: 0.3,
239            }),
240            Box::new(MockMechanism {
241                name: "good",
242                score_value: 0.8,
243            }),
244        ];
245
246        let mut selector = AttentionSelector::new(
247            mechanisms,
248            SelectorConfig {
249                min_samples: 2,
250                ..Default::default()
251            },
252        );
253
254        // Simulate different rewards
255        selector.update(0, 0.3);
256        selector.update(0, 0.4);
257        selector.update(1, 0.8);
258        selector.update(1, 0.9);
259
260        let best = selector.best_mechanism().unwrap();
261        assert_eq!(best, 1);
262    }
263
264    #[test]
265    fn test_selector_forward() {
266        let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![Box::new(MockMechanism {
267            name: "test",
268            score_value: 0.5,
269        })];
270
271        let mut selector = AttentionSelector::new(mechanisms, SelectorConfig::default());
272
273        let mut dag = QueryDag::new();
274        let node = OperatorNode::new(0, OperatorType::Scan);
275        dag.add_node(node);
276
277        let (scores, idx) = selector.forward(&dag).unwrap();
278        assert_eq!(scores.scores.len(), 1);
279        assert_eq!(idx, 0);
280    }
281
282    #[test]
283    fn test_stats() {
284        let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![Box::new(MockMechanism {
285            name: "mech1",
286            score_value: 0.5,
287        })];
288
289        // Use initial_value = 0 so we can test pure update accumulation
290        let config = SelectorConfig {
291            initial_value: 0.0,
292            ..Default::default()
293        };
294        let mut selector = AttentionSelector::new(mechanisms, config);
295        selector.update(0, 1.0);
296        selector.update(0, 2.0);
297
298        let stats = selector.stats();
299        let mech1_stats = stats.get("mech1").unwrap();
300
301        assert_eq!(mech1_stats.count, 2);
302        assert_eq!(mech1_stats.total_reward, 3.0);
303        assert_eq!(mech1_stats.avg_reward, 1.5);
304    }
305}