Skip to main content

swarm_engine_core/events/
lifecycle.rs

1//! Lifecycle Hooks - Orchestrator ライフサイクルフック
2//!
3//! Swarm の開始・終了時に同期的に呼び出されるコールバック。
4//!
5//! # Hook vs Sink の命名規則
6//!
7//! | 名前 | 実行モデル | 用途 |
8//! |------|----------|------|
9//! | **Hook** | 同期 callback | Orchestrator lifecycle など、同期的に呼ばれる処理 |
10//! | **Sink** | async trait | Pipeline 終端など、非同期でデータを消費する処理 |
11//!
12//! # 設計
13//!
14//! - `LifecycleEvent`: Swarm の状態変化を表すイベント
15//! - `LifecycleHook`: 同期的に呼び出される callback trait
16//! - `LearningLifecycleHook`: TrainTrigger + Learn 実行の実装
17//!
18//! # 使用例
19//!
20//! ```ignore
21//! use swarm_engine_core::events::{LifecycleHook, LearningLifecycleHook};
22//! use swarm_engine_core::learn::TriggerBuilder;
23//!
24//! // 10 回の Eval 実行後に Learn を実行
25//! let hook = LearningLifecycleHook::new(learning_path)
26//!     .with_trigger(TriggerBuilder::every_n_episodes(10));
27//!
28//! let orchestrator = OrchestratorBuilder::new()
29//!     .lifecycle_hook(Box::new(hook))
30//!     .build(runtime);
31//! ```
32
33/// Learn 実行コールバック型
34type LearnCallback = Box<dyn Fn(&str, &SwarmState) + Send + Sync>;
35
36use std::path::PathBuf;
37use std::sync::atomic::{AtomicUsize, Ordering};
38use std::sync::Arc;
39
40use crate::learn::{AlwaysTrigger, TrainTrigger, TriggerContext};
41use crate::orchestrator::SwarmResult;
42use crate::state::SwarmState;
43
44// ============================================================================
45// LifecycleEvent
46// ============================================================================
47
48/// Orchestrator のライフサイクルイベント
49#[derive(Debug, Clone)]
50pub enum LifecycleEvent {
51    /// Swarm 開始
52    Started {
53        /// Worker 数
54        worker_count: usize,
55    },
56    /// Swarm 終了
57    Terminated {
58        /// 実行結果
59        result: SwarmResult,
60        /// 終了時の統計(シリアライズ可能な形式)
61        stats: TerminationStats,
62    },
63}
64
65/// 終了時の統計情報(シリアライズ可能)
66#[derive(Debug, Clone, Default)]
67pub struct TerminationStats {
68    /// 総 Tick 数
69    pub total_ticks: u64,
70    /// 総アクション数
71    pub total_actions: u64,
72    /// 成功アクション数
73    pub successful_actions: u64,
74    /// 失敗アクション数
75    pub failed_actions: u64,
76    /// シナリオ名(あれば)
77    pub scenario: Option<String>,
78    /// GroupId(あれば)
79    pub group_id: Option<String>,
80}
81
82impl TerminationStats {
83    /// SwarmState から統計を抽出
84    pub fn from_state(state: &SwarmState) -> Self {
85        Self {
86            total_ticks: state.shared.tick,
87            total_actions: state.shared.stats.total_visits() as u64,
88            successful_actions: state.shared.stats.total_successes() as u64,
89            failed_actions: state.shared.stats.total_failures() as u64,
90            scenario: None,
91            group_id: None,
92        }
93    }
94
95    /// シナリオ名を設定
96    pub fn with_scenario(mut self, scenario: impl Into<String>) -> Self {
97        self.scenario = Some(scenario.into());
98        self
99    }
100
101    /// GroupId を設定
102    pub fn with_group_id(mut self, group_id: impl Into<String>) -> Self {
103        self.group_id = Some(group_id.into());
104        self
105    }
106}
107
108// ============================================================================
109// LifecycleHook trait
110// ============================================================================
111
112/// Orchestrator ライフサイクルの同期コールバック
113///
114/// Swarm の開始・終了時に**同期的に**呼び出される。
115/// Learning の自動実行、統計の永続化などに使用。
116///
117/// # Hook vs Sink
118///
119/// - **Hook**: 同期的に呼ばれる callback(本 trait)
120/// - **Sink**: 非同期でデータを消費する処理(`pipeline::EventSink`)
121///
122/// # 注意
123///
124/// `on_terminate` 内で重い処理を行う場合は `tokio::task::spawn_blocking`
125/// を使用してブロッキングを回避すること。
126pub trait LifecycleHook: Send + Sync {
127    /// Swarm 開始時に呼ばれる
128    fn on_start(&mut self, worker_count: usize) {
129        let _ = worker_count;
130    }
131
132    /// Swarm 終了時に呼ばれる
133    ///
134    /// `state` は終了時点の SwarmState への参照。
135    /// Learning 実行などはここで行う。
136    fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult);
137
138    /// Hook の名前(デバッグ用)
139    fn name(&self) -> &str {
140        "lifecycle_hook"
141    }
142}
143
144// ============================================================================
145// LearningLifecycleHook
146// ============================================================================
147
148/// Learning を自動実行する LifecycleHook
149///
150/// Swarm 終了時に TrainTrigger をチェックし、条件を満たせば Learn を実行。
151///
152/// # Eval カウント
153///
154/// 内部で Eval 実行回数をカウントし、CountTrigger と連携。
155/// 例: 10 回の Eval 後に Learn を実行。
156///
157/// # 非同期 Sink との違い
158///
159/// | 項目 | LearningLifecycleHook | LearningSink (pipeline) |
160/// |------|----------------------|------------------------|
161/// | 呼び出し | Orchestrator 終了時 | ファイル監視イベント時 |
162/// | 実行モデル | 同期 callback | async trait |
163/// | 用途 | Eval 終了後の即時学習 | Daemon モードでの継続学習 |
164///
165/// # 将来拡張
166///
167/// - IPC/WebSocket で外部 Learner プロセスと通信
168/// - DPO 学習の自動実行
169pub struct LearningLifecycleHook {
170    /// 学習データのパス
171    learning_path: PathBuf,
172    /// TrainTrigger
173    trigger: Arc<dyn TrainTrigger>,
174    /// Eval 実行カウント(グローバル)
175    eval_count: Arc<AtomicUsize>,
176    /// 最後に Learn を実行した時の eval_count
177    last_learn_count: usize,
178    /// シナリオ名(Learn 実行時に使用)
179    scenario: Option<String>,
180    /// Learn 実行コールバック(カスタマイズ用)
181    learn_callback: Option<LearnCallback>,
182}
183
184impl LearningLifecycleHook {
185    /// 新しい LearningLifecycleHook を作成
186    ///
187    /// デフォルトは AlwaysTrigger(毎回 Learn 実行)。
188    pub fn new(learning_path: impl Into<PathBuf>) -> Self {
189        Self {
190            learning_path: learning_path.into(),
191            trigger: Arc::new(AlwaysTrigger),
192            eval_count: Arc::new(AtomicUsize::new(0)),
193            last_learn_count: 0,
194            scenario: None,
195            learn_callback: None,
196        }
197    }
198
199    /// TrainTrigger を設定
200    ///
201    /// # Example
202    ///
203    /// ```ignore
204    /// use swarm_engine_core::learn::TriggerBuilder;
205    ///
206    /// // 10 回の Eval 後に Learn 実行
207    /// let hook = LearningLifecycleHook::new(path)
208    ///     .with_trigger(TriggerBuilder::every_n_episodes(10));
209    /// ```
210    pub fn with_trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
211        self.trigger = trigger;
212        self
213    }
214
215    /// シナリオ名を設定
216    pub fn with_scenario(mut self, scenario: impl Into<String>) -> Self {
217        self.scenario = Some(scenario.into());
218        self
219    }
220
221    /// Learn 実行コールバックを設定
222    ///
223    /// デフォルトの Learn 実行ロジックをオーバーライド。
224    /// DPO 学習など、カスタムの学習処理を実行できる。
225    pub fn with_learn_callback<F>(mut self, callback: F) -> Self
226    where
227        F: Fn(&str, &SwarmState) + Send + Sync + 'static,
228    {
229        self.learn_callback = Some(Box::new(callback));
230        self
231    }
232
233    /// 外部から eval_count を共有するための Arc を取得
234    ///
235    /// 複数の Orchestrator 間で Eval カウントを共有する場合に使用。
236    pub fn eval_count_handle(&self) -> Arc<AtomicUsize> {
237        Arc::clone(&self.eval_count)
238    }
239
240    /// 共有された eval_count を設定
241    pub fn with_shared_eval_count(mut self, count: Arc<AtomicUsize>) -> Self {
242        self.eval_count = count;
243        self
244    }
245
246    /// 現在の Eval カウントを取得
247    pub fn current_eval_count(&self) -> usize {
248        self.eval_count.load(Ordering::SeqCst)
249    }
250
251    /// Learning パスを取得
252    pub fn learning_path(&self) -> &PathBuf {
253        &self.learning_path
254    }
255
256    /// Trigger をチェック
257    fn should_learn(&self) -> bool {
258        let current = self.eval_count.load(Ordering::SeqCst);
259        let ctx = TriggerContext::with_count(current).last_train_count(self.last_learn_count);
260        self.trigger.should_train(&ctx).unwrap_or(false)
261    }
262
263    /// Learn を実行
264    fn run_learn(&mut self, state: &SwarmState) {
265        let scenario = self.scenario.as_deref().unwrap_or("unknown");
266
267        tracing::info!(
268            scenario = scenario,
269            eval_count = self.current_eval_count(),
270            trigger = self.trigger.name(),
271            "Running learning after trigger condition met"
272        );
273
274        // カスタムコールバックがあれば使用
275        if let Some(ref callback) = self.learn_callback {
276            callback(scenario, state);
277        } else {
278            // デフォルト: LearningStore を使った offline learning
279            self.run_default_learn(scenario);
280        }
281
282        // last_learn_count を更新
283        self.last_learn_count = self.eval_count.load(Ordering::SeqCst);
284    }
285
286    /// デフォルトの Learn 実行
287    fn run_default_learn(&self, scenario: &str) {
288        use crate::learn::LearningStore;
289
290        match LearningStore::new(&self.learning_path) {
291            Ok(store) => match store.run_offline_learning(scenario, 20) {
292                Ok(model) => {
293                    tracing::info!(
294                        scenario = scenario,
295                        sessions = model.analyzed_sessions,
296                        "Offline learning completed"
297                    );
298                }
299                Err(e) => {
300                    tracing::warn!(
301                        scenario = scenario,
302                        error = %e,
303                        "Offline learning failed"
304                    );
305                }
306            },
307            Err(e) => {
308                tracing::error!(
309                    path = %self.learning_path.display(),
310                    error = %e,
311                    "Failed to create LearningStore"
312                );
313            }
314        }
315    }
316}
317
318impl LifecycleHook for LearningLifecycleHook {
319    fn on_start(&mut self, worker_count: usize) {
320        tracing::debug!(
321            worker_count = worker_count,
322            eval_count = self.current_eval_count(),
323            "LearningLifecycleHook: Swarm started"
324        );
325    }
326
327    fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult) {
328        // Eval カウントをインクリメント
329        let new_count = self.eval_count.fetch_add(1, Ordering::SeqCst) + 1;
330
331        tracing::debug!(
332            eval_count = new_count,
333            total_ticks = result.total_ticks,
334            trigger = self.trigger.name(),
335            "LearningLifecycleHook: Swarm terminated"
336        );
337
338        // Trigger チェック
339        if self.should_learn() {
340            self.run_learn(state);
341        } else {
342            tracing::debug!(
343                eval_count = new_count,
344                last_learn = self.last_learn_count,
345                trigger = self.trigger.name(),
346                "Trigger not met, skipping learning"
347            );
348        }
349    }
350
351    fn name(&self) -> &str {
352        "learning_lifecycle_hook"
353    }
354}
355
356// ============================================================================
357// CompositeLifecycleHook
358// ============================================================================
359
360/// 複数の LifecycleHook を合成
361///
362/// 複数の Hook を登録し、順番に呼び出す。
363///
364/// # Example
365///
366/// ```ignore
367/// let composite = CompositeLifecycleHook::new()
368///     .with_hook(Box::new(LearningLifecycleHook::new(path)))
369///     .with_hook(Box::new(CustomHook::new()));
370/// ```
371pub struct CompositeLifecycleHook {
372    hooks: Vec<Box<dyn LifecycleHook>>,
373}
374
375impl CompositeLifecycleHook {
376    /// 新しい CompositeLifecycleHook を作成
377    pub fn new() -> Self {
378        Self { hooks: Vec::new() }
379    }
380
381    /// Hook を追加
382    pub fn with_hook(mut self, hook: Box<dyn LifecycleHook>) -> Self {
383        self.hooks.push(hook);
384        self
385    }
386}
387
388impl Default for CompositeLifecycleHook {
389    fn default() -> Self {
390        Self::new()
391    }
392}
393
394impl LifecycleHook for CompositeLifecycleHook {
395    fn on_start(&mut self, worker_count: usize) {
396        for hook in &mut self.hooks {
397            hook.on_start(worker_count);
398        }
399    }
400
401    fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult) {
402        for hook in &mut self.hooks {
403            hook.on_terminate(state, result);
404        }
405    }
406
407    fn name(&self) -> &str {
408        "composite_lifecycle_hook"
409    }
410}
411
412// ============================================================================
413// Tests
414// ============================================================================
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use std::sync::atomic::AtomicBool;
420
421    struct TestHook {
422        started: Arc<AtomicBool>,
423        terminated: Arc<AtomicBool>,
424    }
425
426    impl TestHook {
427        fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>) {
428            let started = Arc::new(AtomicBool::new(false));
429            let terminated = Arc::new(AtomicBool::new(false));
430            (
431                Self {
432                    started: Arc::clone(&started),
433                    terminated: Arc::clone(&terminated),
434                },
435                started,
436                terminated,
437            )
438        }
439    }
440
441    impl LifecycleHook for TestHook {
442        fn on_start(&mut self, _worker_count: usize) {
443            self.started.store(true, Ordering::SeqCst);
444        }
445
446        fn on_terminate(&mut self, _state: &SwarmState, _result: &SwarmResult) {
447            self.terminated.store(true, Ordering::SeqCst);
448        }
449    }
450
451    #[test]
452    fn test_termination_stats_from_state() {
453        let state = SwarmState::new(4);
454        let stats = TerminationStats::from_state(&state);
455        assert_eq!(stats.total_ticks, 0);
456        assert!(stats.scenario.is_none());
457    }
458
459    #[test]
460    fn test_learning_lifecycle_hook_eval_count() {
461        let hook = LearningLifecycleHook::new("/tmp/test");
462        assert_eq!(hook.current_eval_count(), 0);
463
464        let handle = hook.eval_count_handle();
465        handle.fetch_add(5, Ordering::SeqCst);
466        assert_eq!(hook.current_eval_count(), 5);
467    }
468
469    #[test]
470    fn test_composite_hook() {
471        let (hook1, started1, terminated1) = TestHook::new();
472        let (hook2, started2, terminated2) = TestHook::new();
473
474        let mut composite = CompositeLifecycleHook::new()
475            .with_hook(Box::new(hook1))
476            .with_hook(Box::new(hook2));
477
478        composite.on_start(4);
479        assert!(started1.load(Ordering::SeqCst));
480        assert!(started2.load(Ordering::SeqCst));
481
482        let state = SwarmState::new(4);
483        let result = SwarmResult {
484            total_ticks: 10,
485            total_duration: std::time::Duration::from_secs(1),
486            completed: true,
487        };
488        composite.on_terminate(&state, &result);
489        assert!(terminated1.load(Ordering::SeqCst));
490        assert!(terminated2.load(Ordering::SeqCst));
491    }
492}