1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
49use std::sync::{Mutex, OnceLock};
50
51use serde::{Deserialize, Serialize};
52use tokio::sync::broadcast;
53
54use crate::util::epoch_millis;
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
62#[serde(tag = "event_type")]
63pub enum LearningEvent {
64 #[serde(rename = "llm_strategy_advice")]
66 StrategyAdvice {
67 timestamp_ms: u64,
69 tick: u64,
71 advisor: String,
73 current_strategy: String,
75 recommended: String,
77 should_change: bool,
79 confidence: f64,
81 reason: String,
83 frontier_count: usize,
85 total_visits: u32,
87 failure_rate: f64,
89 latency_ms: u64,
91 success: bool,
93 error: Option<String>,
95 },
96
97 #[serde(rename = "dependency_graph_inference")]
99 DependencyGraphInference {
100 timestamp_ms: u64,
102 prompt: String,
104 response: String,
106 available_actions: Vec<String>,
108 discover_order: Vec<String>,
110 not_discover_order: Vec<String>,
112 model: String,
114 endpoint: String,
116 lora: Option<String>,
118 latency_ms: u64,
120 success: bool,
122 error: Option<String>,
124 },
125
126 #[serde(rename = "learn_stats_snapshot")]
128 LearnStatsSnapshot {
129 timestamp_ms: u64,
131 scenario: String,
133 session_id: String,
135 stats_json: String,
137 outcome: LearnStatsOutcome,
139 total_ticks: u64,
141 total_actions: u64,
143 },
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub enum LearnStatsOutcome {
149 Success { score: f64 },
151 Failure { reason: String },
153 Timeout { partial_score: Option<f64> },
155}
156
157impl LearningEvent {
158 pub fn strategy_advice(tick: u64, advisor: impl Into<String>) -> StrategyAdviceBuilder {
160 StrategyAdviceBuilder {
161 timestamp_ms: epoch_millis(),
162 tick,
163 advisor: advisor.into(),
164 current_strategy: String::new(),
165 recommended: String::new(),
166 should_change: false,
167 confidence: 0.0,
168 reason: String::new(),
169 frontier_count: 0,
170 total_visits: 0,
171 failure_rate: 0.0,
172 latency_ms: 0,
173 success: true,
174 error: None,
175 }
176 }
177
178 pub fn dependency_graph_inference(model: impl Into<String>) -> DependencyGraphInferenceBuilder {
180 DependencyGraphInferenceBuilder {
181 timestamp_ms: epoch_millis(),
182 prompt: String::new(),
183 response: String::new(),
184 available_actions: Vec::new(),
185 discover_order: Vec::new(),
186 not_discover_order: Vec::new(),
187 model: model.into(),
188 endpoint: String::new(),
189 lora: None,
190 latency_ms: 0,
191 success: true,
192 error: None,
193 }
194 }
195
196 pub fn learn_stats_snapshot(scenario: impl Into<String>) -> LearnStatsSnapshotBuilder {
198 LearnStatsSnapshotBuilder {
199 timestamp_ms: epoch_millis(),
200 scenario: scenario.into(),
201 session_id: String::new(),
202 stats_json: String::new(),
203 outcome: LearnStatsOutcome::Success { score: 0.0 },
204 total_ticks: 0,
205 total_actions: 0,
206 }
207 }
208}
209
210pub struct StrategyAdviceBuilder {
212 timestamp_ms: u64,
213 tick: u64,
214 advisor: String,
215 current_strategy: String,
216 recommended: String,
217 should_change: bool,
218 confidence: f64,
219 reason: String,
220 frontier_count: usize,
221 total_visits: u32,
222 failure_rate: f64,
223 latency_ms: u64,
224 success: bool,
225 error: Option<String>,
226}
227
228impl StrategyAdviceBuilder {
229 pub fn current_strategy(mut self, strategy: impl Into<String>) -> Self {
230 self.current_strategy = strategy.into();
231 self
232 }
233
234 pub fn recommended(mut self, strategy: impl Into<String>) -> Self {
235 self.recommended = strategy.into();
236 self
237 }
238
239 pub fn should_change(mut self, should: bool) -> Self {
240 self.should_change = should;
241 self
242 }
243
244 pub fn confidence(mut self, conf: f64) -> Self {
245 self.confidence = conf;
246 self
247 }
248
249 pub fn reason(mut self, reason: impl Into<String>) -> Self {
250 self.reason = reason.into();
251 self
252 }
253
254 pub fn frontier_count(mut self, count: usize) -> Self {
255 self.frontier_count = count;
256 self
257 }
258
259 pub fn total_visits(mut self, visits: u32) -> Self {
260 self.total_visits = visits;
261 self
262 }
263
264 pub fn failure_rate(mut self, rate: f64) -> Self {
265 self.failure_rate = rate;
266 self
267 }
268
269 pub fn latency_ms(mut self, ms: u64) -> Self {
270 self.latency_ms = ms;
271 self
272 }
273
274 pub fn success(mut self) -> Self {
275 self.success = true;
276 self.error = None;
277 self
278 }
279
280 pub fn failure(mut self, error: impl Into<String>) -> Self {
281 self.success = false;
282 self.error = Some(error.into());
283 self
284 }
285
286 pub fn build(self) -> LearningEvent {
287 LearningEvent::StrategyAdvice {
288 timestamp_ms: self.timestamp_ms,
289 tick: self.tick,
290 advisor: self.advisor,
291 current_strategy: self.current_strategy,
292 recommended: self.recommended,
293 should_change: self.should_change,
294 confidence: self.confidence,
295 reason: self.reason,
296 frontier_count: self.frontier_count,
297 total_visits: self.total_visits,
298 failure_rate: self.failure_rate,
299 latency_ms: self.latency_ms,
300 success: self.success,
301 error: self.error,
302 }
303 }
304}
305
306pub struct DependencyGraphInferenceBuilder {
308 timestamp_ms: u64,
309 prompt: String,
310 response: String,
311 available_actions: Vec<String>,
312 discover_order: Vec<String>,
313 not_discover_order: Vec<String>,
314 model: String,
315 endpoint: String,
316 lora: Option<String>,
317 latency_ms: u64,
318 success: bool,
319 error: Option<String>,
320}
321
322impl DependencyGraphInferenceBuilder {
323 pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
324 self.prompt = prompt.into();
325 self
326 }
327
328 pub fn response(mut self, response: impl Into<String>) -> Self {
329 self.response = response.into();
330 self
331 }
332
333 pub fn available_actions(mut self, actions: Vec<String>) -> Self {
334 self.available_actions = actions;
335 self
336 }
337
338 pub fn discover_order(mut self, order: Vec<String>) -> Self {
339 self.discover_order = order;
340 self
341 }
342
343 pub fn not_discover_order(mut self, order: Vec<String>) -> Self {
344 self.not_discover_order = order;
345 self
346 }
347
348 pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
349 self.endpoint = endpoint.into();
350 self
351 }
352
353 pub fn lora(mut self, lora: impl Into<String>) -> Self {
354 self.lora = Some(lora.into());
355 self
356 }
357
358 pub fn latency_ms(mut self, ms: u64) -> Self {
359 self.latency_ms = ms;
360 self
361 }
362
363 pub fn success(mut self) -> Self {
364 self.success = true;
365 self.error = None;
366 self
367 }
368
369 pub fn failure(mut self, error: impl Into<String>) -> Self {
370 self.success = false;
371 self.error = Some(error.into());
372 self
373 }
374
375 pub fn build(self) -> LearningEvent {
376 LearningEvent::DependencyGraphInference {
377 timestamp_ms: self.timestamp_ms,
378 prompt: self.prompt,
379 response: self.response,
380 available_actions: self.available_actions,
381 discover_order: self.discover_order,
382 not_discover_order: self.not_discover_order,
383 model: self.model,
384 endpoint: self.endpoint,
385 lora: self.lora,
386 latency_ms: self.latency_ms,
387 success: self.success,
388 error: self.error,
389 }
390 }
391}
392
393pub struct LearnStatsSnapshotBuilder {
395 timestamp_ms: u64,
396 scenario: String,
397 session_id: String,
398 stats_json: String,
399 outcome: LearnStatsOutcome,
400 total_ticks: u64,
401 total_actions: u64,
402}
403
404impl LearnStatsSnapshotBuilder {
405 pub fn session_id(mut self, id: impl Into<String>) -> Self {
406 self.session_id = id.into();
407 self
408 }
409
410 pub fn stats_json(mut self, json: impl Into<String>) -> Self {
411 self.stats_json = json.into();
412 self
413 }
414
415 pub fn success(mut self, score: f64) -> Self {
416 self.outcome = LearnStatsOutcome::Success { score };
417 self
418 }
419
420 pub fn failure(mut self, reason: impl Into<String>) -> Self {
421 self.outcome = LearnStatsOutcome::Failure {
422 reason: reason.into(),
423 };
424 self
425 }
426
427 pub fn timeout(mut self, partial_score: Option<f64>) -> Self {
428 self.outcome = LearnStatsOutcome::Timeout { partial_score };
429 self
430 }
431
432 pub fn total_ticks(mut self, ticks: u64) -> Self {
433 self.total_ticks = ticks;
434 self
435 }
436
437 pub fn total_actions(mut self, actions: u64) -> Self {
438 self.total_actions = actions;
439 self
440 }
441
442 pub fn build(self) -> LearningEvent {
443 LearningEvent::LearnStatsSnapshot {
444 timestamp_ms: self.timestamp_ms,
445 scenario: self.scenario,
446 session_id: self.session_id,
447 stats_json: self.stats_json,
448 outcome: self.outcome,
449 total_ticks: self.total_ticks,
450 total_actions: self.total_actions,
451 }
452 }
453}
454
455pub struct LearningEventChannel {
464 tx: broadcast::Sender<LearningEvent>,
466 enabled: AtomicBool,
468 current_tick: AtomicU64,
470 sync_buffer: Mutex<Vec<LearningEvent>>,
472}
473
474impl LearningEventChannel {
475 pub fn new(capacity: usize) -> Self {
477 let (tx, _) = broadcast::channel(capacity);
478 Self {
479 tx,
480 enabled: AtomicBool::new(false),
481 current_tick: AtomicU64::new(0),
482 sync_buffer: Mutex::new(Vec::new()),
483 }
484 }
485
486 pub fn global() -> &'static Self {
488 static INSTANCE: OnceLock<LearningEventChannel> = OnceLock::new();
489 INSTANCE.get_or_init(|| Self::new(256))
490 }
491
492 pub fn enable(&self) {
494 self.enabled.store(true, Ordering::Relaxed);
495 }
496
497 pub fn disable(&self) {
499 self.enabled.store(false, Ordering::Relaxed);
500 }
501
502 pub fn is_enabled(&self) -> bool {
504 self.enabled.load(Ordering::Relaxed)
505 }
506
507 pub fn set_tick(&self, tick: u64) {
509 self.current_tick.store(tick, Ordering::Relaxed);
510 }
511
512 pub fn current_tick(&self) -> u64 {
514 self.current_tick.load(Ordering::Relaxed)
515 }
516
517 pub fn emit(&self, event: LearningEvent) {
522 if self.enabled.load(Ordering::Relaxed) {
523 if let Ok(mut buffer) = self.sync_buffer.lock() {
525 buffer.push(event.clone());
526 }
527 let _ = self.tx.send(event);
529 }
530 }
531
532 pub fn drain_sync(&self) -> Vec<LearningEvent> {
537 if let Ok(mut buffer) = self.sync_buffer.lock() {
538 std::mem::take(&mut *buffer)
539 } else {
540 Vec::new()
541 }
542 }
543
544 pub fn subscribe(&self) -> broadcast::Receiver<LearningEvent> {
546 self.tx.subscribe()
547 }
548
549 pub fn receiver_count(&self) -> usize {
551 self.tx.receiver_count()
552 }
553}
554
555impl Default for LearningEventChannel {
556 fn default() -> Self {
557 Self::new(256)
558 }
559}
560
561#[cfg(test)]
566mod tests {
567 use super::*;
568
569 #[test]
570 fn test_channel_disabled_by_default() {
571 let channel = LearningEventChannel::new(16);
572 assert!(!channel.is_enabled());
573 }
574
575 #[test]
576 fn test_channel_enable_disable() {
577 let channel = LearningEventChannel::new(16);
578 channel.enable();
579 assert!(channel.is_enabled());
580 channel.disable();
581 assert!(!channel.is_enabled());
582 }
583
584 #[tokio::test]
585 async fn test_channel_emit_when_enabled() {
586 let channel = LearningEventChannel::new(16);
587 channel.enable();
588
589 let mut rx = channel.subscribe();
590
591 let event = LearningEvent::strategy_advice(42, "TestAdvisor")
592 .current_strategy("ucb1")
593 .recommended("greedy")
594 .should_change(true)
595 .confidence(0.9)
596 .reason("test reason")
597 .frontier_count(10)
598 .total_visits(100)
599 .failure_rate(0.1)
600 .latency_ms(50)
601 .success()
602 .build();
603
604 channel.emit(event);
605
606 let received = rx.recv().await.unwrap();
607 match received {
608 LearningEvent::StrategyAdvice {
609 tick,
610 advisor,
611 should_change,
612 ..
613 } => {
614 assert_eq!(tick, 42);
615 assert_eq!(advisor, "TestAdvisor");
616 assert!(should_change);
617 }
618 _ => panic!("Expected StrategyAdvice"),
619 }
620 }
621
622 #[tokio::test]
623 async fn test_channel_no_emit_when_disabled() {
624 let channel = LearningEventChannel::new(16);
625 let mut rx = channel.subscribe();
628
629 let event = LearningEvent::strategy_advice(0, "Test")
630 .current_strategy("ucb1")
631 .recommended("ucb1")
632 .build();
633
634 channel.emit(event);
635
636 let result = tokio::time::timeout(std::time::Duration::from_millis(10), rx.recv()).await;
638 assert!(result.is_err());
639 }
640
641 #[test]
642 fn test_tick_management() {
643 let channel = LearningEventChannel::new(16);
644 assert_eq!(channel.current_tick(), 0);
645
646 channel.set_tick(42);
647 assert_eq!(channel.current_tick(), 42);
648
649 channel.set_tick(100);
650 assert_eq!(channel.current_tick(), 100);
651 }
652
653 #[test]
654 fn test_drain_sync() {
655 let channel = LearningEventChannel::new(16);
656 channel.enable();
657
658 channel.emit(
660 LearningEvent::strategy_advice(1, "Advisor1")
661 .current_strategy("ucb1")
662 .recommended("greedy")
663 .build(),
664 );
665 channel.emit(
666 LearningEvent::strategy_advice(2, "Advisor2")
667 .current_strategy("greedy")
668 .recommended("thompson")
669 .build(),
670 );
671
672 let events = channel.drain_sync();
674 assert_eq!(events.len(), 2);
675
676 let t1 = match &events[0] {
677 LearningEvent::StrategyAdvice { tick, .. } => *tick,
678 _ => panic!("Expected StrategyAdvice"),
679 };
680 let t2 = match &events[1] {
681 LearningEvent::StrategyAdvice { tick, .. } => *tick,
682 _ => panic!("Expected StrategyAdvice"),
683 };
684 assert_eq!(t1, 1);
685 assert_eq!(t2, 2);
686
687 let events2 = channel.drain_sync();
689 assert!(events2.is_empty());
690 }
691
692 #[test]
693 fn test_drain_sync_disabled() {
694 let channel = LearningEventChannel::new(16);
695 channel.emit(
698 LearningEvent::strategy_advice(1, "Advisor")
699 .current_strategy("ucb1")
700 .recommended("ucb1")
701 .build(),
702 );
703
704 let events = channel.drain_sync();
706 assert!(events.is_empty());
707 }
708}