1mod applier;
30mod processor;
31mod sink;
32mod subscriber;
33
34pub use applier::{Applier, ApplierConfig, ApplierError, ApplyMode, ApplyResult};
35pub use processor::{ProcessResult, Processor, ProcessorConfig, ProcessorError, ProcessorMode};
36pub use sink::{DataSink, DataSinkError, DataSinkStats};
37pub use subscriber::{ActionEventSubscriber, EventSubscriberConfig, LearningEventSubscriber};
38
39use std::path::PathBuf;
40use std::sync::Arc;
41use std::time::Duration;
42
43use tokio::sync::mpsc;
44use tokio::time::interval;
45
46use crate::learn::learn_model::{LearnModel, WorkerDecisionSequenceLearn};
47use crate::learn::lora::{
48 LoraTrainer, LoraTrainerConfig, ModelApplicator, NoOpApplicator, TrainedModel,
49};
50use crate::learn::record::{DependencyGraphRecord, LearnStatsRecord, Record};
51use crate::learn::snapshot::LearningStore;
52use crate::learn::store::{
53 EpisodeStore, FileEpisodeStore, FileRecordStore, InMemoryEpisodeStore, InMemoryRecordStore,
54 RecordStore, RecordStoreError, StoreError,
55};
56use crate::learn::trigger::{TrainTrigger, TriggerBuilder, TriggerContext};
57use crate::learn::LearnStats;
58use crate::util::epoch_millis;
59
60#[derive(Debug)]
66pub enum DaemonError {
67 Sink(DataSinkError),
69 Processor(ProcessorError),
71 Applier(ApplierError),
73 Io(std::io::Error),
75 Config(String),
77 Shutdown,
79}
80
81impl std::fmt::Display for DaemonError {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 match self {
84 Self::Sink(e) => write!(f, "Sink error: {}", e),
85 Self::Processor(e) => write!(f, "Processor error: {}", e),
86 Self::Applier(e) => write!(f, "Applier error: {}", e),
87 Self::Io(e) => write!(f, "IO error: {}", e),
88 Self::Config(msg) => write!(f, "Config error: {}", msg),
89 Self::Shutdown => write!(f, "Daemon shutdown"),
90 }
91 }
92}
93
94impl std::error::Error for DaemonError {}
95
96impl From<DataSinkError> for DaemonError {
97 fn from(e: DataSinkError) -> Self {
98 Self::Sink(e)
99 }
100}
101
102impl From<ProcessorError> for DaemonError {
103 fn from(e: ProcessorError) -> Self {
104 Self::Processor(e)
105 }
106}
107
108impl From<ApplierError> for DaemonError {
109 fn from(e: ApplierError) -> Self {
110 Self::Applier(e)
111 }
112}
113
114impl From<std::io::Error> for DaemonError {
115 fn from(e: std::io::Error) -> Self {
116 Self::Io(e)
117 }
118}
119
120impl From<RecordStoreError> for DaemonError {
121 fn from(e: RecordStoreError) -> Self {
122 Self::Sink(DataSinkError::RecordStore(e))
123 }
124}
125
126impl From<StoreError> for DaemonError {
127 fn from(e: StoreError) -> Self {
128 Self::Sink(DataSinkError::EpisodeStore(e))
129 }
130}
131
132#[derive(Debug, Clone)]
138pub struct DaemonConfig {
139 pub scenario: String,
141 pub data_dir: PathBuf,
143 pub check_interval: Duration,
145 pub processor_mode: ProcessorMode,
147 pub max_sessions: usize,
149 pub auto_apply: bool,
151 pub lora_config: Option<LoraTrainerConfig>,
153}
154
155impl DaemonConfig {
156 pub fn new(scenario: impl Into<String>) -> Self {
158 Self {
159 scenario: scenario.into(),
160 data_dir: default_data_dir(),
161 check_interval: Duration::from_secs(10),
162 processor_mode: ProcessorMode::OfflineOnly,
163 max_sessions: 20,
164 auto_apply: false,
165 lora_config: None,
166 }
167 }
168
169 pub fn data_dir(mut self, path: impl Into<PathBuf>) -> Self {
171 self.data_dir = path.into();
172 self
173 }
174
175 pub fn check_interval(mut self, interval: Duration) -> Self {
177 self.check_interval = interval;
178 self
179 }
180
181 pub fn processor_mode(mut self, mode: ProcessorMode) -> Self {
183 self.processor_mode = mode;
184 self
185 }
186
187 pub fn max_sessions(mut self, n: usize) -> Self {
189 self.max_sessions = n;
190 self
191 }
192
193 pub fn auto_apply(mut self, enabled: bool) -> Self {
195 self.auto_apply = enabled;
196 self
197 }
198
199 pub fn with_lora(mut self, config: LoraTrainerConfig) -> Self {
201 self.lora_config = Some(config);
202 self
203 }
204}
205
206fn default_data_dir() -> PathBuf {
207 dirs::data_dir()
208 .unwrap_or_else(|| PathBuf::from("."))
209 .join("swarm-engine")
210 .join("learning")
211}
212
213#[derive(Debug, Clone, Default)]
219pub struct DaemonStats {
220 pub records_received: usize,
222 pub episodes_created: usize,
224 pub trainings_completed: usize,
226 pub models_applied: usize,
228 pub last_train_at: Option<u64>,
230 pub started_at: u64,
232}
233
234pub struct LearningDaemon {
240 config: DaemonConfig,
242 sink: DataSink,
244 trigger: Arc<dyn TrainTrigger>,
246 processor: Processor,
248 applier: Option<Applier>,
250 learning_store: LearningStore,
252 stats: DaemonStats,
254 last_train_count: usize,
256 record_rx: mpsc::Receiver<Vec<Record>>,
258 record_tx: mpsc::Sender<Vec<Record>>,
260 shutdown_rx: mpsc::Receiver<()>,
262 shutdown_tx: mpsc::Sender<()>,
264}
265
266impl LearningDaemon {
267 pub fn new(config: DaemonConfig, trigger: Arc<dyn TrainTrigger>) -> Result<Self, DaemonError> {
269 let record_store: Arc<dyn RecordStore> = Arc::new(InMemoryRecordStore::new());
270 let episode_store: Arc<dyn EpisodeStore> = Arc::new(InMemoryEpisodeStore::new());
271 let learn_model: Arc<dyn LearnModel> = Arc::new(WorkerDecisionSequenceLearn::new());
272
273 Self::with_stores(config, trigger, record_store, episode_store, learn_model)
274 }
275
276 pub fn with_file_stores(
278 config: DaemonConfig,
279 trigger: Arc<dyn TrainTrigger>,
280 ) -> Result<Self, DaemonError> {
281 std::fs::create_dir_all(&config.data_dir)?;
283
284 let record_store: Arc<dyn RecordStore> =
285 Arc::new(FileRecordStore::new(config.data_dir.join("records"))?);
286 let episode_store: Arc<dyn EpisodeStore> =
287 Arc::new(FileEpisodeStore::new(config.data_dir.join("episodes"))?);
288 let learn_model: Arc<dyn LearnModel> = Arc::new(WorkerDecisionSequenceLearn::new());
289
290 Self::with_stores(config, trigger, record_store, episode_store, learn_model)
291 }
292
293 pub fn with_stores(
295 config: DaemonConfig,
296 trigger: Arc<dyn TrainTrigger>,
297 record_store: Arc<dyn RecordStore>,
298 episode_store: Arc<dyn EpisodeStore>,
299 learn_model: Arc<dyn LearnModel>,
300 ) -> Result<Self, DaemonError> {
301 let sink = DataSink::new(
303 record_store,
304 Arc::clone(&episode_store),
305 Arc::clone(&learn_model),
306 );
307
308 let processor_config = ProcessorConfig::new(&config.scenario)
310 .mode(config.processor_mode)
311 .max_sessions(config.max_sessions);
312
313 let mut processor = Processor::new(processor_config);
314
315 let learning_store = LearningStore::new(&config.data_dir)?;
317 let learning_store_for_processor = LearningStore::new(&config.data_dir)?;
319 processor = processor.with_learning_store(learning_store_for_processor);
320
321 if let Some(lora_config) = &config.lora_config {
323 let trainer = LoraTrainer::new(lora_config.clone(), episode_store);
324 processor = processor
325 .with_lora_trainer(trainer)
326 .with_learn_model(learn_model);
327 }
328
329 let applier = if config.auto_apply {
331 let applier_config = ApplierConfig::default().auto_apply();
332 let applicator: Arc<dyn ModelApplicator> = Arc::new(NoOpApplicator::new());
334 Some(Applier::new(applier_config, applicator))
335 } else {
336 None
337 };
338
339 let (record_tx, record_rx) = mpsc::channel(1000);
341 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
342
343 Ok(Self {
344 config,
345 sink,
346 trigger,
347 processor,
348 applier,
349 learning_store,
350 stats: DaemonStats {
351 started_at: epoch_millis(),
352 ..Default::default()
353 },
354 last_train_count: 0,
355 record_rx,
356 record_tx,
357 shutdown_rx,
358 shutdown_tx,
359 })
360 }
361
362 pub fn record_sender(&self) -> mpsc::Sender<Vec<Record>> {
364 self.record_tx.clone()
365 }
366
367 pub fn shutdown_sender(&self) -> mpsc::Sender<()> {
369 self.shutdown_tx.clone()
370 }
371
372 pub fn config(&self) -> &DaemonConfig {
374 &self.config
375 }
376
377 pub fn stats(&self) -> &DaemonStats {
379 &self.stats
380 }
381
382 pub async fn run(&mut self) -> Result<(), DaemonError> {
384 tracing::info!(
385 scenario = %self.config.scenario,
386 data_dir = %self.config.data_dir.display(),
387 trigger = self.trigger.name(),
388 "Learning daemon started"
389 );
390
391 let mut check_interval = interval(self.config.check_interval);
392
393 loop {
394 tokio::select! {
395 _ = self.shutdown_rx.recv() => {
397 tracing::info!("Shutdown signal received, draining remaining records...");
398
399 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
402
403 while let Ok(records) = self.record_rx.try_recv() {
405 if let Err(e) = self.handle_records(records).await {
406 tracing::warn!("Error processing records during shutdown: {}", e);
407 }
408 }
409
410 tracing::info!(
411 records_received = self.stats.records_received,
412 episodes_created = self.stats.episodes_created,
413 "Shutdown complete"
414 );
415 return Ok(());
416 }
417
418 Some(records) = self.record_rx.recv() => {
420 self.handle_records(records).await?;
421 }
422
423 _ = check_interval.tick() => {
425 self.check_and_train().await?;
426 }
427 }
428 }
429 }
430
431 async fn handle_records(&mut self, records: Vec<Record>) -> Result<(), DaemonError> {
433 if records.is_empty() {
434 return Ok(());
435 }
436
437 let count = records.len();
438
439 for record in &records {
441 match record {
442 Record::LearnStats(stats_record) => {
443 self.save_stats_to_learning_store(stats_record);
444 }
445 Record::DependencyGraph(dep_graph_record) => {
446 self.save_dependency_graph_to_learning_store(dep_graph_record);
447 }
448 _ => {}
449 }
450 }
451
452 let episode_ids = self.sink.ingest(records)?;
453
454 self.stats.records_received += count;
455 self.stats.episodes_created += episode_ids.len();
456
457 tracing::debug!(
458 records = count,
459 episodes = episode_ids.len(),
460 "Processed records"
461 );
462
463 Ok(())
464 }
465
466 fn save_stats_to_learning_store(&self, stats_record: &LearnStatsRecord) {
468 use crate::learn::snapshot::{LearningSnapshot, SnapshotMetadata, SNAPSHOT_VERSION};
469 use crate::learn::{EpisodeTransitions, NgramStats, SelectionPerformance};
470 use crate::online_stats::ActionStats;
471 use std::collections::HashMap;
472
473 let learn_stats: Option<LearnStats> = serde_json::from_str(&stats_record.stats_json).ok();
475
476 let metadata = SnapshotMetadata {
478 scenario_name: Some(stats_record.scenario.clone()),
479 task_description: None,
480 created_at: stats_record.timestamp_ms / 1000, session_count: 1,
482 total_episodes: 1,
483 total_actions: stats_record.total_actions as u32,
484 phase: None,
485 group_id: None,
486 };
487
488 let (
490 episode_transitions,
491 action_stats,
492 ngram_stats,
493 selection_performance,
494 contextual_stats,
495 ) = if let Some(ref stats) = learn_stats {
496 let transitions = stats.episode_transitions.clone();
498
499 let ngram = stats.ngram_stats.clone();
501
502 let selection = stats.selection_performance.clone();
504
505 let mut ctx_stats: HashMap<(String, String), ActionStats> = HashMap::new();
507 for ((prev, action), ctx) in &stats.contextual_stats {
508 ctx_stats.insert(
509 (prev.clone(), action.clone()),
510 ActionStats {
511 visits: ctx.visits,
512 successes: ctx.successes,
513 failures: ctx.failures,
514 ..Default::default()
515 },
516 );
517 }
518
519 let action_stats: HashMap<String, ActionStats> = HashMap::new();
521
522 (transitions, action_stats, ngram, selection, ctx_stats)
523 } else {
524 (
525 EpisodeTransitions::default(),
526 HashMap::new(),
527 NgramStats::default(),
528 SelectionPerformance::default(),
529 HashMap::new(),
530 )
531 };
532
533 let snapshot = LearningSnapshot {
535 version: SNAPSHOT_VERSION,
536 metadata,
537 episode_transitions,
538 ngram_stats,
539 selection_performance,
540 contextual_stats,
541 action_stats,
542 };
543
544 match self
546 .learning_store
547 .save_session(&stats_record.scenario, &snapshot)
548 {
549 Ok(session_id) => {
550 tracing::info!(
551 scenario = %stats_record.scenario,
552 session_id = %session_id.0,
553 success = stats_record.is_success(),
554 "Saved session to LearningStore"
555 );
556 }
557 Err(e) => {
558 tracing::warn!(
559 scenario = %stats_record.scenario,
560 error = %e,
561 "Failed to save session to LearningStore"
562 );
563 }
564 }
565 }
566
567 fn save_dependency_graph_to_learning_store(&self, record: &DependencyGraphRecord) {
572 use crate::learn::{ActionOrderSource, LearnedActionOrder};
573
574 let all_actions: Vec<String> = record
576 .discover_order
577 .iter()
578 .chain(record.not_discover_order.iter())
579 .cloned()
580 .collect();
581 let action_set_hash = LearnedActionOrder::compute_hash(&all_actions);
582
583 let action_order = LearnedActionOrder {
584 discover: record.discover_order.clone(),
585 not_discover: record.not_discover_order.clone(),
586 action_set_hash,
587 source: ActionOrderSource::Llm,
588 lora: None,
589 validated_accuracy: None,
590 };
591
592 let scenario = &self.config.scenario;
594 let model_result = self.learning_store.load_offline_model(scenario);
595
596 let updated_model = match model_result {
597 Ok(mut model) => {
598 model.action_order = Some(action_order.clone());
599 model
600 }
601 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
602 crate::learn::OfflineModel {
604 action_order: Some(action_order.clone()),
605 ..Default::default()
606 }
607 }
608 Err(e) => {
609 tracing::warn!(
610 scenario = %scenario,
611 error = %e,
612 "Failed to load OfflineModel for action_order update"
613 );
614 return;
615 }
616 };
617
618 match self
620 .learning_store
621 .save_offline_model(scenario, &updated_model)
622 {
623 Ok(()) => {
624 tracing::info!(
625 scenario = %scenario,
626 discover = ?action_order.discover,
627 not_discover = ?action_order.not_discover,
628 action_set_hash = action_order.action_set_hash,
629 "Saved action_order to OfflineModel"
630 );
631 }
632 Err(e) => {
633 tracing::warn!(
634 scenario = %scenario,
635 error = %e,
636 "Failed to save action_order to OfflineModel"
637 );
638 }
639 }
640 }
641
642 async fn check_and_train(&mut self) -> Result<(), DaemonError> {
644 let current_count = self.sink.episode_count();
646 let ctx = TriggerContext::with_count(current_count)
647 .last_train_at(self.stats.last_train_at.unwrap_or(0))
648 .last_train_count(self.last_train_count);
649
650 if !self.trigger.should_train(&ctx).unwrap_or(false) {
651 return Ok(());
652 }
653
654 tracing::info!(
655 episode_count = current_count,
656 trigger = self.trigger.name(),
657 "Trigger fired, starting learning"
658 );
659
660 let result = self
662 .processor
663 .run(self.sink.episode_store().as_ref())
664 .await?;
665
666 self.stats.trainings_completed += 1;
668 self.stats.last_train_at = Some(epoch_millis());
669 self.last_train_count = current_count;
670
671 if let Some(applier) = &mut self.applier {
673 if let Some(model) = result.lora_model() {
674 let apply_result = applier.apply(model).await?;
675 if apply_result.is_applied() {
676 self.stats.models_applied += 1;
677 }
678 }
679 }
680
681 tracing::info!(
682 trainings = self.stats.trainings_completed,
683 models_applied = self.stats.models_applied,
684 "Learning cycle completed"
685 );
686
687 Ok(())
688 }
689
690 pub async fn train_now(&mut self) -> Result<ProcessResult, DaemonError> {
692 tracing::info!("Manual training triggered");
693
694 let result = self
695 .processor
696 .run(self.sink.episode_store().as_ref())
697 .await?;
698
699 self.stats.trainings_completed += 1;
700 self.stats.last_train_at = Some(epoch_millis());
701 self.last_train_count = self.sink.episode_count();
702
703 Ok(result)
704 }
705
706 pub async fn apply_model(&mut self, model: &TrainedModel) -> Result<ApplyResult, DaemonError> {
708 let applier = self
709 .applier
710 .as_mut()
711 .ok_or_else(|| DaemonError::Config("Applier not configured".into()))?;
712
713 let result = applier.apply_now(model).await?;
714 if result.is_applied() {
715 self.stats.models_applied += 1;
716 }
717
718 Ok(result)
719 }
720}
721
722pub struct DaemonBuilder {
728 config: DaemonConfig,
729 trigger: Option<Arc<dyn TrainTrigger>>,
730 record_store: Option<Arc<dyn RecordStore>>,
731 episode_store: Option<Arc<dyn EpisodeStore>>,
732 learn_model: Option<Arc<dyn LearnModel>>,
733 applicator: Option<Arc<dyn ModelApplicator>>,
734}
735
736impl DaemonBuilder {
737 pub fn new(scenario: impl Into<String>) -> Self {
739 Self {
740 config: DaemonConfig::new(scenario),
741 trigger: None,
742 record_store: None,
743 episode_store: None,
744 learn_model: None,
745 applicator: None,
746 }
747 }
748
749 pub fn data_dir(mut self, path: impl Into<PathBuf>) -> Self {
751 self.config.data_dir = path.into();
752 self
753 }
754
755 pub fn trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
757 self.trigger = Some(trigger);
758 self
759 }
760
761 pub fn processor_mode(mut self, mode: ProcessorMode) -> Self {
763 self.config.processor_mode = mode;
764 self
765 }
766
767 pub fn auto_apply(mut self) -> Self {
769 self.config.auto_apply = true;
770 self
771 }
772
773 pub fn record_store(mut self, store: Arc<dyn RecordStore>) -> Self {
775 self.record_store = Some(store);
776 self
777 }
778
779 pub fn episode_store(mut self, store: Arc<dyn EpisodeStore>) -> Self {
781 self.episode_store = Some(store);
782 self
783 }
784
785 pub fn learn_model(mut self, model: Arc<dyn LearnModel>) -> Self {
787 self.learn_model = Some(model);
788 self
789 }
790
791 pub fn applicator(mut self, applicator: Arc<dyn ModelApplicator>) -> Self {
793 self.applicator = Some(applicator);
794 self
795 }
796
797 pub fn with_lora(mut self, config: LoraTrainerConfig) -> Self {
799 self.config.lora_config = Some(config);
800 self
801 }
802
803 pub fn build(self) -> Result<LearningDaemon, DaemonError> {
805 let trigger = self
806 .trigger
807 .unwrap_or_else(|| TriggerBuilder::default_watch());
808
809 let record_store = self
810 .record_store
811 .unwrap_or_else(|| Arc::new(InMemoryRecordStore::new()));
812
813 let episode_store = self
814 .episode_store
815 .unwrap_or_else(|| Arc::new(InMemoryEpisodeStore::new()));
816
817 let learn_model = self
818 .learn_model
819 .unwrap_or_else(|| Arc::new(WorkerDecisionSequenceLearn::new()));
820
821 LearningDaemon::with_stores(
822 self.config,
823 trigger,
824 record_store,
825 episode_store,
826 learn_model,
827 )
828 }
829}
830
831#[cfg(test)]
836mod tests {
837 use super::*;
838 use crate::events::{ActionContext, ActionEventBuilder, ActionEventResult};
839 use crate::learn::trigger::AlwaysTrigger;
840 use crate::types::WorkerId;
841
842 fn make_test_records(count: usize) -> Vec<Record> {
843 (0..count)
844 .map(|i| {
845 let event = ActionEventBuilder::new(i as u64, WorkerId(0), format!("Action{}", i))
846 .result(ActionEventResult::success())
847 .duration(std::time::Duration::from_millis(10))
848 .context(ActionContext::new())
849 .build();
850 Record::from(&event)
851 })
852 .collect()
853 }
854
855 #[test]
856 fn test_daemon_config_builder() {
857 let config = DaemonConfig::new("test")
858 .data_dir("/tmp/test")
859 .check_interval(Duration::from_secs(30))
860 .processor_mode(ProcessorMode::Full)
861 .auto_apply(true);
862
863 assert_eq!(config.scenario, "test");
864 assert_eq!(config.data_dir, PathBuf::from("/tmp/test"));
865 assert_eq!(config.check_interval, Duration::from_secs(30));
866 assert_eq!(config.processor_mode, ProcessorMode::Full);
867 assert!(config.auto_apply);
868 }
869
870 #[tokio::test]
871 async fn test_daemon_creation() {
872 let config = DaemonConfig::new("test");
873 let trigger = TriggerBuilder::never();
874
875 let daemon = LearningDaemon::new(config, trigger).unwrap();
876 assert_eq!(daemon.config().scenario, "test");
877 assert_eq!(daemon.stats().records_received, 0);
878 }
879
880 #[tokio::test]
881 async fn test_daemon_record_ingestion() {
882 let config = DaemonConfig::new("test");
883 let trigger = TriggerBuilder::never(); let mut daemon = LearningDaemon::new(config, trigger).unwrap();
886 let sender = daemon.record_sender();
887
888 let records = make_test_records(5);
890 sender.send(records).await.unwrap();
891
892 daemon.handle_records(make_test_records(3)).await.unwrap();
894
895 assert_eq!(daemon.stats().records_received, 3);
896 }
897
898 #[tokio::test]
899 async fn test_daemon_builder() {
900 let daemon = DaemonBuilder::new("test-scenario")
901 .data_dir("/tmp/test")
902 .trigger(Arc::new(AlwaysTrigger))
903 .processor_mode(ProcessorMode::OfflineOnly)
904 .build()
905 .unwrap();
906
907 assert_eq!(daemon.config().scenario, "test-scenario");
908 }
909}