swarm_engine_core/exploration/selection/
thompson.rs1use std::fmt::Debug;
7
8use super::SelectionLogic;
9use crate::exploration::map::{ExplorationMap, GraphMap, MapNodeId, MapState};
10use crate::exploration::mutation::ActionExtractor;
11use crate::learn::{LearnedProvider, LearningQuery};
12use crate::online_stats::SwarmStats;
13
14#[derive(Debug, Clone)]
21pub struct Thompson {
22 learning_weight: f64,
24}
25
26impl Default for Thompson {
27 fn default() -> Self {
28 Self {
29 learning_weight: 0.3,
30 }
31 }
32}
33
34impl Thompson {
35 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn with_learning_weight(learning_weight: f64) -> Self {
41 Self { learning_weight }
42 }
43
44 pub fn learning_weight(&self) -> f64 {
46 self.learning_weight
47 }
48
49 pub fn compute_score(
51 &self,
52 stats: &SwarmStats,
53 action: &str,
54 target: Option<&str>,
55 provider: &dyn LearnedProvider,
56 ) -> f64 {
57 self.compute_score_with_context(stats, action, target, provider, None, None)
58 }
59
60 pub fn compute_score_with_context(
64 &self,
65 stats: &SwarmStats,
66 action: &str,
67 target: Option<&str>,
68 provider: &dyn LearnedProvider,
69 prev_action: Option<&str>,
70 prev_prev_action: Option<&str>,
71 ) -> f64 {
72 let action_stats = match target {
73 Some(t) => stats.get_action_target_stats(action, t),
74 None => stats.get_action_stats(action),
75 };
76
77 let alpha = action_stats.successes as f64 + 1.0;
79 let beta = action_stats.failures as f64 + 1.0;
80 let base_score = alpha / (alpha + beta);
81
82 let learning_bonus = provider
84 .query(LearningQuery::confidence_with_context(
85 action,
86 target,
87 prev_action,
88 prev_prev_action,
89 ))
90 .score();
91
92 base_score + learning_bonus * self.learning_weight
93 }
94}
95
96impl<N, E, S> SelectionLogic<N, E, S> for Thompson
97where
98 N: Debug + Clone + ActionExtractor,
99 E: Debug + Clone,
100 S: MapState,
101{
102 fn next(
103 &self,
104 map: &GraphMap<N, E, S>,
105 stats: &SwarmStats,
106 provider: &dyn LearnedProvider,
107 ) -> Option<MapNodeId> {
108 self.select(map, 1, stats, provider).into_iter().next()
109 }
110
111 fn select(
112 &self,
113 map: &GraphMap<N, E, S>,
114 count: usize,
115 stats: &SwarmStats,
116 provider: &dyn LearnedProvider,
117 ) -> Vec<MapNodeId> {
118 let frontiers = map.frontiers();
119 if frontiers.is_empty() || count == 0 {
120 return vec![];
121 }
122
123 let mut scored: Vec<_> = frontiers
125 .iter()
126 .filter_map(|&id| {
127 map.get(id).map(|node| {
128 let (action, target) = node.data.extract();
129
130 let parent_id = map.parent(id);
132 let prev_action = parent_id.and_then(|pid| {
133 map.get(pid)
134 .map(|parent_node| parent_node.data.action_name().to_string())
135 });
136
137 let prev_prev_action =
139 parent_id.and_then(|pid| map.parent(pid)).and_then(|ppid| {
140 map.get(ppid).map(|grandparent_node| {
141 grandparent_node.data.action_name().to_string()
142 })
143 });
144
145 let score = self.compute_score_with_context(
146 stats,
147 action,
148 target,
149 provider,
150 prev_action.as_deref(),
151 prev_prev_action.as_deref(),
152 );
153 (id, score)
154 })
155 })
156 .collect();
157
158 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
160
161 scored.into_iter().take(count).map(|(id, _)| id).collect()
162 }
163
164 fn score(
165 &self,
166 action: &str,
167 target: Option<&str>,
168 stats: &SwarmStats,
169 provider: &dyn LearnedProvider,
170 ) -> f64 {
171 self.compute_score(stats, action, target, provider)
172 }
173
174 fn name(&self) -> &str {
175 "Thompson"
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::events::{ActionEventBuilder, ActionEventResult};
183 use crate::learn::NullProvider;
184 use crate::types::WorkerId;
185
186 fn record_success(stats: &mut SwarmStats, action: &str) {
187 let event = ActionEventBuilder::new(0, WorkerId(0), action)
188 .result(ActionEventResult::success())
189 .build();
190 stats.record(&event);
191 }
192
193 fn record_failure(stats: &mut SwarmStats, action: &str) {
194 let event = ActionEventBuilder::new(0, WorkerId(0), action)
195 .result(ActionEventResult::failure("error"))
196 .build();
197 stats.record(&event);
198 }
199
200 #[test]
201 fn test_thompson_initial_score() {
202 let thompson = Thompson::new();
203 let stats = SwarmStats::new();
204 let provider = NullProvider;
205
206 assert_eq!(thompson.compute_score(&stats, "grep", None, &provider), 0.5);
208 }
209
210 #[test]
211 fn test_thompson_score_increases_with_success() {
212 let thompson = Thompson::new();
213 let mut stats = SwarmStats::new();
214 let provider = NullProvider;
215
216 let initial = thompson.compute_score(&stats, "grep", None, &provider);
217
218 record_success(&mut stats, "grep");
219 let after_success = thompson.compute_score(&stats, "grep", None, &provider);
220
221 assert!(after_success > initial);
223 }
224
225 #[test]
226 fn test_thompson_score_decreases_with_failure() {
227 let thompson = Thompson::new();
228 let mut stats = SwarmStats::new();
229 let provider = NullProvider;
230
231 let initial = thompson.compute_score(&stats, "grep", None, &provider);
232
233 record_failure(&mut stats, "grep");
234 let after_failure = thompson.compute_score(&stats, "grep", None, &provider);
235
236 assert!(after_failure < initial);
238 }
239
240 #[test]
241 fn test_thompson_with_learning_weight() {
242 let thompson = Thompson::with_learning_weight(0.5);
243 assert!((thompson.learning_weight() - 0.5).abs() < 1e-10);
244 }
245}