swarm_engine_core/exploration/selection/
ucb1.rs1use std::fmt::Debug;
7
8use super::SelectionLogic;
9use crate::exploration::map::{ExplorationMap, GraphMap, MapNodeId, MapState};
10use crate::exploration::mutation::ActionExtractor;
11use crate::learn::{LearnedProvider, LearningQuery};
12use crate::online_stats::SwarmStats;
13
14#[derive(Debug, Clone)]
22pub struct Ucb1 {
23 c: f64,
25 learning_weight: f64,
27}
28
29impl Default for Ucb1 {
30 fn default() -> Self {
31 Self {
32 c: std::f64::consts::SQRT_2,
33 learning_weight: 0.3,
34 }
35 }
36}
37
38impl Ucb1 {
39 pub fn new(c: f64) -> Self {
41 Self {
42 c,
43 learning_weight: 0.3,
44 }
45 }
46
47 pub fn with_weights(c: f64, learning_weight: f64) -> Self {
49 Self { c, learning_weight }
50 }
51
52 pub fn with_default_c() -> Self {
54 Self::default()
55 }
56
57 pub fn c(&self) -> f64 {
59 self.c
60 }
61
62 pub fn learning_weight(&self) -> f64 {
64 self.learning_weight
65 }
66
67 pub fn compute_score(
69 &self,
70 stats: &SwarmStats,
71 action: &str,
72 target: Option<&str>,
73 provider: &dyn LearnedProvider,
74 ) -> f64 {
75 self.compute_score_with_context(stats, action, target, provider, None, None)
76 }
77
78 pub fn compute_score_with_context(
82 &self,
83 stats: &SwarmStats,
84 action: &str,
85 target: Option<&str>,
86 provider: &dyn LearnedProvider,
87 prev_action: Option<&str>,
88 prev_prev_action: Option<&str>,
89 ) -> f64 {
90 let action_stats = match target {
91 Some(t) => stats.get_action_target_stats(action, t),
92 None => stats.get_action_stats(action),
93 };
94 let total = stats.total_visits().max(1);
95
96 if action_stats.visits == 0 {
97 return f64::INFINITY; }
99
100 let success_rate = action_stats.success_rate();
101 let exploration = self.c * ((total as f64).ln() / action_stats.visits as f64).sqrt();
102
103 let learning_bonus = provider
105 .query(LearningQuery::confidence_with_context(
106 action,
107 target,
108 prev_action,
109 prev_prev_action,
110 ))
111 .score();
112
113 success_rate + exploration + learning_bonus * self.learning_weight
114 }
115}
116
117impl<N, E, S> SelectionLogic<N, E, S> for Ucb1
118where
119 N: Debug + Clone + ActionExtractor,
120 E: Debug + Clone,
121 S: MapState,
122{
123 fn next(
124 &self,
125 map: &GraphMap<N, E, S>,
126 stats: &SwarmStats,
127 provider: &dyn LearnedProvider,
128 ) -> Option<MapNodeId> {
129 self.select(map, 1, stats, provider).into_iter().next()
130 }
131
132 fn select(
133 &self,
134 map: &GraphMap<N, E, S>,
135 count: usize,
136 stats: &SwarmStats,
137 provider: &dyn LearnedProvider,
138 ) -> Vec<MapNodeId> {
139 let frontiers = map.frontiers();
140 if frontiers.is_empty() || count == 0 {
141 return vec![];
142 }
143
144 let mut scored: Vec<_> = frontiers
146 .iter()
147 .filter_map(|&id| {
148 map.get(id).map(|node| {
149 let (action, target) = node.data.extract();
150
151 let parent_id = map.parent(id);
153 let prev_action = parent_id.and_then(|pid| {
154 map.get(pid)
155 .map(|parent_node| parent_node.data.action_name().to_string())
156 });
157
158 let prev_prev_action =
160 parent_id.and_then(|pid| map.parent(pid)).and_then(|ppid| {
161 map.get(ppid).map(|grandparent_node| {
162 grandparent_node.data.action_name().to_string()
163 })
164 });
165
166 let score = self.compute_score_with_context(
167 stats,
168 action,
169 target,
170 provider,
171 prev_action.as_deref(),
172 prev_prev_action.as_deref(),
173 );
174 (id, score)
175 })
176 })
177 .collect();
178
179 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
181
182 scored.into_iter().take(count).map(|(id, _)| id).collect()
183 }
184
185 fn score(
186 &self,
187 action: &str,
188 target: Option<&str>,
189 stats: &SwarmStats,
190 provider: &dyn LearnedProvider,
191 ) -> f64 {
192 self.compute_score(stats, action, target, provider)
193 }
194
195 fn name(&self) -> &str {
196 "UCB1"
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::events::ActionEventBuilder;
204 use crate::learn::NullProvider;
205
206 fn record_success(stats: &mut SwarmStats, action: &str, target: Option<&str>) {
207 use crate::events::ActionEventResult;
208 use crate::types::WorkerId;
209
210 let mut builder = ActionEventBuilder::new(0, WorkerId(0), action);
211 if let Some(t) = target {
212 builder = builder.target(t);
213 }
214 let event = builder.result(ActionEventResult::success()).build();
215 stats.record(&event);
216 }
217
218 #[test]
219 fn test_ucb1_unvisited_is_infinity() {
220 let ucb1 = Ucb1::new(1.41);
221 let stats = SwarmStats::new();
222 let provider = NullProvider;
223
224 assert!(ucb1
226 .compute_score(&stats, "grep", None, &provider)
227 .is_infinite());
228 }
229
230 #[test]
231 fn test_ucb1_score_changes_with_visits() {
232 let ucb1 = Ucb1::new(1.41);
233 let mut stats = SwarmStats::new();
234 let provider = NullProvider;
235
236 let score1 = ucb1.compute_score(&stats, "grep", None, &provider);
238 assert!(score1.is_infinite());
239
240 record_success(&mut stats, "grep", None);
242 let score2 = ucb1.compute_score(&stats, "grep", None, &provider);
243 assert!(score2.is_finite());
244 assert!((score2 - 1.0).abs() < 0.01);
246
247 record_success(&mut stats, "glob", None);
249 let score3 = ucb1.compute_score(&stats, "grep", None, &provider);
250 assert!(score3 > score2);
252 }
253
254 #[test]
255 fn test_ucb1_default_c() {
256 let ucb1 = Ucb1::default();
257 assert!((ucb1.c() - std::f64::consts::SQRT_2).abs() < 1e-10);
258 assert!((ucb1.learning_weight() - 0.3).abs() < 1e-10);
259 }
260
261 #[test]
262 fn test_ucb1_with_weights() {
263 let ucb1 = Ucb1::with_weights(2.0, 0.5);
264 assert!((ucb1.c() - 2.0).abs() < 1e-10);
265 assert!((ucb1.learning_weight() - 0.5).abs() < 1e-10);
266 }
267}