1use std::sync::atomic::{AtomicU32, Ordering};
35use std::sync::RwLock;
36use std::time::Instant;
37
38use super::map::MapNodeState;
39use super::mutation::ActionNodeData;
40use super::node_rules::Rules;
41use super::operator::{ConfigurableOperator, Operator, RulesBasedMutation};
42use super::provider::{AdaptiveOperatorProvider, OperatorProvider, ProviderContext};
43use super::selection::{AnySelection, SelectionKind};
44use crate::events::{LearningEvent, LearningEventChannel};
45use crate::online_stats::SwarmStats;
46
47#[derive(Debug, Clone)]
53pub struct StrategyContext {
54 pub frontier_count: usize,
56 pub total_visits: u32,
58 pub failure_rate: f64,
60 pub success_rate: f64,
62 pub current_strategy: SelectionKind,
64 pub avg_depth: Option<f32>,
66}
67
68impl StrategyContext {
69 pub fn new(
71 frontier_count: usize,
72 total_visits: u32,
73 failure_rate: f64,
74 current_strategy: SelectionKind,
75 ) -> Self {
76 Self {
77 frontier_count,
78 total_visits,
79 failure_rate,
80 success_rate: 1.0 - failure_rate,
81 current_strategy,
82 avg_depth: None,
83 }
84 }
85
86 pub fn from_provider_context(
88 ctx: &ProviderContext<'_, ActionNodeData, String, MapNodeState>,
89 current: SelectionKind,
90 ) -> Self {
91 Self {
92 frontier_count: ctx.frontier_count(),
93 total_visits: ctx.total_visits(),
94 failure_rate: ctx.stats.failure_rate(),
95 success_rate: ctx.stats.success_rate(),
96 current_strategy: current,
97 avg_depth: None,
98 }
99 }
100
101 pub fn from_stats(stats: &SwarmStats, frontier_count: usize, current: SelectionKind) -> Self {
103 Self {
104 frontier_count,
105 total_visits: stats.total_visits(),
106 failure_rate: stats.failure_rate(),
107 success_rate: stats.success_rate(),
108 current_strategy: current,
109 avg_depth: None,
110 }
111 }
112
113 pub fn with_avg_depth(mut self, depth: f32) -> Self {
115 self.avg_depth = Some(depth);
116 self
117 }
118}
119
120#[derive(Debug, Clone)]
126pub struct StrategyAdvice {
127 pub recommended: SelectionKind,
129 pub should_change: bool,
131 pub reason: String,
133 pub confidence: f64,
135}
136
137impl StrategyAdvice {
138 pub fn no_change(current: SelectionKind, reason: impl Into<String>) -> Self {
140 Self {
141 recommended: current,
142 should_change: false,
143 reason: reason.into(),
144 confidence: 1.0,
145 }
146 }
147
148 pub fn change_to(new: SelectionKind, reason: impl Into<String>, confidence: f64) -> Self {
150 Self {
151 recommended: new,
152 should_change: true,
153 reason: reason.into(),
154 confidence,
155 }
156 }
157}
158
159#[derive(Debug, Clone, thiserror::Error)]
165pub enum StrategyAdviceError {
166 #[error("LLM call failed: {0}")]
168 LlmError(String),
169 #[error("Failed to parse response: {0}")]
171 ParseError(String),
172 #[error("Advisor not available")]
174 Unavailable,
175}
176
177pub trait StrategyAdvisor: Send + Sync {
186 fn advise(&self, context: &StrategyContext) -> Result<StrategyAdvice, StrategyAdviceError>;
188
189 fn name(&self) -> &str;
191}
192
193#[derive(Debug, Clone)]
199pub struct ReviewPolicy {
200 pub interval: u32,
202 pub min_interval: u32,
204 pub state_change_threshold: f64,
206}
207
208impl Default for ReviewPolicy {
209 fn default() -> Self {
210 Self {
211 interval: 20, min_interval: 5, state_change_threshold: 0.15, }
215 }
216}
217
218impl ReviewPolicy {
219 pub fn new(interval: u32, min_interval: u32, state_change_threshold: f64) -> Self {
221 Self {
222 interval,
223 min_interval,
224 state_change_threshold: state_change_threshold.clamp(0.0, 1.0),
225 }
226 }
227
228 pub fn frequent() -> Self {
230 Self {
231 interval: 10,
232 min_interval: 3,
233 state_change_threshold: 0.1,
234 }
235 }
236
237 pub fn conservative() -> Self {
239 Self {
240 interval: 50,
241 min_interval: 20,
242 state_change_threshold: 0.25,
243 }
244 }
245}
246
247pub struct AdaptiveLlmOperatorProvider {
262 adaptive: AdaptiveOperatorProvider,
264 advisor: Box<dyn StrategyAdvisor>,
266 policy: ReviewPolicy,
268 ucb1_c: f64,
270 last_review_visits: AtomicU32,
272 last_failure_rate: AtomicU32,
274 llm_override: RwLock<Option<SelectionKind>>,
276}
277
278impl AdaptiveLlmOperatorProvider {
279 pub fn new(advisor: Box<dyn StrategyAdvisor>) -> Self {
281 Self {
282 adaptive: AdaptiveOperatorProvider::default(),
283 advisor,
284 policy: ReviewPolicy::default(),
285 ucb1_c: std::f64::consts::SQRT_2,
286 last_review_visits: AtomicU32::new(0),
287 last_failure_rate: AtomicU32::new(0),
288 llm_override: RwLock::new(None),
289 }
290 }
291
292 pub fn with_policy(mut self, policy: ReviewPolicy) -> Self {
294 self.policy = policy;
295 self
296 }
297
298 pub fn with_adaptive(mut self, adaptive: AdaptiveOperatorProvider) -> Self {
300 self.adaptive = adaptive;
301 self
302 }
303
304 pub fn with_ucb1_c(mut self, c: f64) -> Self {
306 self.ucb1_c = c;
307 self
308 }
309
310 pub fn llm_override(&self) -> Option<SelectionKind> {
312 *self.llm_override.read().unwrap()
313 }
314
315 fn should_review(&self, stats: &SwarmStats) -> bool {
317 let current_visits = stats.total_visits();
318 let last_visits = self.last_review_visits.load(Ordering::Relaxed);
319
320 if current_visits < last_visits + self.policy.min_interval {
322 return false;
323 }
324
325 if current_visits >= last_visits + self.policy.interval {
327 return true;
328 }
329
330 let current_rate = (stats.failure_rate() * 1000.0) as u32;
332 let last_rate = self.last_failure_rate.load(Ordering::Relaxed);
333 let rate_diff = (current_rate as i32 - last_rate as i32).unsigned_abs() as f64 / 1000.0;
334
335 rate_diff >= self.policy.state_change_threshold
336 }
337
338 fn do_review(
340 &self,
341 stats: &SwarmStats,
342 frontier_count: usize,
343 current: SelectionKind,
344 ) -> Option<SelectionKind> {
345 let context = StrategyContext::from_stats(stats, frontier_count, current);
346
347 let start_time = Instant::now();
349
350 let result = self.advisor.advise(&context);
352
353 let elapsed = start_time.elapsed();
355
356 match result {
357 Ok(advice) => {
358 let latency_ms = elapsed.as_millis() as u64;
359
360 let tick = LearningEventChannel::global().current_tick();
362 LearningEventChannel::global().emit(
363 LearningEvent::strategy_advice(tick, self.advisor.name())
364 .current_strategy(current.to_string())
365 .recommended(advice.recommended.to_string())
366 .should_change(advice.should_change)
367 .confidence(advice.confidence)
368 .reason(&advice.reason)
369 .frontier_count(frontier_count)
370 .total_visits(stats.total_visits())
371 .failure_rate(stats.failure_rate())
372 .latency_ms(latency_ms)
373 .success()
374 .build(),
375 );
376
377 tracing::debug!(
379 target: "swarm_engine::learning",
380 advisor = %self.advisor.name(),
381 current_strategy = %current,
382 recommended = %advice.recommended,
383 should_change = advice.should_change,
384 confidence = advice.confidence,
385 reason = %advice.reason,
386 latency_ms = latency_ms,
387 "Strategy advice completed"
388 );
389
390 self.last_review_visits
392 .store(stats.total_visits(), Ordering::Relaxed);
393 self.last_failure_rate
394 .store((stats.failure_rate() * 1000.0) as u32, Ordering::Relaxed);
395
396 if advice.should_change {
397 Some(advice.recommended)
398 } else {
399 None
400 }
401 }
402 Err(e) => {
403 let latency_ms = elapsed.as_millis() as u64;
404
405 let tick = LearningEventChannel::global().current_tick();
407 LearningEventChannel::global().emit(
408 LearningEvent::strategy_advice(tick, self.advisor.name())
409 .current_strategy(current.to_string())
410 .recommended(current.to_string()) .frontier_count(frontier_count)
412 .total_visits(stats.total_visits())
413 .failure_rate(stats.failure_rate())
414 .latency_ms(latency_ms)
415 .failure(e.to_string())
416 .build(),
417 );
418
419 tracing::warn!(
421 advisor = %self.advisor.name(),
422 error = %e,
423 latency_ms = latency_ms,
424 "Strategy advisor failed, falling back to Adaptive"
425 );
426 None
427 }
428 }
429 }
430
431 fn effective_selection(&self, stats: &SwarmStats) -> SelectionKind {
433 if let Some(kind) = *self.llm_override.read().unwrap() {
435 return kind;
436 }
437 self.adaptive.current_selection(stats)
439 }
440}
441
442impl<R> OperatorProvider<R> for AdaptiveLlmOperatorProvider
443where
444 R: Rules + 'static,
445{
446 fn provide(
447 &self,
448 rules: R,
449 context: Option<&ProviderContext<'_, ActionNodeData, String, MapNodeState>>,
450 ) -> ConfigurableOperator<R> {
451 let selection_kind = match context {
452 Some(ctx) => {
453 let current = self.effective_selection(ctx.stats);
454
455 if self.should_review(ctx.stats) {
457 if let Some(new_kind) = self.do_review(ctx.stats, ctx.frontier_count(), current)
458 {
459 *self.llm_override.write().unwrap() = Some(new_kind);
460 new_kind
461 } else {
462 current
463 }
464 } else {
465 current
466 }
467 }
468 None => SelectionKind::Ucb1, };
470
471 let selection = AnySelection::from_kind(selection_kind, self.ucb1_c);
472 Operator::new(RulesBasedMutation::new(), selection, rules)
473 }
474
475 fn reevaluate(
476 &self,
477 operator: &mut ConfigurableOperator<R>,
478 context: &ProviderContext<'_, ActionNodeData, String, MapNodeState>,
479 ) {
480 let current = operator.selection().kind();
481
482 if self.should_review(context.stats) {
484 if let Some(new_kind) = self.do_review(context.stats, context.frontier_count(), current)
485 {
486 if new_kind != current {
487 tracing::info!(
488 from = %current,
489 to = %new_kind,
490 "Strategy changed by LLM advisor"
491 );
492 operator.set_selection(AnySelection::from_kind(new_kind, self.ucb1_c));
493 *self.llm_override.write().unwrap() = Some(new_kind);
494 }
495 return;
496 }
497 }
498
499 if self.llm_override.read().unwrap().is_none() {
502 self.adaptive.reevaluate(operator, context);
503 }
504 }
505
506 fn name(&self) -> &str {
507 "HybridLlm"
508 }
509}
510
511#[cfg(test)]
516mod tests {
517 use super::*;
518 use crate::events::{ActionEventBuilder, ActionEventResult};
519 use crate::exploration::{GraphMap, NodeRules};
520 use crate::types::WorkerId;
521
522 fn record_success(stats: &mut SwarmStats, action: &str) {
523 let event = ActionEventBuilder::new(0, WorkerId(0), action)
524 .result(ActionEventResult::success())
525 .build();
526 stats.record(&event);
527 }
528
529 fn record_failure(stats: &mut SwarmStats, action: &str) {
530 let event = ActionEventBuilder::new(0, WorkerId(0), action)
531 .result(ActionEventResult::failure("error"))
532 .build();
533 stats.record(&event);
534 }
535
536 struct MockAdvisor {
541 advice: StrategyAdvice,
542 call_count: std::sync::atomic::AtomicUsize,
543 }
544
545 impl MockAdvisor {
546 fn new(advice: StrategyAdvice) -> Self {
547 Self {
548 advice,
549 call_count: std::sync::atomic::AtomicUsize::new(0),
550 }
551 }
552
553 fn call_count(&self) -> usize {
554 self.call_count.load(Ordering::Relaxed)
555 }
556 }
557
558 impl StrategyAdvisor for MockAdvisor {
559 fn advise(
560 &self,
561 _context: &StrategyContext,
562 ) -> Result<StrategyAdvice, StrategyAdviceError> {
563 self.call_count
564 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
565 Ok(self.advice.clone())
566 }
567
568 fn name(&self) -> &str {
569 "MockAdvisor"
570 }
571 }
572
573 struct FailingAdvisor;
574
575 impl StrategyAdvisor for FailingAdvisor {
576 fn advise(
577 &self,
578 _context: &StrategyContext,
579 ) -> Result<StrategyAdvice, StrategyAdviceError> {
580 Err(StrategyAdviceError::LlmError("Mock error".into()))
581 }
582
583 fn name(&self) -> &str {
584 "FailingAdvisor"
585 }
586 }
587
588 #[test]
593 fn test_strategy_context_new() {
594 let ctx = StrategyContext::new(15, 47, 0.23, SelectionKind::Ucb1);
595 assert_eq!(ctx.frontier_count, 15);
596 assert_eq!(ctx.total_visits, 47);
597 assert!((ctx.failure_rate - 0.23).abs() < 0.001);
598 assert!((ctx.success_rate - 0.77).abs() < 0.001);
599 assert_eq!(ctx.current_strategy, SelectionKind::Ucb1);
600 }
601
602 #[test]
603 fn test_strategy_context_from_stats() {
604 let mut stats = SwarmStats::new();
605 for _ in 0..7 {
606 record_success(&mut stats, "action");
607 }
608 for _ in 0..3 {
609 record_failure(&mut stats, "action");
610 }
611
612 let ctx = StrategyContext::from_stats(&stats, 10, SelectionKind::Greedy);
613 assert_eq!(ctx.frontier_count, 10);
614 assert_eq!(ctx.total_visits, 10);
615 assert!((ctx.failure_rate - 0.3).abs() < 0.01);
616 }
617
618 #[test]
623 fn test_review_policy_default() {
624 let policy = ReviewPolicy::default();
625 assert_eq!(policy.interval, 20);
626 assert_eq!(policy.min_interval, 5);
627 assert!((policy.state_change_threshold - 0.15).abs() < 0.001);
628 }
629
630 #[test]
631 fn test_review_policy_frequent() {
632 let policy = ReviewPolicy::frequent();
633 assert_eq!(policy.interval, 10);
634 assert_eq!(policy.min_interval, 3);
635 }
636
637 #[test]
638 fn test_review_policy_conservative() {
639 let policy = ReviewPolicy::conservative();
640 assert_eq!(policy.interval, 50);
641 assert_eq!(policy.min_interval, 20);
642 }
643
644 #[test]
649 fn test_hybrid_provider_initial_ucb1() {
650 let advice = StrategyAdvice::no_change(SelectionKind::Ucb1, "test");
651 let advisor = MockAdvisor::new(advice);
652 let provider = AdaptiveLlmOperatorProvider::new(Box::new(advisor));
653 let rules = NodeRules::for_testing();
654
655 let operator = provider.provide(rules, None);
657 assert_eq!(operator.name(), "RulesBased+UCB1");
658 }
659
660 #[test]
661 fn test_hybrid_provider_review_at_interval() {
662 let advice = StrategyAdvice::change_to(SelectionKind::Greedy, "test", 0.9);
663 let advisor = MockAdvisor::new(advice);
664 let provider =
665 AdaptiveLlmOperatorProvider::new(Box::new(advisor)).with_policy(ReviewPolicy {
666 interval: 10,
667 min_interval: 5,
668 state_change_threshold: 0.5,
669 });
670 let rules = NodeRules::for_testing();
671
672 let mut stats = SwarmStats::new();
674 for _ in 0..20 {
675 record_success(&mut stats, "action");
676 }
677
678 let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
679 let ctx = ProviderContext::new(&map, &stats);
680
681 let operator = provider.provide(rules, Some(&ctx));
683 assert_eq!(operator.name(), "RulesBased+Greedy");
684
685 let advisor_ref = provider.advisor.as_ref();
687 let mock = unsafe { &*(advisor_ref as *const dyn StrategyAdvisor as *const MockAdvisor) };
688 assert_eq!(mock.call_count(), 1);
689 }
690
691 #[test]
692 fn test_hybrid_provider_no_review_before_min_interval() {
693 let advice = StrategyAdvice::change_to(SelectionKind::Greedy, "test", 0.9);
694 let advisor = MockAdvisor::new(advice);
695 let provider =
696 AdaptiveLlmOperatorProvider::new(Box::new(advisor)).with_policy(ReviewPolicy {
697 interval: 10,
698 min_interval: 5,
699 state_change_threshold: 0.5,
700 });
701 let rules = NodeRules::for_testing();
702
703 let mut stats = SwarmStats::new();
705 for _ in 0..3 {
706 record_success(&mut stats, "action");
707 }
708
709 let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
710 let ctx = ProviderContext::new(&map, &stats);
711
712 let operator = provider.provide(rules, Some(&ctx));
714 assert_eq!(operator.name(), "RulesBased+UCB1");
715 }
716
717 #[test]
718 fn test_hybrid_provider_fallback_on_error() {
719 let provider =
720 AdaptiveLlmOperatorProvider::new(Box::new(FailingAdvisor)).with_policy(ReviewPolicy {
721 interval: 1,
722 min_interval: 1,
723 state_change_threshold: 0.0,
724 });
725 let rules = NodeRules::for_testing();
726
727 let mut stats = SwarmStats::new();
728 for _ in 0..10 {
729 record_success(&mut stats, "action");
730 }
731
732 let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
733 let ctx = ProviderContext::new(&map, &stats);
734
735 let operator = provider.provide(rules, Some(&ctx));
737 assert!(operator.name().contains("RulesBased"));
739 }
740
741 #[test]
742 fn test_hybrid_provider_reevaluate() {
743 let advice = StrategyAdvice::change_to(SelectionKind::Thompson, "high variance", 0.85);
744 let advisor = MockAdvisor::new(advice);
745 let provider =
746 AdaptiveLlmOperatorProvider::new(Box::new(advisor)).with_policy(ReviewPolicy {
747 interval: 5,
748 min_interval: 1,
749 state_change_threshold: 0.5,
750 });
751 let rules = NodeRules::for_testing();
752
753 let mut operator = provider.provide(rules, None);
755 assert_eq!(operator.selection().kind(), SelectionKind::Ucb1);
756
757 let mut stats = SwarmStats::new();
759 for _ in 0..10 {
760 record_success(&mut stats, "action");
761 }
762 let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
763 let ctx = ProviderContext::new(&map, &stats);
764
765 provider.reevaluate(&mut operator, &ctx);
767 assert_eq!(operator.selection().kind(), SelectionKind::Thompson);
768 }
769
770 #[test]
771 fn test_hybrid_provider_state_change_trigger() {
772 let advice = StrategyAdvice::change_to(SelectionKind::Thompson, "high variance", 0.8);
773 let advisor = MockAdvisor::new(advice);
774 let provider =
775 AdaptiveLlmOperatorProvider::new(Box::new(advisor)).with_policy(ReviewPolicy {
776 interval: 100, min_interval: 1,
778 state_change_threshold: 0.1, });
780 let rules = NodeRules::for_testing();
781
782 let mut stats = SwarmStats::new();
784 for _ in 0..5 {
785 record_success(&mut stats, "action");
786 }
787 let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
788 let ctx = ProviderContext::new(&map, &stats);
789 let _ = provider.provide(rules.clone(), Some(&ctx));
790
791 for _ in 0..5 {
793 record_failure(&mut stats, "action");
794 }
795 let ctx2 = ProviderContext::new(&map, &stats);
796
797 let operator = provider.provide(rules, Some(&ctx2));
799 assert_eq!(operator.selection().kind(), SelectionKind::Thompson);
800 }
801}