Skip to main content

swarm_engine_core/events/
lifecycle.rs

1//! Lifecycle Hooks - Orchestrator ライフサイクルフック
2//!
3//! Swarm の開始・終了時に同期的に呼び出されるコールバック。
4//!
5//! # Hook vs Sink の命名規則
6//!
7//! | 名前 | 実行モデル | 用途 |
8//! |------|----------|------|
9//! | **Hook** | 同期 callback | Orchestrator lifecycle など、同期的に呼ばれる処理 |
10//! | **Sink** | async trait | Pipeline 終端など、非同期でデータを消費する処理 |
11//!
12//! # 設計
13//!
14//! - `LifecycleEvent`: Swarm の状態変化を表すイベント
15//! - `LifecycleHook`: 同期的に呼び出される callback trait
16//! - `LearningLifecycleHook`: TrainTrigger + Learn 実行の実装
17//!
18//! # 使用例
19//!
20//! ```ignore
21//! use swarm_engine_core::events::{LifecycleHook, LearningLifecycleHook};
22//! use swarm_engine_core::learn::TriggerBuilder;
23//!
24//! // 10 回の Eval 実行後に Learn を実行
25//! let hook = LearningLifecycleHook::new(learning_path)
26//!     .with_trigger(TriggerBuilder::every_n_episodes(10));
27//!
28//! let orchestrator = OrchestratorBuilder::new()
29//!     .lifecycle_hook(Box::new(hook))
30//!     .build(runtime);
31//! ```
32
33use std::path::PathBuf;
34use std::sync::atomic::{AtomicUsize, Ordering};
35use std::sync::Arc;
36
37use crate::learn::{AlwaysTrigger, TrainTrigger, TriggerContext};
38use crate::orchestrator::SwarmResult;
39use crate::state::SwarmState;
40
41// ============================================================================
42// LifecycleEvent
43// ============================================================================
44
45/// Orchestrator のライフサイクルイベント
46#[derive(Debug, Clone)]
47pub enum LifecycleEvent {
48    /// Swarm 開始
49    Started {
50        /// Worker 数
51        worker_count: usize,
52    },
53    /// Swarm 終了
54    Terminated {
55        /// 実行結果
56        result: SwarmResult,
57        /// 終了時の統計(シリアライズ可能な形式)
58        stats: TerminationStats,
59    },
60}
61
62/// 終了時の統計情報(シリアライズ可能)
63#[derive(Debug, Clone, Default)]
64pub struct TerminationStats {
65    /// 総 Tick 数
66    pub total_ticks: u64,
67    /// 総アクション数
68    pub total_actions: u64,
69    /// 成功アクション数
70    pub successful_actions: u64,
71    /// 失敗アクション数
72    pub failed_actions: u64,
73    /// シナリオ名(あれば)
74    pub scenario: Option<String>,
75    /// GroupId(あれば)
76    pub group_id: Option<String>,
77}
78
79impl TerminationStats {
80    /// SwarmState から統計を抽出
81    pub fn from_state(state: &SwarmState) -> Self {
82        Self {
83            total_ticks: state.shared.tick,
84            total_actions: state.shared.stats.total_visits() as u64,
85            successful_actions: state.shared.stats.total_successes() as u64,
86            failed_actions: state.shared.stats.total_failures() as u64,
87            scenario: None,
88            group_id: None,
89        }
90    }
91
92    /// シナリオ名を設定
93    pub fn with_scenario(mut self, scenario: impl Into<String>) -> Self {
94        self.scenario = Some(scenario.into());
95        self
96    }
97
98    /// GroupId を設定
99    pub fn with_group_id(mut self, group_id: impl Into<String>) -> Self {
100        self.group_id = Some(group_id.into());
101        self
102    }
103}
104
105// ============================================================================
106// LifecycleHook trait
107// ============================================================================
108
109/// Orchestrator ライフサイクルの同期コールバック
110///
111/// Swarm の開始・終了時に**同期的に**呼び出される。
112/// Learning の自動実行、統計の永続化などに使用。
113///
114/// # Hook vs Sink
115///
116/// - **Hook**: 同期的に呼ばれる callback(本 trait)
117/// - **Sink**: 非同期でデータを消費する処理(`pipeline::EventSink`)
118///
119/// # 注意
120///
121/// `on_terminate` 内で重い処理を行う場合は `tokio::task::spawn_blocking`
122/// を使用してブロッキングを回避すること。
123pub trait LifecycleHook: Send + Sync {
124    /// Swarm 開始時に呼ばれる
125    fn on_start(&mut self, worker_count: usize) {
126        let _ = worker_count;
127    }
128
129    /// Swarm 終了時に呼ばれる
130    ///
131    /// `state` は終了時点の SwarmState への参照。
132    /// Learning 実行などはここで行う。
133    fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult);
134
135    /// Hook の名前(デバッグ用)
136    fn name(&self) -> &str {
137        "lifecycle_hook"
138    }
139}
140
141// ============================================================================
142// LearningLifecycleHook
143// ============================================================================
144
145/// Learning を自動実行する LifecycleHook
146///
147/// Swarm 終了時に TrainTrigger をチェックし、条件を満たせば Learn を実行。
148///
149/// # Eval カウント
150///
151/// 内部で Eval 実行回数をカウントし、CountTrigger と連携。
152/// 例: 10 回の Eval 後に Learn を実行。
153///
154/// # 非同期 Sink との違い
155///
156/// | 項目 | LearningLifecycleHook | LearningSink (pipeline) |
157/// |------|----------------------|------------------------|
158/// | 呼び出し | Orchestrator 終了時 | ファイル監視イベント時 |
159/// | 実行モデル | 同期 callback | async trait |
160/// | 用途 | Eval 終了後の即時学習 | Daemon モードでの継続学習 |
161///
162/// # 将来拡張
163///
164/// - IPC/WebSocket で外部 Learner プロセスと通信
165/// - DPO 学習の自動実行
166pub struct LearningLifecycleHook {
167    /// 学習データのパス
168    learning_path: PathBuf,
169    /// TrainTrigger
170    trigger: Arc<dyn TrainTrigger>,
171    /// Eval 実行カウント(グローバル)
172    eval_count: Arc<AtomicUsize>,
173    /// 最後に Learn を実行した時の eval_count
174    last_learn_count: usize,
175    /// シナリオ名(Learn 実行時に使用)
176    scenario: Option<String>,
177    /// Learn 実行コールバック(カスタマイズ用)
178    learn_callback: Option<Box<dyn Fn(&str, &SwarmState) + Send + Sync>>,
179}
180
181impl LearningLifecycleHook {
182    /// 新しい LearningLifecycleHook を作成
183    ///
184    /// デフォルトは AlwaysTrigger(毎回 Learn 実行)。
185    pub fn new(learning_path: impl Into<PathBuf>) -> Self {
186        Self {
187            learning_path: learning_path.into(),
188            trigger: Arc::new(AlwaysTrigger),
189            eval_count: Arc::new(AtomicUsize::new(0)),
190            last_learn_count: 0,
191            scenario: None,
192            learn_callback: None,
193        }
194    }
195
196    /// TrainTrigger を設定
197    ///
198    /// # Example
199    ///
200    /// ```ignore
201    /// use swarm_engine_core::learn::TriggerBuilder;
202    ///
203    /// // 10 回の Eval 後に Learn 実行
204    /// let hook = LearningLifecycleHook::new(path)
205    ///     .with_trigger(TriggerBuilder::every_n_episodes(10));
206    /// ```
207    pub fn with_trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
208        self.trigger = trigger;
209        self
210    }
211
212    /// シナリオ名を設定
213    pub fn with_scenario(mut self, scenario: impl Into<String>) -> Self {
214        self.scenario = Some(scenario.into());
215        self
216    }
217
218    /// Learn 実行コールバックを設定
219    ///
220    /// デフォルトの Learn 実行ロジックをオーバーライド。
221    /// DPO 学習など、カスタムの学習処理を実行できる。
222    pub fn with_learn_callback<F>(mut self, callback: F) -> Self
223    where
224        F: Fn(&str, &SwarmState) + Send + Sync + 'static,
225    {
226        self.learn_callback = Some(Box::new(callback));
227        self
228    }
229
230    /// 外部から eval_count を共有するための Arc を取得
231    ///
232    /// 複数の Orchestrator 間で Eval カウントを共有する場合に使用。
233    pub fn eval_count_handle(&self) -> Arc<AtomicUsize> {
234        Arc::clone(&self.eval_count)
235    }
236
237    /// 共有された eval_count を設定
238    pub fn with_shared_eval_count(mut self, count: Arc<AtomicUsize>) -> Self {
239        self.eval_count = count;
240        self
241    }
242
243    /// 現在の Eval カウントを取得
244    pub fn current_eval_count(&self) -> usize {
245        self.eval_count.load(Ordering::SeqCst)
246    }
247
248    /// Learning パスを取得
249    pub fn learning_path(&self) -> &PathBuf {
250        &self.learning_path
251    }
252
253    /// Trigger をチェック
254    fn should_learn(&self) -> bool {
255        let current = self.eval_count.load(Ordering::SeqCst);
256        let ctx = TriggerContext::with_count(current).last_train_count(self.last_learn_count);
257        self.trigger.should_train(&ctx).unwrap_or(false)
258    }
259
260    /// Learn を実行
261    fn run_learn(&mut self, state: &SwarmState) {
262        let scenario = self.scenario.as_deref().unwrap_or("unknown");
263
264        tracing::info!(
265            scenario = scenario,
266            eval_count = self.current_eval_count(),
267            trigger = self.trigger.name(),
268            "Running learning after trigger condition met"
269        );
270
271        // カスタムコールバックがあれば使用
272        if let Some(ref callback) = self.learn_callback {
273            callback(scenario, state);
274        } else {
275            // デフォルト: LearningStore を使った offline learning
276            self.run_default_learn(scenario);
277        }
278
279        // last_learn_count を更新
280        self.last_learn_count = self.eval_count.load(Ordering::SeqCst);
281    }
282
283    /// デフォルトの Learn 実行
284    fn run_default_learn(&self, scenario: &str) {
285        use crate::learn::LearningStore;
286
287        match LearningStore::new(&self.learning_path) {
288            Ok(store) => match store.run_offline_learning(scenario, 20) {
289                Ok(model) => {
290                    tracing::info!(
291                        scenario = scenario,
292                        sessions = model.analyzed_sessions,
293                        "Offline learning completed"
294                    );
295                }
296                Err(e) => {
297                    tracing::warn!(
298                        scenario = scenario,
299                        error = %e,
300                        "Offline learning failed"
301                    );
302                }
303            },
304            Err(e) => {
305                tracing::error!(
306                    path = %self.learning_path.display(),
307                    error = %e,
308                    "Failed to create LearningStore"
309                );
310            }
311        }
312    }
313}
314
315impl LifecycleHook for LearningLifecycleHook {
316    fn on_start(&mut self, worker_count: usize) {
317        tracing::debug!(
318            worker_count = worker_count,
319            eval_count = self.current_eval_count(),
320            "LearningLifecycleHook: Swarm started"
321        );
322    }
323
324    fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult) {
325        // Eval カウントをインクリメント
326        let new_count = self.eval_count.fetch_add(1, Ordering::SeqCst) + 1;
327
328        tracing::debug!(
329            eval_count = new_count,
330            total_ticks = result.total_ticks,
331            trigger = self.trigger.name(),
332            "LearningLifecycleHook: Swarm terminated"
333        );
334
335        // Trigger チェック
336        if self.should_learn() {
337            self.run_learn(state);
338        } else {
339            tracing::debug!(
340                eval_count = new_count,
341                last_learn = self.last_learn_count,
342                trigger = self.trigger.name(),
343                "Trigger not met, skipping learning"
344            );
345        }
346    }
347
348    fn name(&self) -> &str {
349        "learning_lifecycle_hook"
350    }
351}
352
353// ============================================================================
354// CompositeLifecycleHook
355// ============================================================================
356
357/// 複数の LifecycleHook を合成
358///
359/// 複数の Hook を登録し、順番に呼び出す。
360///
361/// # Example
362///
363/// ```ignore
364/// let composite = CompositeLifecycleHook::new()
365///     .add(Box::new(LearningLifecycleHook::new(path)))
366///     .add(Box::new(CustomHook::new()));
367/// ```
368pub struct CompositeLifecycleHook {
369    hooks: Vec<Box<dyn LifecycleHook>>,
370}
371
372impl CompositeLifecycleHook {
373    /// 新しい CompositeLifecycleHook を作成
374    pub fn new() -> Self {
375        Self { hooks: Vec::new() }
376    }
377
378    /// Hook を追加
379    pub fn add(mut self, hook: Box<dyn LifecycleHook>) -> Self {
380        self.hooks.push(hook);
381        self
382    }
383}
384
385impl Default for CompositeLifecycleHook {
386    fn default() -> Self {
387        Self::new()
388    }
389}
390
391impl LifecycleHook for CompositeLifecycleHook {
392    fn on_start(&mut self, worker_count: usize) {
393        for hook in &mut self.hooks {
394            hook.on_start(worker_count);
395        }
396    }
397
398    fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult) {
399        for hook in &mut self.hooks {
400            hook.on_terminate(state, result);
401        }
402    }
403
404    fn name(&self) -> &str {
405        "composite_lifecycle_hook"
406    }
407}
408
409// ============================================================================
410// Tests
411// ============================================================================
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use std::sync::atomic::AtomicBool;
417
418    struct TestHook {
419        started: Arc<AtomicBool>,
420        terminated: Arc<AtomicBool>,
421    }
422
423    impl TestHook {
424        fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>) {
425            let started = Arc::new(AtomicBool::new(false));
426            let terminated = Arc::new(AtomicBool::new(false));
427            (
428                Self {
429                    started: Arc::clone(&started),
430                    terminated: Arc::clone(&terminated),
431                },
432                started,
433                terminated,
434            )
435        }
436    }
437
438    impl LifecycleHook for TestHook {
439        fn on_start(&mut self, _worker_count: usize) {
440            self.started.store(true, Ordering::SeqCst);
441        }
442
443        fn on_terminate(&mut self, _state: &SwarmState, _result: &SwarmResult) {
444            self.terminated.store(true, Ordering::SeqCst);
445        }
446    }
447
448    #[test]
449    fn test_termination_stats_from_state() {
450        let state = SwarmState::new(4);
451        let stats = TerminationStats::from_state(&state);
452        assert_eq!(stats.total_ticks, 0);
453        assert!(stats.scenario.is_none());
454    }
455
456    #[test]
457    fn test_learning_lifecycle_hook_eval_count() {
458        let hook = LearningLifecycleHook::new("/tmp/test");
459        assert_eq!(hook.current_eval_count(), 0);
460
461        let handle = hook.eval_count_handle();
462        handle.fetch_add(5, Ordering::SeqCst);
463        assert_eq!(hook.current_eval_count(), 5);
464    }
465
466    #[test]
467    fn test_composite_hook() {
468        let (hook1, started1, terminated1) = TestHook::new();
469        let (hook2, started2, terminated2) = TestHook::new();
470
471        let mut composite = CompositeLifecycleHook::new()
472            .add(Box::new(hook1))
473            .add(Box::new(hook2));
474
475        composite.on_start(4);
476        assert!(started1.load(Ordering::SeqCst));
477        assert!(started2.load(Ordering::SeqCst));
478
479        let state = SwarmState::new(4);
480        let result = SwarmResult {
481            total_ticks: 10,
482            total_duration: std::time::Duration::from_secs(1),
483            completed: true,
484        };
485        composite.on_terminate(&state, &result);
486        assert!(terminated1.load(Ordering::SeqCst));
487        assert!(terminated2.load(Ordering::SeqCst));
488    }
489}