swarm_engine_core/events/
lifecycle.rs1type LearnCallback = Box<dyn Fn(&str, &SwarmState) + Send + Sync>;
35
36use std::path::PathBuf;
37use std::sync::atomic::{AtomicUsize, Ordering};
38use std::sync::Arc;
39
40use crate::learn::{AlwaysTrigger, TrainTrigger, TriggerContext};
41use crate::orchestrator::SwarmResult;
42use crate::state::SwarmState;
43
44#[derive(Debug, Clone)]
50pub enum LifecycleEvent {
51 Started {
53 worker_count: usize,
55 },
56 Terminated {
58 result: SwarmResult,
60 stats: TerminationStats,
62 },
63}
64
65#[derive(Debug, Clone, Default)]
67pub struct TerminationStats {
68 pub total_ticks: u64,
70 pub total_actions: u64,
72 pub successful_actions: u64,
74 pub failed_actions: u64,
76 pub scenario: Option<String>,
78 pub group_id: Option<String>,
80}
81
82impl TerminationStats {
83 pub fn from_state(state: &SwarmState) -> Self {
85 Self {
86 total_ticks: state.shared.tick,
87 total_actions: state.shared.stats.total_visits() as u64,
88 successful_actions: state.shared.stats.total_successes() as u64,
89 failed_actions: state.shared.stats.total_failures() as u64,
90 scenario: None,
91 group_id: None,
92 }
93 }
94
95 pub fn with_scenario(mut self, scenario: impl Into<String>) -> Self {
97 self.scenario = Some(scenario.into());
98 self
99 }
100
101 pub fn with_group_id(mut self, group_id: impl Into<String>) -> Self {
103 self.group_id = Some(group_id.into());
104 self
105 }
106}
107
108pub trait LifecycleHook: Send + Sync {
127 fn on_start(&mut self, worker_count: usize) {
129 let _ = worker_count;
130 }
131
132 fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult);
137
138 fn name(&self) -> &str {
140 "lifecycle_hook"
141 }
142}
143
144pub struct LearningLifecycleHook {
170 learning_path: PathBuf,
172 trigger: Arc<dyn TrainTrigger>,
174 eval_count: Arc<AtomicUsize>,
176 last_learn_count: usize,
178 scenario: Option<String>,
180 learn_callback: Option<LearnCallback>,
182}
183
184impl LearningLifecycleHook {
185 pub fn new(learning_path: impl Into<PathBuf>) -> Self {
189 Self {
190 learning_path: learning_path.into(),
191 trigger: Arc::new(AlwaysTrigger),
192 eval_count: Arc::new(AtomicUsize::new(0)),
193 last_learn_count: 0,
194 scenario: None,
195 learn_callback: None,
196 }
197 }
198
199 pub fn with_trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
211 self.trigger = trigger;
212 self
213 }
214
215 pub fn with_scenario(mut self, scenario: impl Into<String>) -> Self {
217 self.scenario = Some(scenario.into());
218 self
219 }
220
221 pub fn with_learn_callback<F>(mut self, callback: F) -> Self
226 where
227 F: Fn(&str, &SwarmState) + Send + Sync + 'static,
228 {
229 self.learn_callback = Some(Box::new(callback));
230 self
231 }
232
233 pub fn eval_count_handle(&self) -> Arc<AtomicUsize> {
237 Arc::clone(&self.eval_count)
238 }
239
240 pub fn with_shared_eval_count(mut self, count: Arc<AtomicUsize>) -> Self {
242 self.eval_count = count;
243 self
244 }
245
246 pub fn current_eval_count(&self) -> usize {
248 self.eval_count.load(Ordering::SeqCst)
249 }
250
251 pub fn learning_path(&self) -> &PathBuf {
253 &self.learning_path
254 }
255
256 fn should_learn(&self) -> bool {
258 let current = self.eval_count.load(Ordering::SeqCst);
259 let ctx = TriggerContext::with_count(current).last_train_count(self.last_learn_count);
260 self.trigger.should_train(&ctx).unwrap_or(false)
261 }
262
263 fn run_learn(&mut self, state: &SwarmState) {
265 let scenario = self.scenario.as_deref().unwrap_or("unknown");
266
267 tracing::info!(
268 scenario = scenario,
269 eval_count = self.current_eval_count(),
270 trigger = self.trigger.name(),
271 "Running learning after trigger condition met"
272 );
273
274 if let Some(ref callback) = self.learn_callback {
276 callback(scenario, state);
277 } else {
278 self.run_default_learn(scenario);
280 }
281
282 self.last_learn_count = self.eval_count.load(Ordering::SeqCst);
284 }
285
286 fn run_default_learn(&self, scenario: &str) {
288 use crate::learn::LearningStore;
289
290 match LearningStore::new(&self.learning_path) {
291 Ok(store) => match store.run_offline_learning(scenario, 20) {
292 Ok(model) => {
293 tracing::info!(
294 scenario = scenario,
295 sessions = model.analyzed_sessions,
296 "Offline learning completed"
297 );
298 }
299 Err(e) => {
300 tracing::warn!(
301 scenario = scenario,
302 error = %e,
303 "Offline learning failed"
304 );
305 }
306 },
307 Err(e) => {
308 tracing::error!(
309 path = %self.learning_path.display(),
310 error = %e,
311 "Failed to create LearningStore"
312 );
313 }
314 }
315 }
316}
317
318impl LifecycleHook for LearningLifecycleHook {
319 fn on_start(&mut self, worker_count: usize) {
320 tracing::debug!(
321 worker_count = worker_count,
322 eval_count = self.current_eval_count(),
323 "LearningLifecycleHook: Swarm started"
324 );
325 }
326
327 fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult) {
328 let new_count = self.eval_count.fetch_add(1, Ordering::SeqCst) + 1;
330
331 tracing::debug!(
332 eval_count = new_count,
333 total_ticks = result.total_ticks,
334 trigger = self.trigger.name(),
335 "LearningLifecycleHook: Swarm terminated"
336 );
337
338 if self.should_learn() {
340 self.run_learn(state);
341 } else {
342 tracing::debug!(
343 eval_count = new_count,
344 last_learn = self.last_learn_count,
345 trigger = self.trigger.name(),
346 "Trigger not met, skipping learning"
347 );
348 }
349 }
350
351 fn name(&self) -> &str {
352 "learning_lifecycle_hook"
353 }
354}
355
356pub struct CompositeLifecycleHook {
372 hooks: Vec<Box<dyn LifecycleHook>>,
373}
374
375impl CompositeLifecycleHook {
376 pub fn new() -> Self {
378 Self { hooks: Vec::new() }
379 }
380
381 pub fn with_hook(mut self, hook: Box<dyn LifecycleHook>) -> Self {
383 self.hooks.push(hook);
384 self
385 }
386}
387
388impl Default for CompositeLifecycleHook {
389 fn default() -> Self {
390 Self::new()
391 }
392}
393
394impl LifecycleHook for CompositeLifecycleHook {
395 fn on_start(&mut self, worker_count: usize) {
396 for hook in &mut self.hooks {
397 hook.on_start(worker_count);
398 }
399 }
400
401 fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult) {
402 for hook in &mut self.hooks {
403 hook.on_terminate(state, result);
404 }
405 }
406
407 fn name(&self) -> &str {
408 "composite_lifecycle_hook"
409 }
410}
411
412#[cfg(test)]
417mod tests {
418 use super::*;
419 use std::sync::atomic::AtomicBool;
420
421 struct TestHook {
422 started: Arc<AtomicBool>,
423 terminated: Arc<AtomicBool>,
424 }
425
426 impl TestHook {
427 fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>) {
428 let started = Arc::new(AtomicBool::new(false));
429 let terminated = Arc::new(AtomicBool::new(false));
430 (
431 Self {
432 started: Arc::clone(&started),
433 terminated: Arc::clone(&terminated),
434 },
435 started,
436 terminated,
437 )
438 }
439 }
440
441 impl LifecycleHook for TestHook {
442 fn on_start(&mut self, _worker_count: usize) {
443 self.started.store(true, Ordering::SeqCst);
444 }
445
446 fn on_terminate(&mut self, _state: &SwarmState, _result: &SwarmResult) {
447 self.terminated.store(true, Ordering::SeqCst);
448 }
449 }
450
451 #[test]
452 fn test_termination_stats_from_state() {
453 let state = SwarmState::new(4);
454 let stats = TerminationStats::from_state(&state);
455 assert_eq!(stats.total_ticks, 0);
456 assert!(stats.scenario.is_none());
457 }
458
459 #[test]
460 fn test_learning_lifecycle_hook_eval_count() {
461 let hook = LearningLifecycleHook::new("/tmp/test");
462 assert_eq!(hook.current_eval_count(), 0);
463
464 let handle = hook.eval_count_handle();
465 handle.fetch_add(5, Ordering::SeqCst);
466 assert_eq!(hook.current_eval_count(), 5);
467 }
468
469 #[test]
470 fn test_composite_hook() {
471 let (hook1, started1, terminated1) = TestHook::new();
472 let (hook2, started2, terminated2) = TestHook::new();
473
474 let mut composite = CompositeLifecycleHook::new()
475 .with_hook(Box::new(hook1))
476 .with_hook(Box::new(hook2));
477
478 composite.on_start(4);
479 assert!(started1.load(Ordering::SeqCst));
480 assert!(started2.load(Ordering::SeqCst));
481
482 let state = SwarmState::new(4);
483 let result = SwarmResult {
484 total_ticks: 10,
485 total_duration: std::time::Duration::from_secs(1),
486 completed: true,
487 };
488 composite.on_terminate(&state, &result);
489 assert!(terminated1.load(Ordering::SeqCst));
490 assert!(terminated2.load(Ordering::SeqCst));
491 }
492}