swarm_engine_core/events/
lifecycle.rs1use std::path::PathBuf;
34use std::sync::atomic::{AtomicUsize, Ordering};
35use std::sync::Arc;
36
37use crate::learn::{AlwaysTrigger, TrainTrigger, TriggerContext};
38use crate::orchestrator::SwarmResult;
39use crate::state::SwarmState;
40
41#[derive(Debug, Clone)]
47pub enum LifecycleEvent {
48 Started {
50 worker_count: usize,
52 },
53 Terminated {
55 result: SwarmResult,
57 stats: TerminationStats,
59 },
60}
61
62#[derive(Debug, Clone, Default)]
64pub struct TerminationStats {
65 pub total_ticks: u64,
67 pub total_actions: u64,
69 pub successful_actions: u64,
71 pub failed_actions: u64,
73 pub scenario: Option<String>,
75 pub group_id: Option<String>,
77}
78
79impl TerminationStats {
80 pub fn from_state(state: &SwarmState) -> Self {
82 Self {
83 total_ticks: state.shared.tick,
84 total_actions: state.shared.stats.total_visits() as u64,
85 successful_actions: state.shared.stats.total_successes() as u64,
86 failed_actions: state.shared.stats.total_failures() as u64,
87 scenario: None,
88 group_id: None,
89 }
90 }
91
92 pub fn with_scenario(mut self, scenario: impl Into<String>) -> Self {
94 self.scenario = Some(scenario.into());
95 self
96 }
97
98 pub fn with_group_id(mut self, group_id: impl Into<String>) -> Self {
100 self.group_id = Some(group_id.into());
101 self
102 }
103}
104
105pub trait LifecycleHook: Send + Sync {
124 fn on_start(&mut self, worker_count: usize) {
126 let _ = worker_count;
127 }
128
129 fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult);
134
135 fn name(&self) -> &str {
137 "lifecycle_hook"
138 }
139}
140
141pub struct LearningLifecycleHook {
167 learning_path: PathBuf,
169 trigger: Arc<dyn TrainTrigger>,
171 eval_count: Arc<AtomicUsize>,
173 last_learn_count: usize,
175 scenario: Option<String>,
177 learn_callback: Option<Box<dyn Fn(&str, &SwarmState) + Send + Sync>>,
179}
180
181impl LearningLifecycleHook {
182 pub fn new(learning_path: impl Into<PathBuf>) -> Self {
186 Self {
187 learning_path: learning_path.into(),
188 trigger: Arc::new(AlwaysTrigger),
189 eval_count: Arc::new(AtomicUsize::new(0)),
190 last_learn_count: 0,
191 scenario: None,
192 learn_callback: None,
193 }
194 }
195
196 pub fn with_trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
208 self.trigger = trigger;
209 self
210 }
211
212 pub fn with_scenario(mut self, scenario: impl Into<String>) -> Self {
214 self.scenario = Some(scenario.into());
215 self
216 }
217
218 pub fn with_learn_callback<F>(mut self, callback: F) -> Self
223 where
224 F: Fn(&str, &SwarmState) + Send + Sync + 'static,
225 {
226 self.learn_callback = Some(Box::new(callback));
227 self
228 }
229
230 pub fn eval_count_handle(&self) -> Arc<AtomicUsize> {
234 Arc::clone(&self.eval_count)
235 }
236
237 pub fn with_shared_eval_count(mut self, count: Arc<AtomicUsize>) -> Self {
239 self.eval_count = count;
240 self
241 }
242
243 pub fn current_eval_count(&self) -> usize {
245 self.eval_count.load(Ordering::SeqCst)
246 }
247
248 pub fn learning_path(&self) -> &PathBuf {
250 &self.learning_path
251 }
252
253 fn should_learn(&self) -> bool {
255 let current = self.eval_count.load(Ordering::SeqCst);
256 let ctx = TriggerContext::with_count(current).last_train_count(self.last_learn_count);
257 self.trigger.should_train(&ctx).unwrap_or(false)
258 }
259
260 fn run_learn(&mut self, state: &SwarmState) {
262 let scenario = self.scenario.as_deref().unwrap_or("unknown");
263
264 tracing::info!(
265 scenario = scenario,
266 eval_count = self.current_eval_count(),
267 trigger = self.trigger.name(),
268 "Running learning after trigger condition met"
269 );
270
271 if let Some(ref callback) = self.learn_callback {
273 callback(scenario, state);
274 } else {
275 self.run_default_learn(scenario);
277 }
278
279 self.last_learn_count = self.eval_count.load(Ordering::SeqCst);
281 }
282
283 fn run_default_learn(&self, scenario: &str) {
285 use crate::learn::LearningStore;
286
287 match LearningStore::new(&self.learning_path) {
288 Ok(store) => match store.run_offline_learning(scenario, 20) {
289 Ok(model) => {
290 tracing::info!(
291 scenario = scenario,
292 sessions = model.analyzed_sessions,
293 "Offline learning completed"
294 );
295 }
296 Err(e) => {
297 tracing::warn!(
298 scenario = scenario,
299 error = %e,
300 "Offline learning failed"
301 );
302 }
303 },
304 Err(e) => {
305 tracing::error!(
306 path = %self.learning_path.display(),
307 error = %e,
308 "Failed to create LearningStore"
309 );
310 }
311 }
312 }
313}
314
315impl LifecycleHook for LearningLifecycleHook {
316 fn on_start(&mut self, worker_count: usize) {
317 tracing::debug!(
318 worker_count = worker_count,
319 eval_count = self.current_eval_count(),
320 "LearningLifecycleHook: Swarm started"
321 );
322 }
323
324 fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult) {
325 let new_count = self.eval_count.fetch_add(1, Ordering::SeqCst) + 1;
327
328 tracing::debug!(
329 eval_count = new_count,
330 total_ticks = result.total_ticks,
331 trigger = self.trigger.name(),
332 "LearningLifecycleHook: Swarm terminated"
333 );
334
335 if self.should_learn() {
337 self.run_learn(state);
338 } else {
339 tracing::debug!(
340 eval_count = new_count,
341 last_learn = self.last_learn_count,
342 trigger = self.trigger.name(),
343 "Trigger not met, skipping learning"
344 );
345 }
346 }
347
348 fn name(&self) -> &str {
349 "learning_lifecycle_hook"
350 }
351}
352
353pub struct CompositeLifecycleHook {
369 hooks: Vec<Box<dyn LifecycleHook>>,
370}
371
372impl CompositeLifecycleHook {
373 pub fn new() -> Self {
375 Self { hooks: Vec::new() }
376 }
377
378 pub fn add(mut self, hook: Box<dyn LifecycleHook>) -> Self {
380 self.hooks.push(hook);
381 self
382 }
383}
384
385impl Default for CompositeLifecycleHook {
386 fn default() -> Self {
387 Self::new()
388 }
389}
390
391impl LifecycleHook for CompositeLifecycleHook {
392 fn on_start(&mut self, worker_count: usize) {
393 for hook in &mut self.hooks {
394 hook.on_start(worker_count);
395 }
396 }
397
398 fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult) {
399 for hook in &mut self.hooks {
400 hook.on_terminate(state, result);
401 }
402 }
403
404 fn name(&self) -> &str {
405 "composite_lifecycle_hook"
406 }
407}
408
409#[cfg(test)]
414mod tests {
415 use super::*;
416 use std::sync::atomic::AtomicBool;
417
418 struct TestHook {
419 started: Arc<AtomicBool>,
420 terminated: Arc<AtomicBool>,
421 }
422
423 impl TestHook {
424 fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>) {
425 let started = Arc::new(AtomicBool::new(false));
426 let terminated = Arc::new(AtomicBool::new(false));
427 (
428 Self {
429 started: Arc::clone(&started),
430 terminated: Arc::clone(&terminated),
431 },
432 started,
433 terminated,
434 )
435 }
436 }
437
438 impl LifecycleHook for TestHook {
439 fn on_start(&mut self, _worker_count: usize) {
440 self.started.store(true, Ordering::SeqCst);
441 }
442
443 fn on_terminate(&mut self, _state: &SwarmState, _result: &SwarmResult) {
444 self.terminated.store(true, Ordering::SeqCst);
445 }
446 }
447
448 #[test]
449 fn test_termination_stats_from_state() {
450 let state = SwarmState::new(4);
451 let stats = TerminationStats::from_state(&state);
452 assert_eq!(stats.total_ticks, 0);
453 assert!(stats.scenario.is_none());
454 }
455
456 #[test]
457 fn test_learning_lifecycle_hook_eval_count() {
458 let hook = LearningLifecycleHook::new("/tmp/test");
459 assert_eq!(hook.current_eval_count(), 0);
460
461 let handle = hook.eval_count_handle();
462 handle.fetch_add(5, Ordering::SeqCst);
463 assert_eq!(hook.current_eval_count(), 5);
464 }
465
466 #[test]
467 fn test_composite_hook() {
468 let (hook1, started1, terminated1) = TestHook::new();
469 let (hook2, started2, terminated2) = TestHook::new();
470
471 let mut composite = CompositeLifecycleHook::new()
472 .add(Box::new(hook1))
473 .add(Box::new(hook2));
474
475 composite.on_start(4);
476 assert!(started1.load(Ordering::SeqCst));
477 assert!(started2.load(Ordering::SeqCst));
478
479 let state = SwarmState::new(4);
480 let result = SwarmResult {
481 total_ticks: 10,
482 total_duration: std::time::Duration::from_secs(1),
483 completed: true,
484 };
485 composite.on_terminate(&state, &result);
486 assert!(terminated1.load(Ordering::SeqCst));
487 assert!(terminated2.load(Ordering::SeqCst));
488 }
489}