swarm_engine_core/learn/daemon/
subscriber.rs1use tokio::sync::{broadcast, mpsc};
78
79use crate::events::{ActionEvent, LearningEvent};
80use crate::learn::record::Record;
81
82#[derive(Debug, Clone)]
91pub struct EventSubscriberConfig {
92 pub batch_size: usize,
94 pub flush_interval_ms: Option<u64>,
96}
97
98impl Default for EventSubscriberConfig {
99 fn default() -> Self {
100 Self {
101 batch_size: 100,
102 flush_interval_ms: Some(1000), }
104 }
105}
106
107impl EventSubscriberConfig {
108 pub fn new() -> Self {
110 Self::default()
111 }
112
113 pub fn batch_size(mut self, size: usize) -> Self {
115 self.batch_size = size;
116 self
117 }
118
119 pub fn flush_interval_ms(mut self, ms: u64) -> Self {
121 self.flush_interval_ms = Some(ms);
122 self
123 }
124
125 pub fn no_flush_interval(mut self) -> Self {
127 self.flush_interval_ms = None;
128 self
129 }
130}
131
132pub struct ActionEventSubscriber {
145 rx: broadcast::Receiver<ActionEvent>,
147 record_tx: mpsc::Sender<Vec<Record>>,
149 config: EventSubscriberConfig,
151 buffer: Vec<Record>,
153}
154
155impl ActionEventSubscriber {
156 pub fn new(rx: broadcast::Receiver<ActionEvent>, record_tx: mpsc::Sender<Vec<Record>>) -> Self {
158 Self::with_config(rx, record_tx, EventSubscriberConfig::default())
159 }
160
161 pub fn with_config(
163 rx: broadcast::Receiver<ActionEvent>,
164 record_tx: mpsc::Sender<Vec<Record>>,
165 config: EventSubscriberConfig,
166 ) -> Self {
167 let batch_size = config.batch_size;
168 Self {
169 rx,
170 record_tx,
171 config,
172 buffer: Vec::with_capacity(batch_size),
173 }
174 }
175
176 pub async fn run(mut self) {
181 tracing::info!(
182 batch_size = self.config.batch_size,
183 flush_interval_ms = ?self.config.flush_interval_ms,
184 "ActionEventSubscriber started"
185 );
186
187 if let Some(interval_ms) = self.config.flush_interval_ms {
188 self.run_with_flush_interval(interval_ms).await;
189 } else {
190 self.run_batch_only().await;
191 }
192
193 self.flush().await;
195
196 tracing::info!("ActionEventSubscriber stopped");
197 }
198
199 async fn run_with_flush_interval(&mut self, interval_ms: u64) {
201 use std::time::Duration;
202 use tokio::time::{interval, Instant};
203
204 let mut flush_interval = interval(Duration::from_millis(interval_ms));
205 let mut last_flush = Instant::now();
206
207 loop {
208 tokio::select! {
209 result = self.rx.recv() => {
211 match result {
212 Ok(event) => {
213 self.buffer.push(Record::from(&event));
214
215 if self.buffer.len() >= self.config.batch_size {
216 if !self.flush().await {
217 return;
218 }
219 last_flush = Instant::now();
220 }
221 }
222 Err(broadcast::error::RecvError::Closed) => {
223 tracing::debug!("ActionEvent channel closed");
224 return;
225 }
226 Err(broadcast::error::RecvError::Lagged(n)) => {
227 tracing::warn!(lagged = n, "ActionEventSubscriber lagged behind");
228 }
229 }
230 }
231
232 _ = flush_interval.tick() => {
234 if !self.buffer.is_empty() && last_flush.elapsed().as_millis() as u64 >= interval_ms {
235 if !self.flush().await {
236 return;
237 }
238 last_flush = Instant::now();
239 }
240 }
241 }
242 }
243 }
244
245 async fn run_batch_only(&mut self) {
247 loop {
248 match self.rx.recv().await {
249 Ok(event) => {
250 self.buffer.push(Record::from(&event));
251
252 if self.buffer.len() >= self.config.batch_size && !self.flush().await {
253 return;
254 }
255 }
256 Err(broadcast::error::RecvError::Closed) => {
257 tracing::debug!("ActionEvent channel closed");
258 return;
259 }
260 Err(broadcast::error::RecvError::Lagged(n)) => {
261 tracing::warn!(lagged = n, "ActionEventSubscriber lagged behind");
262 }
263 }
264 }
265 }
266
267 async fn flush(&mut self) -> bool {
271 if self.buffer.is_empty() {
272 return true;
273 }
274
275 let records = std::mem::take(&mut self.buffer);
276 let count = records.len();
277
278 match self.record_tx.send(records).await {
279 Ok(()) => {
280 tracing::debug!(count, "Flushed ActionEvent records to LearningDaemon");
281 true
282 }
283 Err(_) => {
284 tracing::warn!("LearningDaemon channel closed");
285 false
286 }
287 }
288 }
289}
290
291pub struct LearningEventSubscriber {
319 rx: broadcast::Receiver<LearningEvent>,
321 record_tx: mpsc::Sender<Vec<Record>>,
323 config: EventSubscriberConfig,
325 buffer: Vec<Record>,
327}
328
329impl LearningEventSubscriber {
330 pub fn new(
332 rx: broadcast::Receiver<LearningEvent>,
333 record_tx: mpsc::Sender<Vec<Record>>,
334 ) -> Self {
335 Self::with_config(rx, record_tx, EventSubscriberConfig::default())
336 }
337
338 pub fn with_config(
340 rx: broadcast::Receiver<LearningEvent>,
341 record_tx: mpsc::Sender<Vec<Record>>,
342 config: EventSubscriberConfig,
343 ) -> Self {
344 let batch_size = config.batch_size;
345 Self {
346 rx,
347 record_tx,
348 config,
349 buffer: Vec::with_capacity(batch_size),
350 }
351 }
352
353 pub async fn run(mut self) {
358 tracing::info!(
359 batch_size = self.config.batch_size,
360 flush_interval_ms = ?self.config.flush_interval_ms,
361 "LearningEventSubscriber started"
362 );
363
364 if let Some(interval_ms) = self.config.flush_interval_ms {
365 self.run_with_flush_interval(interval_ms).await;
366 } else {
367 self.run_batch_only().await;
368 }
369
370 self.flush().await;
372
373 tracing::info!("LearningEventSubscriber stopped");
374 }
375
376 async fn run_with_flush_interval(&mut self, interval_ms: u64) {
378 use std::time::Duration;
379 use tokio::time::{interval, Instant};
380
381 let mut flush_interval = interval(Duration::from_millis(interval_ms));
382 let mut last_flush = Instant::now();
383
384 loop {
385 tokio::select! {
386 result = self.rx.recv() => {
388 match result {
389 Ok(event) => {
390 self.buffer.push(Record::from(&event));
391
392 if self.buffer.len() >= self.config.batch_size {
393 if !self.flush().await {
394 return;
395 }
396 last_flush = Instant::now();
397 }
398 }
399 Err(broadcast::error::RecvError::Closed) => {
400 tracing::debug!("LearningEvent channel closed");
401 return;
402 }
403 Err(broadcast::error::RecvError::Lagged(n)) => {
404 tracing::warn!(lagged = n, "LearningEventSubscriber lagged behind");
405 }
406 }
407 }
408
409 _ = flush_interval.tick() => {
411 if !self.buffer.is_empty() && last_flush.elapsed().as_millis() as u64 >= interval_ms {
412 if !self.flush().await {
413 return;
414 }
415 last_flush = Instant::now();
416 }
417 }
418 }
419 }
420 }
421
422 async fn run_batch_only(&mut self) {
424 loop {
425 match self.rx.recv().await {
426 Ok(event) => {
427 self.buffer.push(Record::from(&event));
428
429 if self.buffer.len() >= self.config.batch_size && !self.flush().await {
430 return;
431 }
432 }
433 Err(broadcast::error::RecvError::Closed) => {
434 tracing::debug!("LearningEvent channel closed");
435 return;
436 }
437 Err(broadcast::error::RecvError::Lagged(n)) => {
438 tracing::warn!(lagged = n, "LearningEventSubscriber lagged behind");
439 }
440 }
441 }
442 }
443
444 async fn flush(&mut self) -> bool {
448 if self.buffer.is_empty() {
449 return true;
450 }
451
452 let records = std::mem::take(&mut self.buffer);
453 let count = records.len();
454
455 match self.record_tx.send(records).await {
456 Ok(()) => {
457 tracing::debug!(count, "Flushed LearningEvent records to LearningDaemon");
458 true
459 }
460 Err(_) => {
461 tracing::warn!("LearningDaemon channel closed");
462 false
463 }
464 }
465 }
466}
467
468#[cfg(test)]
473mod tests {
474 use super::*;
475 use crate::events::{ActionEventBuilder, ActionEventResult, LearningEvent};
476 use crate::types::WorkerId;
477 use std::time::Duration;
478
479 fn make_action_event(tick: u64, action: &str) -> ActionEvent {
480 ActionEventBuilder::new(tick, WorkerId(0), action)
481 .result(ActionEventResult::success())
482 .duration(Duration::from_millis(10))
483 .build()
484 }
485
486 fn make_learning_event(model: &str) -> LearningEvent {
487 LearningEvent::dependency_graph_inference(model)
488 .prompt("test prompt")
489 .response("test response")
490 .discover_order(vec!["A".into(), "B".into()])
491 .success()
492 .build()
493 }
494
495 #[tokio::test]
500 async fn test_action_subscriber_batch() {
501 let (tx, rx) = broadcast::channel::<ActionEvent>(16);
502 let (record_tx, mut record_rx) = mpsc::channel::<Vec<Record>>(16);
503
504 let config = EventSubscriberConfig::new()
505 .batch_size(3)
506 .no_flush_interval();
507
508 let subscriber = ActionEventSubscriber::with_config(rx, record_tx, config);
509
510 let handle = tokio::spawn(async move {
511 subscriber.run().await;
512 });
513
514 for i in 0..5 {
516 tx.send(make_action_event(i, &format!("Action{}", i)))
517 .unwrap();
518 }
519
520 tokio::time::sleep(Duration::from_millis(50)).await;
522
523 let batch = record_rx.try_recv().unwrap();
525 assert_eq!(batch.len(), 3);
526
527 drop(tx);
529
530 let _ = handle.await;
532
533 let batch = record_rx.try_recv().unwrap();
535 assert_eq!(batch.len(), 2);
536 }
537
538 #[tokio::test]
539 async fn test_action_subscriber_flush_interval() {
540 let (tx, rx) = broadcast::channel::<ActionEvent>(16);
541 let (record_tx, mut record_rx) = mpsc::channel::<Vec<Record>>(16);
542
543 let config = EventSubscriberConfig::new()
544 .batch_size(100) .flush_interval_ms(50); let subscriber = ActionEventSubscriber::with_config(rx, record_tx, config);
548
549 let handle = tokio::spawn(async move {
550 subscriber.run().await;
551 });
552
553 tx.send(make_action_event(0, "Action0")).unwrap();
555 tx.send(make_action_event(1, "Action1")).unwrap();
556
557 tokio::time::sleep(Duration::from_millis(100)).await;
559
560 let batch = record_rx.try_recv().unwrap();
562 assert_eq!(batch.len(), 2);
563
564 drop(tx);
565 let _ = handle.await;
566 }
567
568 #[tokio::test]
569 async fn test_action_subscriber_channel_closed() {
570 let (tx, rx) = broadcast::channel::<ActionEvent>(16);
571 let (record_tx, record_rx) = mpsc::channel::<Vec<Record>>(16);
572
573 let config = EventSubscriberConfig::new()
574 .batch_size(100)
575 .no_flush_interval();
576
577 let subscriber = ActionEventSubscriber::with_config(rx, record_tx, config);
578
579 let handle = tokio::spawn(async move {
580 subscriber.run().await;
581 });
582
583 tx.send(make_action_event(0, "Action0")).unwrap();
585
586 drop(record_rx);
588
589 tx.send(make_action_event(1, "Action1")).unwrap();
591
592 tokio::time::sleep(Duration::from_millis(50)).await;
594
595 drop(tx);
596 let _ = handle.await;
597 }
598
599 #[tokio::test]
604 async fn test_learning_subscriber_batch() {
605 let (tx, rx) = broadcast::channel::<LearningEvent>(16);
606 let (record_tx, mut record_rx) = mpsc::channel::<Vec<Record>>(16);
607
608 let config = EventSubscriberConfig::new()
609 .batch_size(2)
610 .no_flush_interval();
611
612 let subscriber = LearningEventSubscriber::with_config(rx, record_tx, config);
613
614 let handle = tokio::spawn(async move {
615 subscriber.run().await;
616 });
617
618 for i in 0..3 {
620 tx.send(make_learning_event(&format!("model{}", i)))
621 .unwrap();
622 }
623
624 tokio::time::sleep(Duration::from_millis(50)).await;
626
627 let batch = record_rx.try_recv().unwrap();
629 assert_eq!(batch.len(), 2);
630
631 for record in &batch {
633 assert!(record.is_dependency_graph());
634 }
635
636 drop(tx);
638
639 let _ = handle.await;
641
642 let batch = record_rx.try_recv().unwrap();
644 assert_eq!(batch.len(), 1);
645 }
646
647 #[tokio::test]
648 async fn test_learning_subscriber_flush_interval() {
649 let (tx, rx) = broadcast::channel::<LearningEvent>(16);
650 let (record_tx, mut record_rx) = mpsc::channel::<Vec<Record>>(16);
651
652 let config = EventSubscriberConfig::new()
653 .batch_size(100) .flush_interval_ms(50); let subscriber = LearningEventSubscriber::with_config(rx, record_tx, config);
657
658 let handle = tokio::spawn(async move {
659 subscriber.run().await;
660 });
661
662 tx.send(make_learning_event("model")).unwrap();
664
665 tokio::time::sleep(Duration::from_millis(100)).await;
667
668 let batch = record_rx.try_recv().unwrap();
670 assert_eq!(batch.len(), 1);
671 assert!(batch[0].is_dependency_graph());
672
673 drop(tx);
674 let _ = handle.await;
675 }
676
677 #[tokio::test]
678 async fn test_learning_subscriber_converts_to_dependency_graph_record() {
679 let (tx, rx) = broadcast::channel::<LearningEvent>(16);
680 let (record_tx, mut record_rx) = mpsc::channel::<Vec<Record>>(16);
681
682 let config = EventSubscriberConfig::new()
683 .batch_size(1)
684 .no_flush_interval();
685
686 let subscriber = LearningEventSubscriber::with_config(rx, record_tx, config);
687
688 let handle = tokio::spawn(async move {
689 subscriber.run().await;
690 });
691
692 tx.send(make_learning_event("test-model")).unwrap();
694
695 tokio::time::sleep(Duration::from_millis(50)).await;
697
698 let batch = record_rx.try_recv().unwrap();
699 assert_eq!(batch.len(), 1);
700
701 let record = batch[0].as_dependency_graph().unwrap();
703 assert_eq!(record.model, "test-model");
704 assert_eq!(record.prompt, "test prompt");
705 assert_eq!(record.discover_order, vec!["A", "B"]);
706
707 drop(tx);
708 let _ = handle.await;
709 }
710
711 #[tokio::test]
716 async fn test_subscriber_config() {
717 let config = EventSubscriberConfig::new()
718 .batch_size(50)
719 .flush_interval_ms(500);
720
721 assert_eq!(config.batch_size, 50);
722 assert_eq!(config.flush_interval_ms, Some(500));
723
724 let config2 = EventSubscriberConfig::new().no_flush_interval();
725 assert_eq!(config2.flush_interval_ms, None);
726 }
727}