swarm_engine_core/exploration/
provider.rs1use std::fmt::Debug;
32
33use super::map::{ExplorationMap, GraphMap, MapNodeState, MapState};
34use super::mutation::ActionNodeData;
35use super::node_rules::Rules;
36use super::operator::{ConfigurableOperator, Operator, RulesBasedMutation};
37use super::selection::{AnySelection, SelectionKind};
38use crate::online_stats::SwarmStats;
39
40#[derive(Debug, Clone)]
46pub struct OperatorConfig {
47 pub selection: SelectionKind,
49 pub ucb1_c: f64,
51}
52
53impl Default for OperatorConfig {
54 fn default() -> Self {
55 Self {
56 selection: SelectionKind::Fifo,
57 ucb1_c: std::f64::consts::SQRT_2,
58 }
59 }
60}
61
62impl OperatorConfig {
63 pub fn ucb1(c: f64) -> Self {
65 Self {
66 selection: SelectionKind::Ucb1,
67 ucb1_c: c,
68 }
69 }
70
71 pub fn greedy() -> Self {
73 Self {
74 selection: SelectionKind::Greedy,
75 ..Default::default()
76 }
77 }
78
79 pub fn thompson() -> Self {
81 Self {
82 selection: SelectionKind::Thompson,
83 ..Default::default()
84 }
85 }
86}
87
88pub struct ProviderContext<'a, N, E, S>
96where
97 N: Debug + Clone,
98 E: Debug + Clone,
99 S: MapState,
100{
101 pub map: &'a GraphMap<N, E, S>,
103 pub stats: &'a SwarmStats,
105}
106
107impl<'a, N, E, S> ProviderContext<'a, N, E, S>
108where
109 N: Debug + Clone,
110 E: Debug + Clone,
111 S: MapState,
112{
113 pub fn new(map: &'a GraphMap<N, E, S>, stats: &'a SwarmStats) -> Self {
114 Self { map, stats }
115 }
116
117 pub fn frontier_count(&self) -> usize {
119 self.map.frontiers().len()
120 }
121
122 pub fn total_visits(&self) -> u32 {
124 self.stats.total_visits()
125 }
126
127 pub fn is_exploration_mature(&self, threshold: u32) -> bool {
129 self.stats.total_visits() >= threshold
130 }
131}
132
133pub trait OperatorProvider<R>: Send + Sync
141where
142 R: Rules,
143{
144 fn provide(
146 &self,
147 rules: R,
148 context: Option<&ProviderContext<'_, ActionNodeData, String, MapNodeState>>,
149 ) -> ConfigurableOperator<R>;
150
151 fn reevaluate(
153 &self,
154 _operator: &mut ConfigurableOperator<R>,
155 _context: &ProviderContext<'_, ActionNodeData, String, MapNodeState>,
156 ) {
157 }
159
160 fn name(&self) -> &str;
162}
163
164#[derive(Debug, Clone)]
170pub struct ConfigBasedOperatorProvider {
171 config: OperatorConfig,
172}
173
174impl ConfigBasedOperatorProvider {
175 pub fn new(config: OperatorConfig) -> Self {
176 Self { config }
177 }
178
179 pub fn fifo() -> Self {
181 Self::new(OperatorConfig::default())
182 }
183
184 pub fn ucb1(c: f64) -> Self {
186 Self::new(OperatorConfig::ucb1(c))
187 }
188
189 pub fn config(&self) -> &OperatorConfig {
191 &self.config
192 }
193}
194
195impl<R> OperatorProvider<R> for ConfigBasedOperatorProvider
196where
197 R: Rules + 'static,
198{
199 fn provide(
200 &self,
201 rules: R,
202 _context: Option<&ProviderContext<'_, ActionNodeData, String, MapNodeState>>,
203 ) -> ConfigurableOperator<R> {
204 let selection = AnySelection::from_kind(self.config.selection, self.config.ucb1_c);
205 Operator::new(RulesBasedMutation::new(), selection, rules)
206 }
207
208 fn name(&self) -> &str {
209 "ConfigBased"
210 }
211}
212
213#[derive(Debug, Clone)]
225pub struct AdaptiveOperatorProvider {
226 maturity_threshold: u32,
228 error_rate_threshold: f64,
230 ucb1_c: f64,
232}
233
234impl Default for AdaptiveOperatorProvider {
235 fn default() -> Self {
236 Self {
237 maturity_threshold: 10,
238 error_rate_threshold: 0.3,
239 ucb1_c: std::f64::consts::SQRT_2,
240 }
241 }
242}
243
244impl AdaptiveOperatorProvider {
245 pub fn new(maturity_threshold: u32, error_rate_threshold: f64, ucb1_c: f64) -> Self {
247 Self {
248 maturity_threshold,
249 error_rate_threshold,
250 ucb1_c,
251 }
252 }
253
254 pub fn with_maturity_threshold(mut self, threshold: u32) -> Self {
256 self.maturity_threshold = threshold;
257 self
258 }
259
260 pub fn with_error_rate_threshold(mut self, threshold: f64) -> Self {
262 self.error_rate_threshold = threshold.clamp(0.0, 1.0);
263 self
264 }
265
266 pub fn with_ucb1_c(mut self, c: f64) -> Self {
268 self.ucb1_c = c;
269 self
270 }
271
272 fn select_strategy(&self, stats: &SwarmStats) -> SelectionKind {
274 let visits = stats.total_visits();
275 let error_rate = stats.failure_rate();
276
277 if visits < self.maturity_threshold {
278 SelectionKind::Ucb1
279 } else if error_rate > self.error_rate_threshold {
280 SelectionKind::Thompson
281 } else {
282 SelectionKind::Greedy
283 }
284 }
285
286 pub fn current_selection(&self, stats: &SwarmStats) -> SelectionKind {
288 self.select_strategy(stats)
289 }
290}
291
292impl<R> OperatorProvider<R> for AdaptiveOperatorProvider
293where
294 R: Rules + 'static,
295{
296 fn provide(
297 &self,
298 rules: R,
299 context: Option<&ProviderContext<'_, ActionNodeData, String, MapNodeState>>,
300 ) -> ConfigurableOperator<R> {
301 let selection_kind = context
302 .map(|ctx| self.select_strategy(ctx.stats))
303 .unwrap_or(SelectionKind::Ucb1);
304
305 let selection = AnySelection::from_kind(selection_kind, self.ucb1_c);
306 Operator::new(RulesBasedMutation::new(), selection, rules)
307 }
308
309 fn reevaluate(
310 &self,
311 operator: &mut ConfigurableOperator<R>,
312 context: &ProviderContext<'_, ActionNodeData, String, MapNodeState>,
313 ) {
314 let current_kind = operator.selection.kind();
315 let new_kind = self.select_strategy(context.stats);
316
317 if current_kind != new_kind {
318 operator.selection = AnySelection::from_kind(new_kind, self.ucb1_c);
319 }
320 }
321
322 fn name(&self) -> &str {
323 "Adaptive"
324 }
325}
326
327#[cfg(test)]
332mod tests {
333 use super::*;
334 use crate::events::{ActionEventBuilder, ActionEventResult};
335 use crate::exploration::NodeRules;
336 use crate::types::WorkerId;
337
338 fn record_success(stats: &mut SwarmStats, action: &str) {
339 let event = ActionEventBuilder::new(0, WorkerId(0), action)
340 .result(ActionEventResult::success())
341 .build();
342 stats.record(&event);
343 }
344
345 fn record_failure(stats: &mut SwarmStats, action: &str) {
346 let event = ActionEventBuilder::new(0, WorkerId(0), action)
347 .result(ActionEventResult::failure("error"))
348 .build();
349 stats.record(&event);
350 }
351
352 #[test]
357 fn test_operator_config_default() {
358 let config = OperatorConfig::default();
359 assert_eq!(config.selection, SelectionKind::Fifo);
360 assert!((config.ucb1_c - std::f64::consts::SQRT_2).abs() < 1e-10);
361 }
362
363 #[test]
364 fn test_operator_config_ucb1() {
365 let config = OperatorConfig::ucb1(2.0);
366 assert_eq!(config.selection, SelectionKind::Ucb1);
367 assert_eq!(config.ucb1_c, 2.0);
368 }
369
370 #[test]
371 fn test_operator_config_greedy() {
372 let config = OperatorConfig::greedy();
373 assert_eq!(config.selection, SelectionKind::Greedy);
374 }
375
376 #[test]
377 fn test_operator_config_thompson() {
378 let config = OperatorConfig::thompson();
379 assert_eq!(config.selection, SelectionKind::Thompson);
380 }
381
382 #[test]
387 fn test_config_based_provider_fifo() {
388 let provider = ConfigBasedOperatorProvider::fifo();
389 let rules = NodeRules::for_testing();
390 let operator = provider.provide(rules, None);
391
392 assert_eq!(operator.name(), "RulesBased+FIFO");
393 }
394
395 #[test]
396 fn test_config_based_provider_ucb1() {
397 let provider = ConfigBasedOperatorProvider::ucb1(1.41);
398 let rules = NodeRules::for_testing();
399 let operator = provider.provide(rules, None);
400
401 assert_eq!(operator.name(), "RulesBased+UCB1");
402 }
403
404 #[test]
405 fn test_config_based_provider_with_context() {
406 let provider = ConfigBasedOperatorProvider::new(OperatorConfig::greedy());
407 let rules = NodeRules::for_testing();
408
409 let operator1 = provider.provide(rules.clone(), None);
411 assert_eq!(operator1.name(), "RulesBased+Greedy");
412
413 let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
415 let stats = SwarmStats::new();
416 let ctx = ProviderContext::new(&map, &stats);
417 let operator2 = provider.provide(rules, Some(&ctx));
418 assert_eq!(operator2.name(), "RulesBased+Greedy");
419 }
420
421 #[test]
426 fn test_provider_context_queries() {
427 let mut map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
428 let _root = map.create_root(ActionNodeData::new("root"), MapNodeState::Open);
429
430 let stats = SwarmStats::new();
431 let ctx = ProviderContext::new(&map, &stats);
432
433 assert_eq!(ctx.frontier_count(), 1);
434 assert_eq!(ctx.total_visits(), 0);
435 assert!(!ctx.is_exploration_mature(10));
436 }
437
438 #[test]
443 fn test_adaptive_provider_initial_ucb1() {
444 let provider = AdaptiveOperatorProvider::default();
445 let stats = SwarmStats::new();
446
447 assert_eq!(provider.current_selection(&stats), SelectionKind::Ucb1);
448 }
449
450 #[test]
451 fn test_adaptive_provider_mature_low_error_greedy() {
452 let provider = AdaptiveOperatorProvider::default().with_maturity_threshold(5);
453 let mut stats = SwarmStats::new();
454
455 for _ in 0..10 {
456 record_success(&mut stats, "grep");
457 }
458
459 assert_eq!(stats.failure_rate(), 0.0);
460 assert_eq!(provider.current_selection(&stats), SelectionKind::Greedy);
461 }
462
463 #[test]
464 fn test_adaptive_provider_mature_high_error_thompson() {
465 let provider = AdaptiveOperatorProvider::default()
466 .with_maturity_threshold(5)
467 .with_error_rate_threshold(0.3);
468 let mut stats = SwarmStats::new();
469
470 for _ in 0..5 {
471 record_success(&mut stats, "grep");
472 }
473 for _ in 0..5 {
474 record_failure(&mut stats, "grep");
475 }
476
477 assert_eq!(stats.failure_rate(), 0.5);
478 assert_eq!(provider.current_selection(&stats), SelectionKind::Thompson);
479 }
480
481 #[test]
482 fn test_adaptive_provider_provide() {
483 let provider = AdaptiveOperatorProvider::default();
484 let rules = NodeRules::for_testing();
485
486 let operator = provider.provide(rules.clone(), None);
487 assert_eq!(operator.name(), "RulesBased+UCB1");
488
489 let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
490 let mut stats = SwarmStats::new();
491 for _ in 0..20 {
492 record_success(&mut stats, "grep");
493 }
494 let ctx = ProviderContext::new(&map, &stats);
495
496 let operator2 = provider.provide(rules, Some(&ctx));
497 assert_eq!(operator2.name(), "RulesBased+Greedy");
498 }
499
500 #[test]
501 fn test_adaptive_provider_reevaluate() {
502 let provider = AdaptiveOperatorProvider::default().with_maturity_threshold(5);
503 let rules = NodeRules::for_testing();
504
505 let mut operator = provider.provide(rules, None);
506 assert_eq!(operator.selection.kind(), SelectionKind::Ucb1);
507
508 let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
509 let mut stats = SwarmStats::new();
510 for _ in 0..10 {
511 record_success(&mut stats, "grep");
512 }
513 let ctx = ProviderContext::new(&map, &stats);
514
515 provider.reevaluate(&mut operator, &ctx);
516 assert_eq!(operator.selection.kind(), SelectionKind::Greedy);
517 }
518
519 #[test]
520 fn test_adaptive_provider_reevaluate_to_thompson() {
521 let provider = AdaptiveOperatorProvider::default()
522 .with_maturity_threshold(5)
523 .with_error_rate_threshold(0.3);
524 let rules = NodeRules::for_testing();
525
526 let mut operator = provider.provide(rules, None);
527 assert_eq!(operator.selection.kind(), SelectionKind::Ucb1);
528
529 let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
530 let mut stats = SwarmStats::new();
531 for _ in 0..3 {
532 record_success(&mut stats, "grep");
533 }
534 for _ in 0..7 {
535 record_failure(&mut stats, "grep");
536 }
537 let ctx = ProviderContext::new(&map, &stats);
538
539 provider.reevaluate(&mut operator, &ctx);
540 assert_eq!(operator.selection.kind(), SelectionKind::Thompson);
541 }
542}