1use std::path::PathBuf;
50use std::sync::Arc;
51use std::time::Duration;
52
53use tokio::runtime::Handle;
54use tokio::sync::mpsc;
55use tokio::task::JoinHandle;
56
57use crate::agent::{BatchInvoker, ManagerAgent, WorkerAgent};
58use crate::error::SwarmError;
59use crate::events::{ActionEventPublisher, LearningEventChannel, LifecycleHook, TraceSubscriber};
60use crate::exploration::{DependencyGraph, DependencyGraphProvider, NodeRules, OperatorProvider};
61use crate::extensions::Extensions;
62use crate::orchestrator::{Orchestrator, OrchestratorBuilder, SwarmConfig, SwarmResult};
63use crate::types::SwarmTask;
64
65use super::daemon::{
66 ActionEventSubscriber, DaemonConfig, DaemonError, EventSubscriberConfig, LearningDaemon,
67 LearningEventSubscriber,
68};
69use super::profile_adapter::profile_to_offline_model;
70use super::scenario_profile::ScenarioProfile;
71use super::snapshot::LearningStore;
72use super::trigger::{TrainTrigger, TriggerBuilder};
73use super::LearningSnapshot;
74use super::OfflineModel;
75
76type DaemonHandle = JoinHandle<Result<(), DaemonError>>;
82
83#[derive(Default)]
85struct LearningSetupResult {
86 daemon_handle: Option<DaemonHandle>,
88 subscriber_handles: Vec<JoinHandle<()>>,
90 shutdown_tx: Option<mpsc::Sender<()>>,
92}
93
94#[derive(Debug, Clone)]
100pub struct LearnableSwarmConfig {
101 pub scenario: String,
103 pub data_dir: PathBuf,
105 pub learning_enabled: bool,
107 pub subscriber_batch_size: usize,
109 pub subscriber_flush_interval_ms: u64,
111 pub daemon_check_interval: Duration,
113}
114
115impl Default for LearnableSwarmConfig {
116 fn default() -> Self {
117 Self {
118 scenario: String::new(),
119 data_dir: default_learning_dir(),
120 learning_enabled: false,
121 subscriber_batch_size: 100,
122 subscriber_flush_interval_ms: 100,
123 daemon_check_interval: Duration::from_secs(10),
124 }
125 }
126}
127
128impl LearnableSwarmConfig {
129 pub fn new(scenario: impl Into<String>) -> Self {
131 Self {
132 scenario: scenario.into(),
133 ..Default::default()
134 }
135 }
136
137 pub fn with_learning(mut self, enabled: bool) -> Self {
139 self.learning_enabled = enabled;
140 self
141 }
142
143 pub fn data_dir(mut self, path: impl Into<PathBuf>) -> Self {
145 self.data_dir = path.into();
146 self
147 }
148}
149
150fn default_learning_dir() -> PathBuf {
151 dirs::data_dir()
152 .unwrap_or_else(|| PathBuf::from("."))
153 .join("swarm-engine")
154 .join("learning")
155}
156
157pub struct LearnableSwarmBuilder {
165 runtime: Handle,
167 config: LearnableSwarmConfig,
169 swarm_config: Option<SwarmConfig>,
171 workers: Vec<Box<dyn WorkerAgent>>,
173 managers: Vec<Box<dyn ManagerAgent>>,
175 batch_invoker: Option<Box<dyn BatchInvoker>>,
177 dependency_provider: Option<Box<dyn DependencyGraphProvider>>,
179 operator_provider: Option<Box<dyn OperatorProvider<NodeRules>>>,
181 extensions: Option<Extensions>,
183 dependency_graph: Option<DependencyGraph>,
185 offline_model: Option<OfflineModel>,
187 prior_snapshot: Option<LearningSnapshot>,
189 learning_store: Option<LearningStore>,
191 train_trigger: Option<Arc<dyn TrainTrigger>>,
193 lifecycle_hook: Option<Box<dyn LifecycleHook>>,
195 enable_exploration: bool,
197 deferred_error: Option<SwarmError>,
199 trace_subscriber: Option<Arc<dyn TraceSubscriber>>,
201}
202
203impl LearnableSwarmBuilder {
204 pub fn new(runtime: Handle) -> Self {
206 Self {
207 runtime,
208 config: LearnableSwarmConfig::default(),
209 swarm_config: None,
210 workers: Vec::new(),
211 managers: Vec::new(),
212 batch_invoker: None,
213 dependency_provider: None,
214 operator_provider: None,
215 extensions: None,
216 dependency_graph: None,
217 offline_model: None,
218 prior_snapshot: None,
219 learning_store: None,
220 train_trigger: None,
221 lifecycle_hook: None,
222 enable_exploration: false,
223 deferred_error: None,
224 trace_subscriber: None,
225 }
226 }
227
228 pub fn scenario(mut self, name: impl Into<String>) -> Self {
230 self.config.scenario = name.into();
231 self
232 }
233
234 pub fn with_learning(mut self, enabled: bool) -> Self {
238 self.config.learning_enabled = enabled;
239 self
240 }
241
242 pub fn data_dir(mut self, path: impl Into<PathBuf>) -> Self {
244 self.config.data_dir = path.into();
245 self
246 }
247
248 pub fn swarm_config(mut self, config: SwarmConfig) -> Self {
250 self.swarm_config = Some(config);
251 self
252 }
253
254 pub fn add_worker(mut self, worker: Box<dyn WorkerAgent>) -> Self {
256 self.workers.push(worker);
257 self
258 }
259
260 pub fn workers(mut self, workers: Vec<Box<dyn WorkerAgent>>) -> Self {
262 self.workers = workers;
263 self
264 }
265
266 pub fn add_manager(mut self, manager: Box<dyn ManagerAgent>) -> Self {
268 self.managers.push(manager);
269 self
270 }
271
272 pub fn managers(mut self, managers: Vec<Box<dyn ManagerAgent>>) -> Self {
274 self.managers = managers;
275 self
276 }
277
278 pub fn batch_invoker(mut self, invoker: Box<dyn BatchInvoker>) -> Self {
280 self.batch_invoker = Some(invoker);
281 self
282 }
283
284 pub fn dependency_provider(mut self, provider: Box<dyn DependencyGraphProvider>) -> Self {
286 self.dependency_provider = Some(provider);
287 self
288 }
289
290 pub fn operator_provider(mut self, provider: Box<dyn OperatorProvider<NodeRules>>) -> Self {
292 self.operator_provider = Some(provider);
293 self
294 }
295
296 pub fn extensions(mut self, extensions: Extensions) -> Self {
298 self.extensions = Some(extensions);
299 self
300 }
301
302 pub fn dependency_graph(mut self, graph: DependencyGraph) -> Self {
304 self.dependency_graph = Some(graph);
305 self
306 }
307
308 pub fn offline_model(mut self, model: OfflineModel) -> Self {
310 self.offline_model = Some(model);
311 self
312 }
313
314 pub fn with_scenario_profile(mut self, profile: &ScenarioProfile) -> Self {
330 let model = profile_to_offline_model(profile);
331 self.offline_model = Some(model);
332 if self.config.scenario.is_empty() {
334 self.config.scenario = profile.id.0.clone();
335 }
336 self
337 }
338
339 pub fn offline_model_ref(&self) -> Option<&OfflineModel> {
341 self.offline_model.as_ref()
342 }
343
344 pub fn prior_snapshot(mut self, snapshot: LearningSnapshot) -> Self {
346 self.prior_snapshot = Some(snapshot);
347 self
348 }
349
350 pub fn learning_store(mut self, store: LearningStore) -> Self {
352 self.learning_store = Some(store);
353 self
354 }
355
356 pub fn with_learning_store(mut self, store: LearningStore) -> Self {
365 if let Err(e) = self.load_from_store(&store) {
366 self.deferred_error = Some(e);
367 }
368 self.config.data_dir = store.storage().base_dir().to_path_buf();
369 self.config.learning_enabled = true;
370 self.learning_store = Some(store);
371 self
372 }
373
374 pub fn train_trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
376 self.train_trigger = Some(trigger);
377 self
378 }
379
380 pub fn lifecycle_hook(mut self, hook: Box<dyn LifecycleHook>) -> Self {
382 self.lifecycle_hook = Some(hook);
383 self
384 }
385
386 pub fn enable_exploration(mut self, enabled: bool) -> Self {
388 self.enable_exploration = enabled;
389 self
390 }
391
392 pub fn with_trace_subscriber(mut self, subscriber: Arc<dyn TraceSubscriber>) -> Self {
408 self.trace_subscriber = Some(subscriber);
409 self
410 }
411
412 pub fn with_learning_store_path(mut self, path: impl AsRef<std::path::Path>) -> Self {
419 match LearningStore::new(&path) {
420 Ok(store) => {
421 if let Err(e) = self.load_from_store(&store) {
422 self.deferred_error = Some(e);
423 }
424 self.config.learning_enabled = true;
425 self.learning_store = Some(store);
426 }
427 Err(e) => {
428 self.deferred_error = Some(SwarmError::Config {
430 message: format!(
431 "Failed to create LearningStore at '{}': {}",
432 path.as_ref().display(),
433 e
434 ),
435 });
436 }
437 }
438 self
439 }
440
441 fn load_from_store(&mut self, store: &LearningStore) -> Result<(), SwarmError> {
447 let scenario_key = &self.config.scenario;
448
449 match store.load_scenario(scenario_key) {
451 Ok(snapshot) => {
452 self.prior_snapshot = Some(snapshot);
453 }
454 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
455 tracing::debug!(scenario = %scenario_key, "No prior snapshot found (first run)");
457 }
458 Err(e) => {
459 return Err(SwarmError::Config {
461 message: format!(
462 "Failed to load prior snapshot for '{}': {}",
463 scenario_key, e
464 ),
465 });
466 }
467 }
468
469 match store.load_offline_model(scenario_key) {
471 Ok(model) => {
472 tracing::debug!(
473 ucb1_c = model.parameters.ucb1_c,
474 strategy = %model.strategy_config.initial_strategy,
475 action_order = model.action_order.is_some(),
476 "Offline model loaded"
477 );
478 self.offline_model = Some(model);
479 }
480 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
481 tracing::debug!(scenario = %scenario_key, "No offline model found (first run)");
483 }
484 Err(e) => {
485 return Err(SwarmError::Config {
487 message: format!("Failed to load offline model for '{}': {}", scenario_key, e),
488 });
489 }
490 }
491
492 Ok(())
493 }
494
495 pub fn build(mut self) -> Result<LearnableSwarm, SwarmError> {
497 if let Some(err) = self.deferred_error.take() {
502 return Err(err);
503 }
504
505 if self.config.learning_enabled && self.config.scenario.is_empty() {
506 return Err(SwarmError::Config {
507 message: "scenario is required when learning is enabled".into(),
508 });
509 }
510
511 if self.config.learning_enabled {
513 LearningEventChannel::global().enable();
514 }
515
516 let swarm_config = self.swarm_config.take().unwrap_or_default();
517
518 let (action_publisher, _initial_rx) = ActionEventPublisher::new(1024);
523
524 let mut subscriber_handles_trace = Vec::new();
526 if let Some(trace_subscriber) = self.trace_subscriber.take() {
527 let rx = action_publisher.subscribe();
528 let handle = self.runtime.spawn(async move {
529 run_trace_subscriber(rx, trace_subscriber).await;
530 });
531 subscriber_handles_trace.push(handle);
532 }
533
534 let LearningSetupResult {
535 daemon_handle,
536 mut subscriber_handles,
537 shutdown_tx,
538 } = if self.config.learning_enabled {
539 self.setup_learning_components(&action_publisher)?
540 } else {
541 LearningSetupResult::default()
542 };
543
544 subscriber_handles.extend(subscriber_handles_trace);
546
547 let extensions = self.build_extensions();
553
554 let mut orch_builder = OrchestratorBuilder::new()
555 .config(swarm_config)
556 .extensions(extensions);
557
558 for worker in self.workers {
560 orch_builder = orch_builder.add_worker_boxed(worker);
561 }
562
563 for manager in self.managers {
565 orch_builder = orch_builder.add_manager_boxed(manager);
566 }
567
568 if let Some(invoker) = self.batch_invoker {
570 orch_builder = orch_builder.batch_invoker_boxed(invoker);
571 }
572
573 if let Some(provider) = self.dependency_provider {
575 orch_builder = orch_builder.dependency_provider_boxed(provider);
576 }
577
578 if let Some(provider) = self.operator_provider {
580 orch_builder = orch_builder.operator_provider_boxed(provider);
581 }
582
583 if self.enable_exploration {
585 orch_builder = orch_builder.with_exploration();
586 }
587
588 if let Some(ref model) = self.offline_model {
590 orch_builder = orch_builder.with_offline_model(model.clone());
591 }
592
593 if let Some(hook) = self.lifecycle_hook {
595 orch_builder = orch_builder.lifecycle_hook(hook);
596 }
597
598 orch_builder = orch_builder.action_collector(action_publisher);
600
601 let orchestrator = orch_builder.build(self.runtime.clone());
603
604 Ok(LearnableSwarm {
605 orchestrator,
606 runtime: self.runtime,
607 config: self.config,
608 learning_store: self.learning_store,
609 offline_model: self.offline_model,
610 daemon_handle,
611 subscriber_handles,
612 shutdown_tx,
613 })
614 }
615
616 fn build_extensions(&mut self) -> Extensions {
618 let mut ext = self.extensions.take().unwrap_or_default();
619
620 if let Some(graph) = self.dependency_graph.take() {
621 ext.insert(graph);
622 }
623 if let Some(snapshot) = self.prior_snapshot.take() {
624 ext.insert(snapshot);
625 }
626
627 ext
628 }
629
630 fn setup_learning_components(
632 &self,
633 action_publisher: &ActionEventPublisher,
634 ) -> Result<LearningSetupResult, SwarmError> {
635 let daemon_config = DaemonConfig::new(&self.config.scenario)
637 .data_dir(&self.config.data_dir)
638 .check_interval(self.config.daemon_check_interval);
639
640 let trigger = self
641 .train_trigger
642 .clone()
643 .unwrap_or_else(|| TriggerBuilder::never());
644
645 let mut daemon =
646 LearningDaemon::new(daemon_config, trigger).map_err(|e| SwarmError::Config {
647 message: format!("Failed to create LearningDaemon: {}", e),
648 })?;
649
650 let record_tx = daemon.record_sender();
651 let shutdown_tx = daemon.shutdown_sender();
652
653 let sub_config = EventSubscriberConfig::new()
655 .batch_size(self.config.subscriber_batch_size)
656 .flush_interval_ms(self.config.subscriber_flush_interval_ms);
657
658 let mut subscriber_handles = Vec::new();
659
660 let action_sub = ActionEventSubscriber::with_config(
662 action_publisher.subscribe(),
663 record_tx.clone(),
664 sub_config.clone(),
665 );
666 let action_handle = self.runtime.spawn(async move {
667 action_sub.run().await;
668 });
669 subscriber_handles.push(action_handle);
670
671 let learning_channel = LearningEventChannel::global();
673 let learning_sub = LearningEventSubscriber::with_config(
674 learning_channel.subscribe(),
675 record_tx,
676 sub_config,
677 );
678 let learning_handle = self.runtime.spawn(async move {
679 learning_sub.run().await;
680 });
681 subscriber_handles.push(learning_handle);
682
683 let daemon_handle = self.runtime.spawn(async move { daemon.run().await });
685
686 Ok(LearningSetupResult {
687 daemon_handle: Some(daemon_handle),
688 subscriber_handles,
689 shutdown_tx: Some(shutdown_tx),
690 })
691 }
692}
693
694pub struct LearnableSwarm {
702 orchestrator: Orchestrator,
704 runtime: Handle,
706 config: LearnableSwarmConfig,
708 learning_store: Option<LearningStore>,
710 offline_model: Option<OfflineModel>,
712 daemon_handle: Option<DaemonHandle>,
714 subscriber_handles: Vec<JoinHandle<()>>,
716 shutdown_tx: Option<mpsc::Sender<()>>,
718}
719
720impl LearnableSwarm {
721 pub fn run_task(&mut self, task: SwarmTask) -> Result<SwarmResult, SwarmError> {
723 self.orchestrator.run_task(task)
724 }
725
726 pub fn run(&mut self) -> SwarmResult {
728 self.orchestrator.run()
729 }
730
731 pub fn orchestrator(&self) -> &Orchestrator {
733 &self.orchestrator
734 }
735
736 pub fn orchestrator_mut(&mut self) -> &mut Orchestrator {
738 &mut self.orchestrator
739 }
740
741 pub fn dependency_graph(&self) -> Option<&DependencyGraph> {
743 self.orchestrator.dependency_graph()
744 }
745
746 pub fn config(&self) -> &LearnableSwarmConfig {
748 &self.config
749 }
750
751 pub fn learning_store(&self) -> Option<&LearningStore> {
753 self.learning_store.as_ref()
754 }
755
756 pub fn offline_model(&self) -> Option<&OfflineModel> {
758 self.offline_model.as_ref()
759 }
760
761 pub fn is_learning_enabled(&self) -> bool {
763 self.config.learning_enabled
764 }
765
766 pub fn emit_stats_snapshot(&self) {
771 use crate::events::{LearnStatsOutcome, LearningEvent};
772 use crate::util::epoch_millis;
773
774 let state = self.orchestrator.state();
775 let tick = state.shared.tick;
776 let total_actions = state.shared.stats.total_visits() as u64;
777
778 let stats_json = if let Some(provider) = self.orchestrator.learned_provider() {
780 provider
781 .stats()
782 .map(|stats| serde_json::to_string(stats).unwrap_or_default())
783 .unwrap_or_default()
784 } else {
785 String::new()
786 };
787
788 let session_id = format!("{}", epoch_millis());
790
791 let outcome = if state.shared.environment_done {
793 LearnStatsOutcome::Success { score: 1.0 }
794 } else {
795 LearnStatsOutcome::Timeout {
796 partial_score: None,
797 }
798 };
799
800 let event = LearningEvent::learn_stats_snapshot(&self.config.scenario)
802 .session_id(session_id)
803 .stats_json(stats_json)
804 .total_ticks(tick)
805 .total_actions(total_actions);
806
807 let event = match outcome {
808 LearnStatsOutcome::Success { score } => event.success(score),
809 LearnStatsOutcome::Timeout { partial_score } => event.timeout(partial_score),
810 LearnStatsOutcome::Failure { reason } => event.failure(reason),
811 };
812
813 LearningEventChannel::global().emit(event.build());
814
815 tracing::debug!(
816 scenario = %self.config.scenario,
817 tick = tick,
818 total_actions = total_actions,
819 "LearnStatsSnapshot emitted"
820 );
821 }
822
823 pub fn take_shutdown_tx(&mut self) -> Option<mpsc::Sender<()>> {
828 self.shutdown_tx.take()
829 }
830
831 pub async fn shutdown(self) {
836 if self.config.learning_enabled {
838 self.emit_stats_snapshot();
839 }
840
841 if let Some(tx) = self.shutdown_tx {
843 let _ = tx.send(()).await;
844 }
845
846 if let Some(handle) = self.daemon_handle {
848 match handle.await {
849 Ok(Ok(())) => {
850 tracing::debug!("LearningDaemon shutdown completed");
851 }
852 Ok(Err(e)) => {
853 tracing::warn!("LearningDaemon error on shutdown: {}", e);
854 }
855 Err(e) => {
856 tracing::warn!("LearningDaemon join error: {}", e);
857 }
858 }
859 }
860
861 for handle in self.subscriber_handles {
863 let _ = handle.await;
864 }
865
866 tracing::debug!("LearnableSwarm shutdown completed");
867 }
868
869 pub fn shutdown_blocking(self) {
871 let runtime = self.runtime.clone();
872 runtime.block_on(self.shutdown());
873 }
874}
875
876async fn run_trace_subscriber(
882 mut rx: tokio::sync::broadcast::Receiver<crate::events::ActionEvent>,
883 subscriber: Arc<dyn TraceSubscriber>,
884) {
885 while let Ok(event) = rx.recv().await {
886 subscriber.on_event(&event);
887 }
888 subscriber.finish();
889}
890
891#[cfg(test)]
896mod tests {
897 use super::*;
898 use crate::agent::GenericWorker;
899
900 fn make_test_runtime() -> tokio::runtime::Runtime {
901 tokio::runtime::Builder::new_current_thread()
902 .enable_all()
903 .build()
904 .unwrap()
905 }
906
907 #[test]
908 fn test_config_default() {
909 let config = LearnableSwarmConfig::default();
910 assert!(!config.learning_enabled);
911 assert!(config.scenario.is_empty());
912 }
913
914 #[test]
915 fn test_config_builder() {
916 let config = LearnableSwarmConfig::new("test-scenario")
917 .with_learning(true)
918 .data_dir("/tmp/test");
919
920 assert_eq!(config.scenario, "test-scenario");
921 assert!(config.learning_enabled);
922 assert_eq!(config.data_dir, PathBuf::from("/tmp/test"));
923 }
924
925 #[test]
926 fn test_builder_basic() {
927 let rt = make_test_runtime();
928
929 let builder = LearnableSwarmBuilder::new(rt.handle().clone())
930 .scenario("test")
931 .add_worker(Box::new(GenericWorker::new(0)));
932
933 assert_eq!(builder.config.scenario, "test");
934 assert_eq!(builder.workers.len(), 1);
935 }
936
937 #[test]
938 fn test_builder_with_learning() {
939 let rt = make_test_runtime();
940
941 let builder = LearnableSwarmBuilder::new(rt.handle().clone())
942 .scenario("test")
943 .with_learning(true)
944 .add_worker(Box::new(GenericWorker::new(0)));
945
946 assert!(builder.config.learning_enabled);
947 }
948
949 #[test]
950 fn test_builder_learning_without_scenario_fails() {
951 let rt = make_test_runtime();
952
953 let result = LearnableSwarmBuilder::new(rt.handle().clone())
954 .with_learning(true)
955 .add_worker(Box::new(GenericWorker::new(0)))
956 .build();
957
958 assert!(result.is_err());
959 if let Err(err) = result {
960 assert!(err.to_string().contains("scenario is required"));
961 }
962 }
963
964 #[test]
965 fn test_builder_learning_disabled_without_scenario_ok() {
966 let rt = make_test_runtime();
967
968 let result = LearnableSwarmBuilder::new(rt.handle().clone())
970 .add_worker(Box::new(GenericWorker::new(0)))
971 .build();
972
973 assert!(result.is_ok());
974 }
975
976 #[test]
977 fn test_builder_with_scenario_profile() {
978 use crate::learn::learned_component::LearnedExploration;
979 use crate::learn::scenario_profile::{ScenarioProfile, ScenarioSource};
980
981 let rt = make_test_runtime();
982
983 let mut profile =
985 ScenarioProfile::new("test-profile", ScenarioSource::from_path("/test.toml"));
986 profile.exploration = Some(LearnedExploration::new(2.5, 0.4, 1.2));
987
988 let builder = LearnableSwarmBuilder::new(rt.handle().clone())
989 .with_scenario_profile(&profile)
990 .add_worker(Box::new(GenericWorker::new(0)));
991
992 assert_eq!(builder.config.scenario, "test-profile");
994 assert!(builder.offline_model.is_some());
996 let model = builder.offline_model.as_ref().unwrap();
997 assert_eq!(model.parameters.ucb1_c, 2.5);
998 }
999}