1use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
7use crate::dag::QueryDag;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
11pub struct SelectorConfig {
12 pub exploration_factor: f32,
14 pub initial_value: f32,
16 pub min_samples: usize,
18}
19
20impl Default for SelectorConfig {
21 fn default() -> Self {
22 Self {
23 exploration_factor: (2.0_f32).sqrt(),
24 initial_value: 1.0,
25 min_samples: 5,
26 }
27 }
28}
29
30pub struct AttentionSelector {
31 config: SelectorConfig,
32 mechanisms: Vec<Box<dyn DagAttentionMechanism>>,
33 rewards: Vec<f32>,
35 counts: Vec<usize>,
37 total_count: usize,
39}
40
41impl AttentionSelector {
42 pub fn new(mechanisms: Vec<Box<dyn DagAttentionMechanism>>, config: SelectorConfig) -> Self {
43 let n = mechanisms.len();
44 let initial_value = config.initial_value;
45 Self {
46 config,
47 mechanisms,
48 rewards: vec![initial_value; n],
49 counts: vec![0; n],
50 total_count: 0,
51 }
52 }
53
54 pub fn select(&self) -> usize {
56 if self.mechanisms.is_empty() {
57 return 0;
58 }
59
60 for (i, &count) in self.counts.iter().enumerate() {
62 if count < self.config.min_samples {
63 return i;
64 }
65 }
66
67 let ln_total = (self.total_count as f32).ln().max(1.0);
69
70 let ucb_values: Vec<f32> = self
71 .mechanisms
72 .iter()
73 .enumerate()
74 .map(|(i, _)| {
75 let count = self.counts[i] as f32;
76 if count == 0.0 {
77 return f32::INFINITY;
78 }
79
80 let exploitation = self.rewards[i] / count;
81 let exploration = self.config.exploration_factor * (ln_total / count).sqrt();
82
83 exploitation + exploration
84 })
85 .collect();
86
87 ucb_values
88 .iter()
89 .enumerate()
90 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
91 .map(|(i, _)| i)
92 .unwrap_or(0)
93 }
94
95 pub fn update(&mut self, mechanism_idx: usize, reward: f32) {
97 if mechanism_idx < self.rewards.len() {
98 self.rewards[mechanism_idx] += reward;
99 self.counts[mechanism_idx] += 1;
100 self.total_count += 1;
101 }
102 }
103
104 pub fn get_mechanism(&self, idx: usize) -> Option<&dyn DagAttentionMechanism> {
106 self.mechanisms.get(idx).map(|m| m.as_ref())
107 }
108
109 pub fn get_mechanism_mut(&mut self, idx: usize) -> Option<&mut Box<dyn DagAttentionMechanism>> {
111 self.mechanisms.get_mut(idx)
112 }
113
114 pub fn stats(&self) -> HashMap<&'static str, MechanismStats> {
116 self.mechanisms
117 .iter()
118 .enumerate()
119 .map(|(i, m)| {
120 let stats = MechanismStats {
121 total_reward: self.rewards[i],
122 count: self.counts[i],
123 avg_reward: if self.counts[i] > 0 {
124 self.rewards[i] / self.counts[i] as f32
125 } else {
126 0.0
127 },
128 };
129 (m.name(), stats)
130 })
131 .collect()
132 }
133
134 pub fn best_mechanism(&self) -> Option<usize> {
136 self.mechanisms
137 .iter()
138 .enumerate()
139 .filter(|(i, _)| self.counts[*i] >= self.config.min_samples)
140 .max_by(|(i, _), (j, _)| {
141 let avg_i = self.rewards[*i] / self.counts[*i] as f32;
142 let avg_j = self.rewards[*j] / self.counts[*j] as f32;
143 avg_i
144 .partial_cmp(&avg_j)
145 .unwrap_or(std::cmp::Ordering::Equal)
146 })
147 .map(|(i, _)| i)
148 }
149
150 pub fn reset(&mut self) {
152 for i in 0..self.rewards.len() {
153 self.rewards[i] = self.config.initial_value;
154 self.counts[i] = 0;
155 }
156 self.total_count = 0;
157 }
158
159 pub fn forward(&mut self, dag: &QueryDag) -> Result<(AttentionScores, usize), AttentionError> {
161 let selected = self.select();
162 let mechanism = self
163 .get_mechanism(selected)
164 .ok_or_else(|| AttentionError::ConfigError("No mechanisms available".to_string()))?;
165
166 let scores = mechanism.forward(dag)?;
167 Ok((scores, selected))
168 }
169}
170
171#[derive(Debug, Clone)]
172pub struct MechanismStats {
173 pub total_reward: f32,
174 pub count: usize,
175 pub avg_reward: f32,
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::dag::{OperatorNode, OperatorType, QueryDag};
182
183 struct MockMechanism {
185 name: &'static str,
186 score_value: f32,
187 }
188
189 impl DagAttentionMechanism for MockMechanism {
190 fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
191 let scores = vec![self.score_value; dag.nodes.len()];
192 Ok(AttentionScores::new(scores))
193 }
194
195 fn name(&self) -> &'static str {
196 self.name
197 }
198
199 fn complexity(&self) -> &'static str {
200 "O(1)"
201 }
202 }
203
204 #[test]
205 fn test_ucb_selection() {
206 let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![
207 Box::new(MockMechanism {
208 name: "mech1",
209 score_value: 0.5,
210 }),
211 Box::new(MockMechanism {
212 name: "mech2",
213 score_value: 0.7,
214 }),
215 Box::new(MockMechanism {
216 name: "mech3",
217 score_value: 0.3,
218 }),
219 ];
220
221 let mut selector = AttentionSelector::new(mechanisms, SelectorConfig::default());
222
223 for _ in 0..15 {
225 let selected = selector.select();
226 selector.update(selected, 0.5);
227 }
228
229 assert!(selector.total_count > 0);
230 assert!(selector.counts.iter().all(|&c| c > 0));
231 }
232
233 #[test]
234 fn test_best_mechanism() {
235 let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![
236 Box::new(MockMechanism {
237 name: "poor",
238 score_value: 0.3,
239 }),
240 Box::new(MockMechanism {
241 name: "good",
242 score_value: 0.8,
243 }),
244 ];
245
246 let mut selector = AttentionSelector::new(
247 mechanisms,
248 SelectorConfig {
249 min_samples: 2,
250 ..Default::default()
251 },
252 );
253
254 selector.update(0, 0.3);
256 selector.update(0, 0.4);
257 selector.update(1, 0.8);
258 selector.update(1, 0.9);
259
260 let best = selector.best_mechanism().unwrap();
261 assert_eq!(best, 1);
262 }
263
264 #[test]
265 fn test_selector_forward() {
266 let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![Box::new(MockMechanism {
267 name: "test",
268 score_value: 0.5,
269 })];
270
271 let mut selector = AttentionSelector::new(mechanisms, SelectorConfig::default());
272
273 let mut dag = QueryDag::new();
274 let node = OperatorNode::new(0, OperatorType::Scan);
275 dag.add_node(node);
276
277 let (scores, idx) = selector.forward(&dag).unwrap();
278 assert_eq!(scores.scores.len(), 1);
279 assert_eq!(idx, 0);
280 }
281
282 #[test]
283 fn test_stats() {
284 let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![Box::new(MockMechanism {
285 name: "mech1",
286 score_value: 0.5,
287 })];
288
289 let config = SelectorConfig {
291 initial_value: 0.0,
292 ..Default::default()
293 };
294 let mut selector = AttentionSelector::new(mechanisms, config);
295 selector.update(0, 1.0);
296 selector.update(0, 2.0);
297
298 let stats = selector.stats();
299 let mech1_stats = stats.get("mech1").unwrap();
300
301 assert_eq!(mech1_stats.count, 2);
302 assert_eq!(mech1_stats.total_reward, 3.0);
303 assert_eq!(mech1_stats.avg_reward, 1.5);
304 }
305}