Skip to main content

swarm_engine_core/learn/
episode.rs

1//! Episode - 学習の基本単位
2//!
3//! ## 設計思想
4//!
5//! Episode は既存の統計システム(SwarmStats, N-gram)の **派生** であり、
6//! 既存システムに影響を与えない追加レイヤーとして機能する。
7//!
8//! - **既存統計が取れないのは NG**: Episode はあくまでオプショナル
9//! - **Episode が取れなくても OK**: 既存の Core 統計は独立して動作
10//! - **変換は Trait ベース**: From/Into で柔軟に既存 Event から変換
11//!
12//! ## アーキテクチャ
13//!
14//! ```text
15//! ActionEvent ──┬──▶ SwarmStats.record() [既存、必須]
16//!               │
17//!               └──▶ Episode への変換 [新規、optional]
18//!                    └─▶ EpisodeStore.append()
19//!
20//! LlmDebugEvent ─┬──▶ StderrLlmSubscriber [既存、デバッグ用]
21//!                │
22//!                └──▶ Episode への変換 [新規、optional]
23//!                     └─▶ EpisodeStore.append()
24//! ```
25
26use std::collections::HashMap;
27
28use serde::{Deserialize, Serialize};
29
30use super::record::{ActionRecord, FromRecord, Record};
31use crate::types::{GroupId, TaskId};
32use crate::util::{epoch_millis, epoch_millis_for_ordering};
33
34// ============================================================================
35// Episode Trait
36// ============================================================================
37
38/// Episode trait - 学習の基本単位を表すインターフェース
39///
40/// 各 LearnModel に対応する具体的な Episode 型はこの trait を実装する。
41/// これにより、学習対象ごとに最適化された構造を持ちながら、
42/// 共通のインターフェースで扱うことができる。
43///
44/// ## 実装例
45///
46/// ```ignore
47/// // WorkerTask 学習用の Episode
48/// pub struct WorkerTaskEpisode {
49///     id: EpisodeId,
50///     task_id: TaskId,
51///     actions: Vec<ActionRecord>,
52///     outcome: Outcome,
53/// }
54///
55/// impl EpisodeTrait for WorkerTaskEpisode {
56///     fn id(&self) -> &EpisodeId { &self.id }
57///     fn task_id(&self) -> TaskId { self.task_id }
58///     // ...
59/// }
60/// ```
61pub trait EpisodeTrait: Send + Sync {
62    /// Episode ID
63    fn id(&self) -> &EpisodeId;
64
65    /// 対応する LearnModel 名
66    fn learn_model_name(&self) -> &str;
67
68    /// Task ID(どのタスクの Episode か)
69    fn task_id(&self) -> Option<TaskId>;
70
71    /// Group ID(DPO 学習での比較グループ)
72    fn group_id(&self) -> Option<GroupId>;
73
74    /// 結果(成功/失敗/タイムアウト等)
75    fn outcome(&self) -> &Outcome;
76
77    /// 成功したかどうか
78    fn is_success(&self) -> bool {
79        self.outcome().is_success()
80    }
81
82    /// シナリオ名(あれば)
83    fn scenario_name(&self) -> Option<&str>;
84}
85
86// ============================================================================
87// EpisodeId
88// ============================================================================
89
90/// Episode ID - 一意識別子
91///
92/// timestamp (ms) + counter の組み合わせでユニーク性を保証。
93/// - timestamp: ソート可能性(ordering)
94/// - counter: 同一 ms 内の順序保証 + NTP 巻き戻り耐性
95#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
96pub struct EpisodeId {
97    /// タイムスタンプ部分(Unix epoch ms)
98    pub timestamp_ms: u64,
99    /// カウンタ部分(単調増加)
100    pub counter: u32,
101}
102
103impl EpisodeId {
104    pub fn new() -> Self {
105        use std::sync::atomic::{AtomicU32, Ordering};
106        static COUNTER: AtomicU32 = AtomicU32::new(0);
107
108        Self {
109            timestamp_ms: epoch_millis_for_ordering(),
110            counter: COUNTER.fetch_add(1, Ordering::Relaxed),
111        }
112    }
113
114    /// 既知の値から作成(テスト用)
115    pub fn from_parts(timestamp_ms: u64, counter: u32) -> Self {
116        Self {
117            timestamp_ms,
118            counter,
119        }
120    }
121}
122
123impl Default for EpisodeId {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129impl std::fmt::Display for EpisodeId {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        write!(f, "{}-{:08x}", self.timestamp_ms, self.counter)
132    }
133}
134
135// ============================================================================
136// Outcome
137// ============================================================================
138
139/// エピソードの結果
140///
141/// LearnModel::evaluate() で判定される。
142///
143/// ## 構造
144///
145/// - **Single Episode**: Success / Failure / Timeout
146/// - **未判定**: Unknown
147///
148/// ## 設計方針
149///
150/// Episode は「1 Task 実行」「1 生成」などの単一単位を表す。
151/// 複数 Episode の集計は Eval レポート側 (`AggregatedResults`) で行う。
152#[derive(Debug, Clone, Serialize, Deserialize)]
153#[serde(tag = "type")]
154#[derive(Default)]
155pub enum Outcome {
156    // ========================================================================
157    // Single Episode
158    // ========================================================================
159    /// 成功
160    Success {
161        /// スコア(0.0〜1.0、または任意のスケール)
162        score: f64,
163    },
164    /// 失敗
165    Failure {
166        /// 失敗理由
167        reason: String,
168    },
169    /// タイムアウト
170    Timeout {
171        /// 部分スコア(タイムアウト時点での進捗)
172        partial_score: Option<f64>,
173    },
174
175    // ========================================================================
176    // Unknown
177    // ========================================================================
178    /// 不明(判定できない場合)
179    #[default]
180    Unknown,
181}
182
183impl Outcome {
184    // ========================================================================
185    // Constructors - Single Task
186    // ========================================================================
187
188    pub fn success(score: f64) -> Self {
189        Self::Success { score }
190    }
191
192    pub fn success_binary() -> Self {
193        Self::Success { score: 1.0 }
194    }
195
196    pub fn failure(reason: impl Into<String>) -> Self {
197        Self::Failure {
198            reason: reason.into(),
199        }
200    }
201
202    pub fn timeout(partial_score: Option<f64>) -> Self {
203        Self::Timeout { partial_score }
204    }
205
206    // ========================================================================
207    // Predicates
208    // ========================================================================
209
210    pub fn is_success(&self) -> bool {
211        matches!(self, Self::Success { .. })
212    }
213
214    pub fn is_failure(&self) -> bool {
215        matches!(self, Self::Failure { .. } | Self::Timeout { .. })
216    }
217
218    /// Unknown かどうか
219    pub fn is_unknown(&self) -> bool {
220        matches!(self, Self::Unknown)
221    }
222
223    // ========================================================================
224    // Accessors
225    // ========================================================================
226
227    /// スコアを取得(失敗なら0.0)
228    pub fn score(&self) -> f64 {
229        match self {
230            Self::Success { score } => *score,
231            Self::Timeout { partial_score } => partial_score.unwrap_or(0.0),
232            _ => 0.0,
233        }
234    }
235}
236
237// ============================================================================
238// EpisodeContext
239// ============================================================================
240
241/// エピソードのコンテキスト
242///
243/// Record のコレクションを保持。新しい Record 種別が追加されても
244/// 構造体を変更する必要がない。
245#[derive(Debug, Clone, Default, Serialize, Deserialize)]
246pub struct EpisodeContext {
247    /// Record のリスト(統一的に保持)
248    pub records: Vec<Record>,
249}
250
251impl EpisodeContext {
252    pub fn new() -> Self {
253        Self::default()
254    }
255
256    /// Record を追加
257    pub fn push(&mut self, record: impl Into<Record>) {
258        self.records.push(record.into());
259    }
260
261    /// Record を追加(builder pattern)
262    pub fn with_record(mut self, record: impl Into<Record>) -> Self {
263        self.records.push(record.into());
264        self
265    }
266
267    /// 型でフィルタしてイテレート
268    ///
269    /// ```ignore
270    /// context.iter::<ActionRecord>()
271    /// context.iter::<LlmCallRecord>()
272    /// ```
273    pub fn iter<'a, T: FromRecord + 'a>(&'a self) -> impl Iterator<Item = &'a T> {
274        self.records.iter().filter_map(T::from_record)
275    }
276
277    /// 指定した型の最初の Record を取得
278    ///
279    /// ```ignore
280    /// context.first::<DependencyGraphRecord>()
281    /// context.first::<LlmCallRecord>()
282    /// ```
283    pub fn first<T: FromRecord>(&self) -> Option<&T> {
284        self.iter::<T>().next()
285    }
286
287    /// 全 Record 数
288    pub fn len(&self) -> usize {
289        self.records.len()
290    }
291
292    /// 空かどうか
293    pub fn is_empty(&self) -> bool {
294        self.records.is_empty()
295    }
296}
297
298// ============================================================================
299// EpisodeMetadata
300// ============================================================================
301
302/// エピソードのメタデータ
303#[derive(Debug, Clone, Default, Serialize, Deserialize)]
304pub struct EpisodeMetadata {
305    /// Strategy名(どの抽出戦略で生成されたか)
306    pub strategy_name: Option<String>,
307    /// シナリオ名
308    pub scenario_name: Option<String>,
309    /// 作成日時(Unix timestamp ms)
310    pub created_at: u64,
311    /// 開始日時(Unix timestamp ms)
312    pub started_at: Option<u64>,
313    /// 終了日時(Unix timestamp ms)
314    pub ended_at: Option<u64>,
315    /// 拡張タグ
316    pub tags: HashMap<String, String>,
317}
318
319impl EpisodeMetadata {
320    pub fn new() -> Self {
321        Self {
322            created_at: epoch_millis(),
323            ..Default::default()
324        }
325    }
326
327    pub fn with_strategy(mut self, name: impl Into<String>) -> Self {
328        self.strategy_name = Some(name.into());
329        self
330    }
331
332    pub fn with_scenario(mut self, name: impl Into<String>) -> Self {
333        self.scenario_name = Some(name.into());
334        self
335    }
336
337    pub fn with_duration(mut self, start: u64, end: u64) -> Self {
338        self.started_at = Some(start);
339        self.ended_at = Some(end);
340        self
341    }
342
343    pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
344        self.tags.insert(key.into(), value.into());
345        self
346    }
347
348    /// 実行時間(ミリ秒)
349    pub fn duration_ms(&self) -> Option<u64> {
350        match (self.started_at, self.ended_at) {
351            (Some(start), Some(end)) => Some(end.saturating_sub(start)),
352            _ => None,
353        }
354    }
355}
356
357// ============================================================================
358// Episode Entity
359// ============================================================================
360
361/// Episode - 学習の基本単位(汎用実装)
362///
363/// Swarmの「経験」を表現する。LearnModel によって Record[] から構築され、
364/// TrainingData 生成の元となる。
365///
366/// ## Note
367///
368/// これは汎用の Episode 実装。学習対象ごとに最適化した Episode が必要な場合は、
369/// `EpisodeTrait` を実装した新しい型を定義すること。
370#[derive(Debug, Clone, Serialize, Deserialize)]
371pub struct Episode {
372    /// 一意識別子
373    pub id: EpisodeId,
374    /// どの LearnModel で生成されたか(e.g., "ngram-5", "worker_task")
375    pub learn_model: String,
376    /// Task ID(どのタスクの Episode か)
377    #[serde(default, skip_serializing_if = "Option::is_none")]
378    pub task_id: Option<TaskId>,
379    /// Group ID(DPO 学習での比較グループ)
380    #[serde(default, skip_serializing_if = "Option::is_none")]
381    pub group_id: Option<GroupId>,
382    /// コンテキスト(LLM呼び出し + アクション履歴)
383    pub context: EpisodeContext,
384    /// 結果(LearnModel が判定)
385    pub outcome: Outcome,
386    /// メタデータ
387    pub metadata: EpisodeMetadata,
388}
389
390impl Episode {
391    /// 新規作成
392    pub fn new(learn_model: impl Into<String>, outcome: Outcome) -> Self {
393        Self {
394            id: EpisodeId::new(),
395            learn_model: learn_model.into(),
396            task_id: None,
397            group_id: None,
398            context: EpisodeContext::default(),
399            outcome,
400            metadata: EpisodeMetadata::new(),
401        }
402    }
403
404    /// Builder を取得
405    pub fn builder() -> EpisodeBuilder {
406        EpisodeBuilder::default()
407    }
408
409    /// 成功したかどうか
410    pub fn is_success(&self) -> bool {
411        self.outcome.is_success()
412    }
413
414    /// Worker ID を取得(最初の Action から)
415    pub fn worker_id(&self) -> Option<usize> {
416        self.context
417            .iter::<ActionRecord>()
418            .next()
419            .map(|a| a.worker_id)
420    }
421
422    /// Task ID を取得(context の最初の ActionRecord から、または直接設定された値)
423    pub fn get_task_id(&self) -> Option<TaskId> {
424        self.task_id.or_else(|| {
425            self.context
426                .iter::<ActionRecord>()
427                .next()
428                .map(|a| a.task_id)
429        })
430    }
431
432    /// Group ID を取得(context の最初の ActionRecord から、または直接設定された値)
433    pub fn get_group_id(&self) -> Option<GroupId> {
434        self.group_id.or_else(|| {
435            self.context
436                .iter::<ActionRecord>()
437                .next()
438                .and_then(|a| a.group_id)
439        })
440    }
441}
442
443impl EpisodeTrait for Episode {
444    fn id(&self) -> &EpisodeId {
445        &self.id
446    }
447
448    fn learn_model_name(&self) -> &str {
449        &self.learn_model
450    }
451
452    fn task_id(&self) -> Option<TaskId> {
453        self.get_task_id()
454    }
455
456    fn group_id(&self) -> Option<GroupId> {
457        self.get_group_id()
458    }
459
460    fn outcome(&self) -> &Outcome {
461        &self.outcome
462    }
463
464    fn scenario_name(&self) -> Option<&str> {
465        self.metadata.scenario_name.as_deref()
466    }
467}
468
469// ============================================================================
470// EpisodeBuilder
471// ============================================================================
472
473/// Episode を構築するためのビルダー
474#[derive(Debug, Default)]
475pub struct EpisodeBuilder {
476    id: Option<EpisodeId>,
477    learn_model: Option<String>,
478    task_id: Option<TaskId>,
479    group_id: Option<GroupId>,
480    context: EpisodeContext,
481    outcome: Option<Outcome>,
482    metadata: EpisodeMetadata,
483}
484
485impl EpisodeBuilder {
486    /// Episode ID を設定(永続化からの復元用)
487    pub fn id(mut self, id: EpisodeId) -> Self {
488        self.id = Some(id);
489        self
490    }
491
492    /// LearnModel 名を設定
493    pub fn learn_model(mut self, name: impl Into<String>) -> Self {
494        self.learn_model = Some(name.into());
495        self
496    }
497
498    /// Task ID を設定
499    pub fn task_id(mut self, task_id: TaskId) -> Self {
500        self.task_id = Some(task_id);
501        self
502    }
503
504    /// Group ID を設定
505    pub fn group_id(mut self, group_id: GroupId) -> Self {
506        self.group_id = Some(group_id);
507        self
508    }
509
510    /// Record を追加(汎用)
511    pub fn record(mut self, record: impl Into<Record>) -> Self {
512        self.context.push(record);
513        self
514    }
515
516    /// EpisodeContext を設定
517    pub fn context(mut self, context: EpisodeContext) -> Self {
518        self.context = context;
519        self
520    }
521
522    pub fn outcome(mut self, outcome: Outcome) -> Self {
523        self.outcome = Some(outcome);
524        self
525    }
526
527    pub fn scenario(mut self, name: impl Into<String>) -> Self {
528        self.metadata.scenario_name = Some(name.into());
529        self
530    }
531
532    pub fn tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
533        self.metadata.tags.insert(key.into(), value.into());
534        self
535    }
536
537    /// EpisodeMetadata を設定(永続化からの復元用)
538    pub fn metadata(mut self, metadata: EpisodeMetadata) -> Self {
539        self.metadata = metadata;
540        self
541    }
542
543    pub fn build(self) -> Episode {
544        Episode {
545            id: self.id.unwrap_or_default(),
546            learn_model: self.learn_model.unwrap_or_else(|| "unknown".to_string()),
547            task_id: self.task_id,
548            group_id: self.group_id,
549            context: self.context,
550            outcome: self.outcome.unwrap_or(Outcome::Unknown),
551            metadata: self.metadata,
552        }
553    }
554}
555
556// ============================================================================
557// Tests
558// ============================================================================
559
560#[cfg(test)]
561mod tests {
562    use std::time::Duration;
563
564    use super::*;
565    use crate::events::{ActionContext, ActionEvent, ActionEventBuilder, ActionEventResult};
566    use crate::learn::record::LlmCallRecord;
567    use crate::types::WorkerId;
568
569    fn make_action_event(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
570        let result = if success {
571            ActionEventResult::success()
572        } else {
573            ActionEventResult::failure("test error")
574        };
575
576        ActionEventBuilder::new(tick, WorkerId(worker_id), action)
577            .result(result)
578            .duration(Duration::from_millis(50))
579            .context(
580                ActionContext::new()
581                    .with_selection_logic("UCB1")
582                    .with_previous_action("PrevAction"),
583            )
584            .build()
585    }
586
587    #[test]
588    fn test_action_record_from_action_event() {
589        let event = make_action_event(10, 1, "CheckStatus", true);
590        let record = ActionRecord::from(&event);
591
592        assert_eq!(record.tick, 10);
593        assert_eq!(record.worker_id, 1);
594        assert_eq!(record.action, "CheckStatus");
595        assert!(record.success);
596        assert_eq!(record.duration_ms, 50);
597        assert_eq!(record.selection_logic, Some("UCB1".to_string()));
598        assert_eq!(record.previous_action, Some("PrevAction".to_string()));
599    }
600
601    #[test]
602    fn test_episode_builder_with_actions() {
603        let event1 = make_action_event(1, 0, "Grep", true);
604        let event2 = make_action_event(2, 0, "Read", true);
605        let event3 = make_action_event(3, 0, "done", true);
606
607        let episode = Episode::builder()
608            .learn_model("worker_task")
609            .record(ActionRecord::from(&event1))
610            .record(ActionRecord::from(&event2))
611            .record(ActionRecord::from(&event3))
612            .outcome(Outcome::success_binary())
613            .scenario("troubleshooting")
614            .build();
615
616        assert_eq!(episode.learn_model, "worker_task");
617        assert_eq!(episode.context.iter::<ActionRecord>().count(), 3);
618
619        let actions: Vec<&str> = episode
620            .context
621            .iter::<ActionRecord>()
622            .map(|a| a.action.as_str())
623            .collect();
624        assert_eq!(actions, vec!["Grep", "Read", "done"]);
625
626        assert!(episode.is_success());
627        assert_eq!(
628            episode.metadata.scenario_name,
629            Some("troubleshooting".to_string())
630        );
631    }
632
633    #[test]
634    fn test_episode_builder_with_llm_call() {
635        let llm_record = LlmCallRecord::new("decide", "qwen2.5")
636            .prompt("What action?")
637            .response("CheckStatus")
638            .latency_ms(150)
639            .worker_id(0);
640
641        let episode = Episode::builder()
642            .learn_model("llm_call")
643            .record(llm_record.clone())
644            .outcome(Outcome::success(0.9))
645            .build();
646
647        assert_eq!(episode.learn_model, "llm_call");
648        assert_eq!(episode.context.iter::<LlmCallRecord>().count(), 1);
649
650        let llm_call = episode.context.first::<LlmCallRecord>().unwrap();
651        assert_eq!(llm_call.prompt, "What action?");
652        assert_eq!(llm_call.response, "CheckStatus");
653    }
654
655    #[test]
656    fn test_outcome_variants() {
657        // Single: Success
658        assert!(Outcome::success(1.0).is_success());
659        assert!(!Outcome::success(1.0).is_failure());
660        assert_eq!(Outcome::success(0.8).score(), 0.8);
661
662        // Single: Failure
663        assert!(!Outcome::failure("test").is_success());
664        assert!(Outcome::failure("test").is_failure());
665        assert_eq!(Outcome::failure("test").score(), 0.0);
666
667        // Single: Timeout
668        assert!(!Outcome::timeout(Some(0.5)).is_success());
669        assert!(Outcome::timeout(Some(0.5)).is_failure());
670        assert_eq!(Outcome::timeout(Some(0.5)).score(), 0.5);
671
672        // Unknown
673        assert!(!Outcome::Unknown.is_success());
674        assert!(!Outcome::Unknown.is_failure());
675    }
676
677    #[test]
678    fn test_episode_context_iter() {
679        let mut context = EpisodeContext::new();
680        context.push(ActionRecord::new(1, 0, "A").success(true));
681        context.push(ActionRecord::new(2, 0, "B").success(true));
682        context.push(ActionRecord::new(3, 0, "C").success(false));
683
684        // iter::<ActionRecord>() でカウント
685        assert_eq!(context.iter::<ActionRecord>().count(), 3);
686
687        // 成功したアクションのカウント
688        let success_count = context.iter::<ActionRecord>().filter(|a| a.success).count();
689        assert_eq!(success_count, 2);
690
691        // アクションシーケンス
692        let actions: Vec<&str> = context
693            .iter::<ActionRecord>()
694            .map(|a| a.action.as_str())
695            .collect();
696        assert_eq!(actions, vec!["A", "B", "C"]);
697    }
698
699    #[test]
700    fn test_episode_serialization() {
701        let episode = Episode::builder()
702            .learn_model("worker_task")
703            .record(ActionRecord::new(1, 0, "CheckStatus").success(true))
704            .outcome(Outcome::success_binary())
705            .build();
706
707        // Serialize
708        let json = serde_json::to_string(&episode).unwrap();
709        assert!(json.contains("\"learn_model\":\"worker_task\""));
710        assert!(json.contains("\"action\":\"CheckStatus\""));
711
712        // Deserialize
713        let restored: Episode = serde_json::from_str(&json).unwrap();
714        assert_eq!(restored.learn_model, "worker_task");
715        assert_eq!(restored.context.iter::<ActionRecord>().count(), 1);
716        assert!(restored.is_success());
717    }
718
719    #[test]
720    fn test_llm_call_record_builder() {
721        let record = LlmCallRecord::new("decide", "qwen2.5")
722            .prompt("prompt")
723            .response("response")
724            .endpoint("http://localhost:11434")
725            .lora("adapter1")
726            .latency_ms(100)
727            .worker_id(5);
728
729        assert_eq!(record.call_type, "decide");
730        assert_eq!(record.model, "qwen2.5");
731        assert_eq!(record.prompt, "prompt");
732        assert_eq!(record.response, "response");
733        assert_eq!(record.lora, Some("adapter1".to_string()));
734        assert_eq!(record.worker_id, Some(5));
735        assert!(record.is_success());
736
737        let error_record = LlmCallRecord::new("decide", "model").error("timeout");
738        assert!(!error_record.is_success());
739    }
740
741    #[test]
742    fn test_episode_builder_with_id_and_metadata() {
743        let custom_id = EpisodeId::from_parts(12345, 1);
744        let mut custom_metadata = EpisodeMetadata::new();
745        custom_metadata.scenario_name = Some("custom-scenario".to_string());
746        custom_metadata
747            .tags
748            .insert("key".to_string(), "value".to_string());
749
750        let episode = Episode::builder()
751            .id(custom_id.clone())
752            .learn_model("test")
753            .metadata(custom_metadata)
754            .outcome(Outcome::Unknown)
755            .build();
756
757        assert_eq!(episode.id, custom_id);
758        assert_eq!(
759            episode.metadata.scenario_name,
760            Some("custom-scenario".to_string())
761        );
762        assert_eq!(episode.metadata.tags.get("key"), Some(&"value".to_string()));
763    }
764}