1use std::path::PathBuf;
50use std::sync::Arc;
51use std::time::Duration;
52
53use tokio::runtime::Handle;
54use tokio::sync::mpsc;
55use tokio::task::JoinHandle;
56
57use crate::agent::{BatchInvoker, ManagerAgent, WorkerAgent};
58use crate::error::SwarmError;
59use crate::events::{ActionEventPublisher, LearningEventChannel, LifecycleHook, TraceSubscriber};
60use crate::exploration::{DependencyGraph, DependencyGraphProvider, NodeRules, OperatorProvider};
61use crate::extensions::Extensions;
62use crate::orchestrator::{Orchestrator, OrchestratorBuilder, SwarmConfig, SwarmResult};
63use crate::types::SwarmTask;
64
65use super::daemon::{
66 ActionEventSubscriber, DaemonConfig, DaemonError, EventSubscriberConfig, LearningDaemon,
67 LearningEventSubscriber,
68};
69use super::profile_adapter::profile_to_offline_model;
70use super::scenario_profile::ScenarioProfile;
71use super::snapshot::LearningStore;
72use super::trigger::{TrainTrigger, TriggerBuilder};
73use super::LearningSnapshot;
74use super::OfflineModel;
75
76type DaemonHandle = JoinHandle<Result<(), DaemonError>>;
82
83#[derive(Default)]
85struct LearningSetupResult {
86 daemon_handle: Option<DaemonHandle>,
88 subscriber_handles: Vec<JoinHandle<()>>,
90 shutdown_tx: Option<mpsc::Sender<()>>,
92}
93
94#[derive(Debug, Clone)]
100pub struct LearnableSwarmConfig {
101 pub scenario: String,
103 pub data_dir: PathBuf,
105 pub learning_enabled: bool,
107 pub subscriber_batch_size: usize,
109 pub subscriber_flush_interval_ms: u64,
111 pub daemon_check_interval: Duration,
113}
114
115impl Default for LearnableSwarmConfig {
116 fn default() -> Self {
117 Self {
118 scenario: String::new(),
119 data_dir: default_learning_dir(),
120 learning_enabled: false,
121 subscriber_batch_size: 100,
122 subscriber_flush_interval_ms: 1000,
123 daemon_check_interval: Duration::from_secs(10),
124 }
125 }
126}
127
128impl LearnableSwarmConfig {
129 pub fn new(scenario: impl Into<String>) -> Self {
131 Self {
132 scenario: scenario.into(),
133 ..Default::default()
134 }
135 }
136
137 pub fn with_learning(mut self, enabled: bool) -> Self {
139 self.learning_enabled = enabled;
140 self
141 }
142
143 pub fn data_dir(mut self, path: impl Into<PathBuf>) -> Self {
145 self.data_dir = path.into();
146 self
147 }
148}
149
150fn default_learning_dir() -> PathBuf {
151 dirs::data_dir()
152 .unwrap_or_else(|| PathBuf::from("."))
153 .join("swarm-engine")
154 .join("learning")
155}
156
157pub struct LearnableSwarmBuilder {
165 runtime: Handle,
167 config: LearnableSwarmConfig,
169 swarm_config: Option<SwarmConfig>,
171 workers: Vec<Box<dyn WorkerAgent>>,
173 managers: Vec<Box<dyn ManagerAgent>>,
175 batch_invoker: Option<Box<dyn BatchInvoker>>,
177 dependency_provider: Option<Box<dyn DependencyGraphProvider>>,
179 operator_provider: Option<Box<dyn OperatorProvider<NodeRules>>>,
181 extensions: Option<Extensions>,
183 dependency_graph: Option<DependencyGraph>,
185 offline_model: Option<OfflineModel>,
187 prior_snapshot: Option<LearningSnapshot>,
189 learning_store: Option<LearningStore>,
191 train_trigger: Option<Arc<dyn TrainTrigger>>,
193 lifecycle_hook: Option<Box<dyn LifecycleHook>>,
195 enable_exploration: bool,
197 deferred_error: Option<SwarmError>,
199 trace_subscriber: Option<Arc<dyn TraceSubscriber>>,
201}
202
203impl LearnableSwarmBuilder {
204 pub fn new(runtime: Handle) -> Self {
206 Self {
207 runtime,
208 config: LearnableSwarmConfig::default(),
209 swarm_config: None,
210 workers: Vec::new(),
211 managers: Vec::new(),
212 batch_invoker: None,
213 dependency_provider: None,
214 operator_provider: None,
215 extensions: None,
216 dependency_graph: None,
217 offline_model: None,
218 prior_snapshot: None,
219 learning_store: None,
220 train_trigger: None,
221 lifecycle_hook: None,
222 enable_exploration: false,
223 deferred_error: None,
224 trace_subscriber: None,
225 }
226 }
227
228 pub fn scenario(mut self, name: impl Into<String>) -> Self {
230 self.config.scenario = name.into();
231 self
232 }
233
234 pub fn with_learning(mut self, enabled: bool) -> Self {
238 self.config.learning_enabled = enabled;
239 self
240 }
241
242 pub fn data_dir(mut self, path: impl Into<PathBuf>) -> Self {
244 self.config.data_dir = path.into();
245 self
246 }
247
248 pub fn swarm_config(mut self, config: SwarmConfig) -> Self {
250 self.swarm_config = Some(config);
251 self
252 }
253
254 pub fn add_worker(mut self, worker: Box<dyn WorkerAgent>) -> Self {
256 self.workers.push(worker);
257 self
258 }
259
260 pub fn workers(mut self, workers: Vec<Box<dyn WorkerAgent>>) -> Self {
262 self.workers = workers;
263 self
264 }
265
266 pub fn add_manager(mut self, manager: Box<dyn ManagerAgent>) -> Self {
268 self.managers.push(manager);
269 self
270 }
271
272 pub fn managers(mut self, managers: Vec<Box<dyn ManagerAgent>>) -> Self {
274 self.managers = managers;
275 self
276 }
277
278 pub fn batch_invoker(mut self, invoker: Box<dyn BatchInvoker>) -> Self {
280 self.batch_invoker = Some(invoker);
281 self
282 }
283
284 pub fn dependency_provider(mut self, provider: Box<dyn DependencyGraphProvider>) -> Self {
286 self.dependency_provider = Some(provider);
287 self
288 }
289
290 pub fn operator_provider(mut self, provider: Box<dyn OperatorProvider<NodeRules>>) -> Self {
292 self.operator_provider = Some(provider);
293 self
294 }
295
296 pub fn extensions(mut self, extensions: Extensions) -> Self {
298 self.extensions = Some(extensions);
299 self
300 }
301
302 pub fn dependency_graph(mut self, graph: DependencyGraph) -> Self {
304 self.dependency_graph = Some(graph);
305 self
306 }
307
308 pub fn offline_model(mut self, model: OfflineModel) -> Self {
310 self.offline_model = Some(model);
311 self
312 }
313
314 pub fn with_scenario_profile(mut self, profile: &ScenarioProfile) -> Self {
330 let model = profile_to_offline_model(profile);
331 self.offline_model = Some(model);
332 if self.config.scenario.is_empty() {
334 self.config.scenario = profile.id.0.clone();
335 }
336 self
337 }
338
339 pub fn offline_model_ref(&self) -> Option<&OfflineModel> {
341 self.offline_model.as_ref()
342 }
343
344 pub fn prior_snapshot(mut self, snapshot: LearningSnapshot) -> Self {
346 self.prior_snapshot = Some(snapshot);
347 self
348 }
349
350 pub fn learning_store(mut self, store: LearningStore) -> Self {
352 self.learning_store = Some(store);
353 self
354 }
355
356 pub fn with_learning_store(mut self, store: LearningStore) -> Self {
365 if let Err(e) = self.load_from_store(&store) {
366 self.deferred_error = Some(e);
367 }
368 self.config.data_dir = store.storage().base_dir().to_path_buf();
369 self.config.learning_enabled = true;
370 self.learning_store = Some(store);
371 self
372 }
373
374 pub fn train_trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
376 self.train_trigger = Some(trigger);
377 self
378 }
379
380 pub fn lifecycle_hook(mut self, hook: Box<dyn LifecycleHook>) -> Self {
382 self.lifecycle_hook = Some(hook);
383 self
384 }
385
386 pub fn enable_exploration(mut self, enabled: bool) -> Self {
388 self.enable_exploration = enabled;
389 self
390 }
391
392 pub fn with_trace_subscriber(mut self, subscriber: Arc<dyn TraceSubscriber>) -> Self {
408 self.trace_subscriber = Some(subscriber);
409 self
410 }
411
412 pub fn with_learning_store_path(mut self, path: impl AsRef<std::path::Path>) -> Self {
419 match LearningStore::new(&path) {
420 Ok(store) => {
421 if let Err(e) = self.load_from_store(&store) {
422 self.deferred_error = Some(e);
423 }
424 self.config.learning_enabled = true;
425 self.learning_store = Some(store);
426 }
427 Err(e) => {
428 self.deferred_error = Some(SwarmError::Config {
430 message: format!(
431 "Failed to create LearningStore at '{}': {}",
432 path.as_ref().display(),
433 e
434 ),
435 });
436 }
437 }
438 self
439 }
440
441 fn load_from_store(&mut self, store: &LearningStore) -> Result<(), SwarmError> {
447 let scenario_key = &self.config.scenario;
448
449 match store.load_scenario(scenario_key) {
451 Ok(snapshot) => {
452 self.prior_snapshot = Some(snapshot);
453 }
454 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
455 tracing::debug!(scenario = %scenario_key, "No prior snapshot found (first run)");
457 }
458 Err(e) => {
459 return Err(SwarmError::Config {
461 message: format!(
462 "Failed to load prior snapshot for '{}': {}",
463 scenario_key, e
464 ),
465 });
466 }
467 }
468
469 match store.load_offline_model(scenario_key) {
471 Ok(model) => {
472 tracing::debug!(
473 ucb1_c = model.parameters.ucb1_c,
474 strategy = %model.strategy_config.initial_strategy,
475 action_order = model.action_order.is_some(),
476 "Offline model loaded"
477 );
478 self.offline_model = Some(model);
479 }
480 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
481 tracing::debug!(scenario = %scenario_key, "No offline model found (first run)");
483 }
484 Err(e) => {
485 return Err(SwarmError::Config {
487 message: format!("Failed to load offline model for '{}': {}", scenario_key, e),
488 });
489 }
490 }
491
492 Ok(())
493 }
494
495 pub fn build(mut self) -> Result<LearnableSwarm, SwarmError> {
497 if let Some(err) = self.deferred_error.take() {
502 return Err(err);
503 }
504
505 if self.config.learning_enabled && self.config.scenario.is_empty() {
506 return Err(SwarmError::Config {
507 message: "scenario is required when learning is enabled".into(),
508 });
509 }
510
511 if self.config.learning_enabled {
513 LearningEventChannel::global().enable();
514 }
515
516 let swarm_config = self.swarm_config.take().unwrap_or_default();
517
518 let (action_publisher, _initial_rx) = ActionEventPublisher::new(1024);
523
524 let mut subscriber_handles_trace = Vec::new();
526 if let Some(trace_subscriber) = self.trace_subscriber.take() {
527 let rx = action_publisher.subscribe();
528 let handle = self.runtime.spawn(async move {
529 run_trace_subscriber(rx, trace_subscriber).await;
530 });
531 subscriber_handles_trace.push(handle);
532 }
533
534 let LearningSetupResult {
535 daemon_handle,
536 mut subscriber_handles,
537 shutdown_tx,
538 } = if self.config.learning_enabled {
539 self.setup_learning_components(&action_publisher)?
540 } else {
541 LearningSetupResult::default()
542 };
543
544 subscriber_handles.extend(subscriber_handles_trace);
546
547 let extensions = self.build_extensions();
553
554 let mut orch_builder = OrchestratorBuilder::new()
555 .config(swarm_config)
556 .extensions(extensions);
557
558 for worker in self.workers {
560 orch_builder = orch_builder.add_worker_boxed(worker);
561 }
562
563 for manager in self.managers {
565 orch_builder = orch_builder.add_manager_boxed(manager);
566 }
567
568 if let Some(invoker) = self.batch_invoker {
570 orch_builder = orch_builder.batch_invoker_boxed(invoker);
571 }
572
573 if let Some(provider) = self.dependency_provider {
575 orch_builder = orch_builder.dependency_provider_boxed(provider);
576 }
577
578 if let Some(provider) = self.operator_provider {
580 orch_builder = orch_builder.operator_provider_boxed(provider);
581 }
582
583 if self.enable_exploration {
585 orch_builder = orch_builder.with_exploration();
586 }
587
588 if let Some(ref model) = self.offline_model {
590 orch_builder = orch_builder.with_offline_model(model.clone());
591 }
592
593 if let Some(hook) = self.lifecycle_hook {
595 orch_builder = orch_builder.lifecycle_hook(hook);
596 }
597
598 orch_builder = orch_builder.action_collector(action_publisher);
600
601 let orchestrator = orch_builder.build(self.runtime.clone());
603
604 Ok(LearnableSwarm {
605 orchestrator,
606 runtime: self.runtime,
607 config: self.config,
608 learning_store: self.learning_store,
609 offline_model: self.offline_model,
610 daemon_handle,
611 subscriber_handles,
612 shutdown_tx,
613 })
614 }
615
616 fn build_extensions(&mut self) -> Extensions {
618 let mut ext = self.extensions.take().unwrap_or_default();
619
620 if let Some(graph) = self.dependency_graph.take() {
621 ext.insert(graph);
622 }
623 if let Some(snapshot) = self.prior_snapshot.take() {
624 ext.insert(snapshot);
625 }
626
627 ext
628 }
629
630 fn setup_learning_components(
632 &self,
633 action_publisher: &ActionEventPublisher,
634 ) -> Result<LearningSetupResult, SwarmError> {
635 let daemon_config = DaemonConfig::new(&self.config.scenario)
637 .data_dir(&self.config.data_dir)
638 .check_interval(self.config.daemon_check_interval);
639
640 let trigger = self
641 .train_trigger
642 .clone()
643 .unwrap_or_else(|| TriggerBuilder::never());
644
645 let mut daemon =
646 LearningDaemon::new(daemon_config, trigger).map_err(|e| SwarmError::Config {
647 message: format!("Failed to create LearningDaemon: {}", e),
648 })?;
649
650 let record_tx = daemon.record_sender();
651 let shutdown_tx = daemon.shutdown_sender();
652
653 let sub_config = EventSubscriberConfig::new()
655 .batch_size(self.config.subscriber_batch_size)
656 .flush_interval_ms(self.config.subscriber_flush_interval_ms);
657
658 let mut subscriber_handles = Vec::new();
659
660 let action_sub = ActionEventSubscriber::with_config(
662 action_publisher.subscribe(),
663 record_tx.clone(),
664 sub_config.clone(),
665 );
666 let action_handle = self.runtime.spawn(async move {
667 action_sub.run().await;
668 });
669 subscriber_handles.push(action_handle);
670
671 let learning_channel = LearningEventChannel::global();
673 let learning_sub = LearningEventSubscriber::with_config(
674 learning_channel.subscribe(),
675 record_tx,
676 sub_config,
677 );
678 let learning_handle = self.runtime.spawn(async move {
679 learning_sub.run().await;
680 });
681 subscriber_handles.push(learning_handle);
682
683 let daemon_handle = self.runtime.spawn(async move { daemon.run().await });
685
686 Ok(LearningSetupResult {
687 daemon_handle: Some(daemon_handle),
688 subscriber_handles,
689 shutdown_tx: Some(shutdown_tx),
690 })
691 }
692}
693
694pub struct LearnableSwarm {
702 orchestrator: Orchestrator,
704 runtime: Handle,
706 config: LearnableSwarmConfig,
708 learning_store: Option<LearningStore>,
710 offline_model: Option<OfflineModel>,
712 daemon_handle: Option<DaemonHandle>,
714 subscriber_handles: Vec<JoinHandle<()>>,
716 shutdown_tx: Option<mpsc::Sender<()>>,
718}
719
720impl LearnableSwarm {
721 pub fn run_task(&mut self, task: SwarmTask) -> Result<SwarmResult, SwarmError> {
723 self.orchestrator.run_task(task)
724 }
725
726 pub fn run(&mut self) -> SwarmResult {
728 self.orchestrator.run()
729 }
730
731 pub fn orchestrator(&self) -> &Orchestrator {
733 &self.orchestrator
734 }
735
736 pub fn orchestrator_mut(&mut self) -> &mut Orchestrator {
738 &mut self.orchestrator
739 }
740
741 pub fn dependency_graph(&self) -> Option<&DependencyGraph> {
743 self.orchestrator.dependency_graph()
744 }
745
746 pub fn config(&self) -> &LearnableSwarmConfig {
748 &self.config
749 }
750
751 pub fn learning_store(&self) -> Option<&LearningStore> {
753 self.learning_store.as_ref()
754 }
755
756 pub fn offline_model(&self) -> Option<&OfflineModel> {
758 self.offline_model.as_ref()
759 }
760
761 pub fn is_learning_enabled(&self) -> bool {
763 self.config.learning_enabled
764 }
765
766 fn emit_stats_snapshot(&self) {
770 use crate::events::{LearnStatsOutcome, LearningEvent};
771 use crate::util::epoch_millis;
772
773 let state = self.orchestrator.state();
774 let tick = state.shared.tick;
775 let total_actions = state.shared.stats.total_visits() as u64;
776
777 let stats_json = if let Some(provider) = self.orchestrator.learned_provider() {
779 provider
780 .stats()
781 .map(|stats| serde_json::to_string(stats).unwrap_or_default())
782 .unwrap_or_default()
783 } else {
784 String::new()
785 };
786
787 let session_id = format!("{}", epoch_millis());
789
790 let outcome = if state.shared.environment_done {
792 LearnStatsOutcome::Success { score: 1.0 }
793 } else {
794 LearnStatsOutcome::Timeout {
795 partial_score: None,
796 }
797 };
798
799 let event = LearningEvent::learn_stats_snapshot(&self.config.scenario)
801 .session_id(session_id)
802 .stats_json(stats_json)
803 .total_ticks(tick)
804 .total_actions(total_actions);
805
806 let event = match outcome {
807 LearnStatsOutcome::Success { score } => event.success(score),
808 LearnStatsOutcome::Timeout { partial_score } => event.timeout(partial_score),
809 LearnStatsOutcome::Failure { reason } => event.failure(reason),
810 };
811
812 LearningEventChannel::global().emit(event.build());
813
814 tracing::debug!(
815 scenario = %self.config.scenario,
816 tick = tick,
817 total_actions = total_actions,
818 "LearnStatsSnapshot emitted"
819 );
820 }
821
822 pub fn take_shutdown_tx(&mut self) -> Option<mpsc::Sender<()>> {
827 self.shutdown_tx.take()
828 }
829
830 pub async fn shutdown(self) {
835 if self.config.learning_enabled {
837 self.emit_stats_snapshot();
838 }
839
840 if let Some(tx) = self.shutdown_tx {
842 let _ = tx.send(()).await;
843 }
844
845 if let Some(handle) = self.daemon_handle {
847 match handle.await {
848 Ok(Ok(())) => {
849 tracing::debug!("LearningDaemon shutdown completed");
850 }
851 Ok(Err(e)) => {
852 tracing::warn!("LearningDaemon error on shutdown: {}", e);
853 }
854 Err(e) => {
855 tracing::warn!("LearningDaemon join error: {}", e);
856 }
857 }
858 }
859
860 for handle in self.subscriber_handles {
862 let _ = handle.await;
863 }
864
865 tracing::debug!("LearnableSwarm shutdown completed");
866 }
867
868 pub fn shutdown_blocking(self) {
870 let runtime = self.runtime.clone();
871 runtime.block_on(self.shutdown());
872 }
873}
874
875async fn run_trace_subscriber(
881 mut rx: tokio::sync::broadcast::Receiver<crate::events::ActionEvent>,
882 subscriber: Arc<dyn TraceSubscriber>,
883) {
884 while let Ok(event) = rx.recv().await {
885 subscriber.on_event(&event);
886 }
887 subscriber.finish();
888}
889
890#[cfg(test)]
895mod tests {
896 use super::*;
897 use crate::agent::GenericWorker;
898
899 fn make_test_runtime() -> tokio::runtime::Runtime {
900 tokio::runtime::Builder::new_current_thread()
901 .enable_all()
902 .build()
903 .unwrap()
904 }
905
906 #[test]
907 fn test_config_default() {
908 let config = LearnableSwarmConfig::default();
909 assert!(!config.learning_enabled);
910 assert!(config.scenario.is_empty());
911 }
912
913 #[test]
914 fn test_config_builder() {
915 let config = LearnableSwarmConfig::new("test-scenario")
916 .with_learning(true)
917 .data_dir("/tmp/test");
918
919 assert_eq!(config.scenario, "test-scenario");
920 assert!(config.learning_enabled);
921 assert_eq!(config.data_dir, PathBuf::from("/tmp/test"));
922 }
923
924 #[test]
925 fn test_builder_basic() {
926 let rt = make_test_runtime();
927
928 let builder = LearnableSwarmBuilder::new(rt.handle().clone())
929 .scenario("test")
930 .add_worker(Box::new(GenericWorker::new(0)));
931
932 assert_eq!(builder.config.scenario, "test");
933 assert_eq!(builder.workers.len(), 1);
934 }
935
936 #[test]
937 fn test_builder_with_learning() {
938 let rt = make_test_runtime();
939
940 let builder = LearnableSwarmBuilder::new(rt.handle().clone())
941 .scenario("test")
942 .with_learning(true)
943 .add_worker(Box::new(GenericWorker::new(0)));
944
945 assert!(builder.config.learning_enabled);
946 }
947
948 #[test]
949 fn test_builder_learning_without_scenario_fails() {
950 let rt = make_test_runtime();
951
952 let result = LearnableSwarmBuilder::new(rt.handle().clone())
953 .with_learning(true)
954 .add_worker(Box::new(GenericWorker::new(0)))
955 .build();
956
957 assert!(result.is_err());
958 if let Err(err) = result {
959 assert!(err.to_string().contains("scenario is required"));
960 }
961 }
962
963 #[test]
964 fn test_builder_learning_disabled_without_scenario_ok() {
965 let rt = make_test_runtime();
966
967 let result = LearnableSwarmBuilder::new(rt.handle().clone())
969 .add_worker(Box::new(GenericWorker::new(0)))
970 .build();
971
972 assert!(result.is_ok());
973 }
974
975 #[test]
976 fn test_builder_with_scenario_profile() {
977 use crate::learn::learned_component::LearnedExploration;
978 use crate::learn::scenario_profile::{ScenarioProfile, ScenarioSource};
979
980 let rt = make_test_runtime();
981
982 let mut profile =
984 ScenarioProfile::new("test-profile", ScenarioSource::from_path("/test.toml"));
985 profile.exploration = Some(LearnedExploration::new(2.5, 0.4, 1.2));
986
987 let builder = LearnableSwarmBuilder::new(rt.handle().clone())
988 .with_scenario_profile(&profile)
989 .add_worker(Box::new(GenericWorker::new(0)));
990
991 assert_eq!(builder.config.scenario, "test-profile");
993 assert!(builder.offline_model.is_some());
995 let model = builder.offline_model.as_ref().unwrap();
996 assert_eq!(model.parameters.ucb1_c, 2.5);
997 }
998}