Skip to main content

swarm_engine_core/learn/
episode.rs

1//! Episode - 学習の基本単位
2//!
3//! ## 設計思想
4//!
5//! Episode は既存の統計システム(SwarmStats, N-gram)の **派生** であり、
6//! 既存システムに影響を与えない追加レイヤーとして機能する。
7//!
8//! - **既存統計が取れないのは NG**: Episode はあくまでオプショナル
9//! - **Episode が取れなくても OK**: 既存の Core 統計は独立して動作
10//! - **変換は Trait ベース**: From/Into で柔軟に既存 Event から変換
11//!
12//! ## アーキテクチャ
13//!
14//! ```text
15//! ActionEvent ──┬──▶ SwarmStats.record() [既存、必須]
16//!               │
17//!               └──▶ Episode への変換 [新規、optional]
18//!                    └─▶ EpisodeStore.append()
19//!
20//! LlmDebugEvent ─┬──▶ StderrLlmSubscriber [既存、デバッグ用]
21//!                │
22//!                └──▶ Episode への変換 [新規、optional]
23//!                     └─▶ EpisodeStore.append()
24//! ```
25
26use std::collections::HashMap;
27
28use serde::{Deserialize, Serialize};
29
30use super::record::{ActionRecord, FromRecord, Record};
31use crate::types::{GroupId, TaskId};
32use crate::util::{epoch_millis, epoch_millis_for_ordering};
33
34// ============================================================================
35// Episode Trait
36// ============================================================================
37
38/// Episode trait - 学習の基本単位を表すインターフェース
39///
40/// 各 LearnModel に対応する具体的な Episode 型はこの trait を実装する。
41/// これにより、学習対象ごとに最適化された構造を持ちながら、
42/// 共通のインターフェースで扱うことができる。
43///
44/// ## 実装例
45///
46/// ```ignore
47/// // WorkerTask 学習用の Episode
48/// pub struct WorkerTaskEpisode {
49///     id: EpisodeId,
50///     task_id: TaskId,
51///     actions: Vec<ActionRecord>,
52///     outcome: Outcome,
53/// }
54///
55/// impl EpisodeTrait for WorkerTaskEpisode {
56///     fn id(&self) -> &EpisodeId { &self.id }
57///     fn task_id(&self) -> TaskId { self.task_id }
58///     // ...
59/// }
60/// ```
61pub trait EpisodeTrait: Send + Sync {
62    /// Episode ID
63    fn id(&self) -> &EpisodeId;
64
65    /// 対応する LearnModel 名
66    fn learn_model_name(&self) -> &str;
67
68    /// Task ID(どのタスクの Episode か)
69    fn task_id(&self) -> Option<TaskId>;
70
71    /// Group ID(DPO 学習での比較グループ)
72    fn group_id(&self) -> Option<GroupId>;
73
74    /// 結果(成功/失敗/タイムアウト等)
75    fn outcome(&self) -> &Outcome;
76
77    /// 成功したかどうか
78    fn is_success(&self) -> bool {
79        self.outcome().is_success()
80    }
81
82    /// シナリオ名(あれば)
83    fn scenario_name(&self) -> Option<&str>;
84}
85
86// ============================================================================
87// EpisodeId
88// ============================================================================
89
90/// Episode ID - 一意識別子
91///
92/// timestamp (ms) + counter の組み合わせでユニーク性を保証。
93/// - timestamp: ソート可能性(ordering)
94/// - counter: 同一 ms 内の順序保証 + NTP 巻き戻り耐性
95#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
96pub struct EpisodeId {
97    /// タイムスタンプ部分(Unix epoch ms)
98    pub timestamp_ms: u64,
99    /// カウンタ部分(単調増加)
100    pub counter: u32,
101}
102
103impl EpisodeId {
104    pub fn new() -> Self {
105        use std::sync::atomic::{AtomicU32, Ordering};
106        static COUNTER: AtomicU32 = AtomicU32::new(0);
107
108        Self {
109            timestamp_ms: epoch_millis_for_ordering(),
110            counter: COUNTER.fetch_add(1, Ordering::Relaxed),
111        }
112    }
113
114    /// 既知の値から作成(テスト用)
115    pub fn from_parts(timestamp_ms: u64, counter: u32) -> Self {
116        Self {
117            timestamp_ms,
118            counter,
119        }
120    }
121}
122
123impl Default for EpisodeId {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129impl std::fmt::Display for EpisodeId {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        write!(f, "{}-{:08x}", self.timestamp_ms, self.counter)
132    }
133}
134
135// ============================================================================
136// Outcome
137// ============================================================================
138
139/// エピソードの結果
140///
141/// LearnModel::evaluate() で判定される。
142///
143/// ## 構造
144///
145/// - **Single Task**: Success / Failure / Timeout
146/// - **Multiple Tasks**: Aggregated
147/// - **未判定**: Unknown
148#[derive(Debug, Clone, Serialize, Deserialize)]
149#[serde(tag = "type")]
150#[derive(Default)]
151pub enum Outcome {
152    // ========================================================================
153    // Single Task
154    // ========================================================================
155    /// 成功(1 Task)
156    Success {
157        /// スコア(0.0〜1.0、または任意のスケール)
158        score: f64,
159    },
160    /// 失敗(1 Task)
161    Failure {
162        /// 失敗理由
163        reason: String,
164    },
165    /// タイムアウト(1 Task)
166    Timeout {
167        /// 部分スコア(タイムアウト時点での進捗)
168        partial_score: Option<f64>,
169    },
170
171    // ========================================================================
172    // Multiple Tasks (Aggregated)
173    // ========================================================================
174    /// 複数 Task の集約結果
175    ///
176    /// Eval 実行後に複数タスクの結果を集約して評価する場合に使用。
177    Aggregated {
178        /// 成功率(0.0〜1.0)
179        success_rate: f64,
180        /// 総タスク数
181        total_tasks: u32,
182        /// 成功タスク数
183        successful_tasks: u32,
184        /// 合計 Tick 数
185        total_ticks: u32,
186    },
187
188    // ========================================================================
189    // Unknown
190    // ========================================================================
191    /// 不明(判定できない場合)
192    #[default]
193    Unknown,
194}
195
196impl Outcome {
197    // ========================================================================
198    // Constructors - Single Task
199    // ========================================================================
200
201    pub fn success(score: f64) -> Self {
202        Self::Success { score }
203    }
204
205    pub fn success_binary() -> Self {
206        Self::Success { score: 1.0 }
207    }
208
209    pub fn failure(reason: impl Into<String>) -> Self {
210        Self::Failure {
211            reason: reason.into(),
212        }
213    }
214
215    pub fn timeout(partial_score: Option<f64>) -> Self {
216        Self::Timeout { partial_score }
217    }
218
219    // ========================================================================
220    // Constructors - Aggregated
221    // ========================================================================
222
223    /// 複数 Task の集約結果を作成
224    pub fn aggregated(
225        success_rate: f64,
226        total_tasks: u32,
227        successful_tasks: u32,
228        total_ticks: u32,
229    ) -> Self {
230        Self::Aggregated {
231            success_rate,
232            total_tasks,
233            successful_tasks,
234            total_ticks,
235        }
236    }
237
238    /// 複数 Task の集計結果から Aggregated を作成
239    ///
240    /// **注意**: このメソッドは複数 Task の集計結果用です。
241    /// 1 Task の結果には以下を使用してください:
242    /// - `Outcome::success(score)` - タスク成功
243    /// - `Outcome::failure(reason)` - タスク失敗
244    /// - `Outcome::timeout(partial_score)` - タイムアウト
245    ///
246    /// # Arguments
247    /// - `total_tasks`: 総タスク数(2 以上を想定)
248    /// - `successful_tasks`: 成功タスク数
249    /// - `success_rate`: 成功率(0.0〜1.0)
250    /// - `total_ticks`: 合計 Tick 数
251    ///
252    /// # Example
253    ///
254    /// ```ignore
255    /// // DPO 学習: 同じ GroupId で 5 回実行した結果を集計
256    /// let outcome = Outcome::from_eval_result(5, 4, 0.8, 250);
257    /// ```
258    pub fn from_eval_result(
259        total_tasks: u32,
260        successful_tasks: u32,
261        success_rate: f64,
262        total_ticks: u32,
263    ) -> Self {
264        Self::Aggregated {
265            success_rate,
266            total_tasks,
267            successful_tasks,
268            total_ticks,
269        }
270    }
271
272    // ========================================================================
273    // Predicates
274    // ========================================================================
275
276    pub fn is_success(&self) -> bool {
277        matches!(self, Self::Success { .. })
278    }
279
280    pub fn is_failure(&self) -> bool {
281        matches!(self, Self::Failure { .. } | Self::Timeout { .. })
282    }
283
284    /// Aggregated かどうか
285    pub fn is_aggregated(&self) -> bool {
286        matches!(self, Self::Aggregated { .. })
287    }
288
289    /// Aggregated で閾値以上の成功率かどうか
290    pub fn is_aggregated_success(&self, threshold: f64) -> bool {
291        match self {
292            Self::Aggregated { success_rate, .. } => *success_rate >= threshold,
293            _ => false,
294        }
295    }
296
297    /// Aggregated で閾値未満の成功率かどうか
298    pub fn is_aggregated_failure(&self, threshold: f64) -> bool {
299        match self {
300            Self::Aggregated { success_rate, .. } => *success_rate < threshold,
301            _ => false,
302        }
303    }
304
305    // ========================================================================
306    // Accessors
307    // ========================================================================
308
309    /// スコアを取得(失敗なら0.0)
310    pub fn score(&self) -> f64 {
311        match self {
312            Self::Success { score } => *score,
313            Self::Timeout { partial_score } => partial_score.unwrap_or(0.0),
314            Self::Aggregated { success_rate, .. } => *success_rate,
315            _ => 0.0,
316        }
317    }
318
319    /// 成功率を取得(Aggregated 専用)
320    ///
321    /// Single Task (Success/Failure/Timeout) では None を返す。
322    /// Single Task のスコアは `score()` を使用すること。
323    pub fn success_rate(&self) -> Option<f64> {
324        match self {
325            Self::Aggregated { success_rate, .. } => Some(*success_rate),
326            _ => None,
327        }
328    }
329
330    /// Tick 数を取得(Aggregated 用)
331    pub fn ticks(&self) -> Option<u32> {
332        match self {
333            Self::Aggregated { total_ticks, .. } => Some(*total_ticks),
334            _ => None,
335        }
336    }
337}
338
339// ============================================================================
340// EpisodeContext
341// ============================================================================
342
343/// エピソードのコンテキスト
344///
345/// Record のコレクションを保持。新しい Record 種別が追加されても
346/// 構造体を変更する必要がない。
347#[derive(Debug, Clone, Default, Serialize, Deserialize)]
348pub struct EpisodeContext {
349    /// Record のリスト(統一的に保持)
350    pub records: Vec<Record>,
351}
352
353impl EpisodeContext {
354    pub fn new() -> Self {
355        Self::default()
356    }
357
358    /// Record を追加
359    pub fn push(&mut self, record: impl Into<Record>) {
360        self.records.push(record.into());
361    }
362
363    /// Record を追加(builder pattern)
364    pub fn with_record(mut self, record: impl Into<Record>) -> Self {
365        self.records.push(record.into());
366        self
367    }
368
369    /// 型でフィルタしてイテレート
370    ///
371    /// ```ignore
372    /// context.iter::<ActionRecord>()
373    /// context.iter::<LlmCallRecord>()
374    /// ```
375    pub fn iter<'a, T: FromRecord + 'a>(&'a self) -> impl Iterator<Item = &'a T> {
376        self.records.iter().filter_map(T::from_record)
377    }
378
379    /// 指定した型の最初の Record を取得
380    ///
381    /// ```ignore
382    /// context.first::<DependencyGraphRecord>()
383    /// context.first::<LlmCallRecord>()
384    /// ```
385    pub fn first<T: FromRecord>(&self) -> Option<&T> {
386        self.iter::<T>().next()
387    }
388
389    /// 全 Record 数
390    pub fn len(&self) -> usize {
391        self.records.len()
392    }
393
394    /// 空かどうか
395    pub fn is_empty(&self) -> bool {
396        self.records.is_empty()
397    }
398}
399
400// ============================================================================
401// EpisodeMetadata
402// ============================================================================
403
404/// エピソードのメタデータ
405#[derive(Debug, Clone, Default, Serialize, Deserialize)]
406pub struct EpisodeMetadata {
407    /// Strategy名(どの抽出戦略で生成されたか)
408    pub strategy_name: Option<String>,
409    /// シナリオ名
410    pub scenario_name: Option<String>,
411    /// 作成日時(Unix timestamp ms)
412    pub created_at: u64,
413    /// 開始日時(Unix timestamp ms)
414    pub started_at: Option<u64>,
415    /// 終了日時(Unix timestamp ms)
416    pub ended_at: Option<u64>,
417    /// 拡張タグ
418    pub tags: HashMap<String, String>,
419}
420
421impl EpisodeMetadata {
422    pub fn new() -> Self {
423        Self {
424            created_at: epoch_millis(),
425            ..Default::default()
426        }
427    }
428
429    pub fn with_strategy(mut self, name: impl Into<String>) -> Self {
430        self.strategy_name = Some(name.into());
431        self
432    }
433
434    pub fn with_scenario(mut self, name: impl Into<String>) -> Self {
435        self.scenario_name = Some(name.into());
436        self
437    }
438
439    pub fn with_duration(mut self, start: u64, end: u64) -> Self {
440        self.started_at = Some(start);
441        self.ended_at = Some(end);
442        self
443    }
444
445    pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
446        self.tags.insert(key.into(), value.into());
447        self
448    }
449
450    /// 実行時間(ミリ秒)
451    pub fn duration_ms(&self) -> Option<u64> {
452        match (self.started_at, self.ended_at) {
453            (Some(start), Some(end)) => Some(end.saturating_sub(start)),
454            _ => None,
455        }
456    }
457}
458
459// ============================================================================
460// Episode Entity
461// ============================================================================
462
463/// Episode - 学習の基本単位(汎用実装)
464///
465/// Swarmの「経験」を表現する。LearnModel によって Record[] から構築され、
466/// TrainingData 生成の元となる。
467///
468/// ## Note
469///
470/// これは汎用の Episode 実装。学習対象ごとに最適化した Episode が必要な場合は、
471/// `EpisodeTrait` を実装した新しい型を定義すること。
472#[derive(Debug, Clone, Serialize, Deserialize)]
473pub struct Episode {
474    /// 一意識別子
475    pub id: EpisodeId,
476    /// どの LearnModel で生成されたか(e.g., "ngram-5", "worker_task")
477    pub learn_model: String,
478    /// Task ID(どのタスクの Episode か)
479    #[serde(default, skip_serializing_if = "Option::is_none")]
480    pub task_id: Option<TaskId>,
481    /// Group ID(DPO 学習での比較グループ)
482    #[serde(default, skip_serializing_if = "Option::is_none")]
483    pub group_id: Option<GroupId>,
484    /// コンテキスト(LLM呼び出し + アクション履歴)
485    pub context: EpisodeContext,
486    /// 結果(LearnModel が判定)
487    pub outcome: Outcome,
488    /// メタデータ
489    pub metadata: EpisodeMetadata,
490}
491
492impl Episode {
493    /// 新規作成
494    pub fn new(learn_model: impl Into<String>, outcome: Outcome) -> Self {
495        Self {
496            id: EpisodeId::new(),
497            learn_model: learn_model.into(),
498            task_id: None,
499            group_id: None,
500            context: EpisodeContext::default(),
501            outcome,
502            metadata: EpisodeMetadata::new(),
503        }
504    }
505
506    /// Builder を取得
507    pub fn builder() -> EpisodeBuilder {
508        EpisodeBuilder::default()
509    }
510
511    /// 成功したかどうか
512    pub fn is_success(&self) -> bool {
513        self.outcome.is_success()
514    }
515
516    /// Worker ID を取得(最初の Action から)
517    pub fn worker_id(&self) -> Option<usize> {
518        self.context
519            .iter::<ActionRecord>()
520            .next()
521            .map(|a| a.worker_id)
522    }
523
524    /// Task ID を取得(context の最初の ActionRecord から、または直接設定された値)
525    pub fn get_task_id(&self) -> Option<TaskId> {
526        self.task_id.or_else(|| {
527            self.context
528                .iter::<ActionRecord>()
529                .next()
530                .map(|a| a.task_id)
531        })
532    }
533
534    /// Group ID を取得(context の最初の ActionRecord から、または直接設定された値)
535    pub fn get_group_id(&self) -> Option<GroupId> {
536        self.group_id.or_else(|| {
537            self.context
538                .iter::<ActionRecord>()
539                .next()
540                .and_then(|a| a.group_id)
541        })
542    }
543}
544
545impl EpisodeTrait for Episode {
546    fn id(&self) -> &EpisodeId {
547        &self.id
548    }
549
550    fn learn_model_name(&self) -> &str {
551        &self.learn_model
552    }
553
554    fn task_id(&self) -> Option<TaskId> {
555        self.get_task_id()
556    }
557
558    fn group_id(&self) -> Option<GroupId> {
559        self.get_group_id()
560    }
561
562    fn outcome(&self) -> &Outcome {
563        &self.outcome
564    }
565
566    fn scenario_name(&self) -> Option<&str> {
567        self.metadata.scenario_name.as_deref()
568    }
569}
570
571// ============================================================================
572// EpisodeBuilder
573// ============================================================================
574
575/// Episode を構築するためのビルダー
576#[derive(Debug, Default)]
577pub struct EpisodeBuilder {
578    id: Option<EpisodeId>,
579    learn_model: Option<String>,
580    task_id: Option<TaskId>,
581    group_id: Option<GroupId>,
582    context: EpisodeContext,
583    outcome: Option<Outcome>,
584    metadata: EpisodeMetadata,
585}
586
587impl EpisodeBuilder {
588    /// Episode ID を設定(永続化からの復元用)
589    pub fn id(mut self, id: EpisodeId) -> Self {
590        self.id = Some(id);
591        self
592    }
593
594    /// LearnModel 名を設定
595    pub fn learn_model(mut self, name: impl Into<String>) -> Self {
596        self.learn_model = Some(name.into());
597        self
598    }
599
600    /// Task ID を設定
601    pub fn task_id(mut self, task_id: TaskId) -> Self {
602        self.task_id = Some(task_id);
603        self
604    }
605
606    /// Group ID を設定
607    pub fn group_id(mut self, group_id: GroupId) -> Self {
608        self.group_id = Some(group_id);
609        self
610    }
611
612    /// Record を追加(汎用)
613    pub fn record(mut self, record: impl Into<Record>) -> Self {
614        self.context.push(record);
615        self
616    }
617
618    /// EpisodeContext を設定
619    pub fn context(mut self, context: EpisodeContext) -> Self {
620        self.context = context;
621        self
622    }
623
624    pub fn outcome(mut self, outcome: Outcome) -> Self {
625        self.outcome = Some(outcome);
626        self
627    }
628
629    pub fn scenario(mut self, name: impl Into<String>) -> Self {
630        self.metadata.scenario_name = Some(name.into());
631        self
632    }
633
634    pub fn tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
635        self.metadata.tags.insert(key.into(), value.into());
636        self
637    }
638
639    /// EpisodeMetadata を設定(永続化からの復元用)
640    pub fn metadata(mut self, metadata: EpisodeMetadata) -> Self {
641        self.metadata = metadata;
642        self
643    }
644
645    pub fn build(self) -> Episode {
646        Episode {
647            id: self.id.unwrap_or_default(),
648            learn_model: self.learn_model.unwrap_or_else(|| "unknown".to_string()),
649            task_id: self.task_id,
650            group_id: self.group_id,
651            context: self.context,
652            outcome: self.outcome.unwrap_or(Outcome::Unknown),
653            metadata: self.metadata,
654        }
655    }
656}
657
658// ============================================================================
659// Tests
660// ============================================================================
661
662#[cfg(test)]
663mod tests {
664    use std::time::Duration;
665
666    use super::*;
667    use crate::events::{ActionContext, ActionEvent, ActionEventBuilder, ActionEventResult};
668    use crate::learn::record::LlmCallRecord;
669    use crate::types::WorkerId;
670
671    fn make_action_event(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
672        let result = if success {
673            ActionEventResult::success()
674        } else {
675            ActionEventResult::failure("test error")
676        };
677
678        ActionEventBuilder::new(tick, WorkerId(worker_id), action)
679            .result(result)
680            .duration(Duration::from_millis(50))
681            .context(
682                ActionContext::new()
683                    .with_selection_logic("UCB1")
684                    .with_previous_action("PrevAction"),
685            )
686            .build()
687    }
688
689    #[test]
690    fn test_action_record_from_action_event() {
691        let event = make_action_event(10, 1, "CheckStatus", true);
692        let record = ActionRecord::from(&event);
693
694        assert_eq!(record.tick, 10);
695        assert_eq!(record.worker_id, 1);
696        assert_eq!(record.action, "CheckStatus");
697        assert!(record.success);
698        assert_eq!(record.duration_ms, 50);
699        assert_eq!(record.selection_logic, Some("UCB1".to_string()));
700        assert_eq!(record.previous_action, Some("PrevAction".to_string()));
701    }
702
703    #[test]
704    fn test_episode_builder_with_actions() {
705        let event1 = make_action_event(1, 0, "Grep", true);
706        let event2 = make_action_event(2, 0, "Read", true);
707        let event3 = make_action_event(3, 0, "done", true);
708
709        let episode = Episode::builder()
710            .learn_model("worker_task")
711            .record(ActionRecord::from(&event1))
712            .record(ActionRecord::from(&event2))
713            .record(ActionRecord::from(&event3))
714            .outcome(Outcome::success_binary())
715            .scenario("troubleshooting")
716            .build();
717
718        assert_eq!(episode.learn_model, "worker_task");
719        assert_eq!(episode.context.iter::<ActionRecord>().count(), 3);
720
721        let actions: Vec<&str> = episode
722            .context
723            .iter::<ActionRecord>()
724            .map(|a| a.action.as_str())
725            .collect();
726        assert_eq!(actions, vec!["Grep", "Read", "done"]);
727
728        assert!(episode.is_success());
729        assert_eq!(
730            episode.metadata.scenario_name,
731            Some("troubleshooting".to_string())
732        );
733    }
734
735    #[test]
736    fn test_episode_builder_with_llm_call() {
737        let llm_record = LlmCallRecord::new("decide", "qwen2.5")
738            .prompt("What action?")
739            .response("CheckStatus")
740            .latency_ms(150)
741            .worker_id(0);
742
743        let episode = Episode::builder()
744            .learn_model("llm_call")
745            .record(llm_record.clone())
746            .outcome(Outcome::success(0.9))
747            .build();
748
749        assert_eq!(episode.learn_model, "llm_call");
750        assert_eq!(episode.context.iter::<LlmCallRecord>().count(), 1);
751
752        let llm_call = episode.context.first::<LlmCallRecord>().unwrap();
753        assert_eq!(llm_call.prompt, "What action?");
754        assert_eq!(llm_call.response, "CheckStatus");
755    }
756
757    #[test]
758    fn test_outcome_variants() {
759        // Single: Success
760        assert!(Outcome::success(1.0).is_success());
761        assert!(!Outcome::success(1.0).is_failure());
762        assert_eq!(Outcome::success(0.8).score(), 0.8);
763
764        // Single: Failure
765        assert!(!Outcome::failure("test").is_success());
766        assert!(Outcome::failure("test").is_failure());
767        assert_eq!(Outcome::failure("test").score(), 0.0);
768
769        // Single: Timeout
770        assert!(!Outcome::timeout(Some(0.5)).is_success());
771        assert!(Outcome::timeout(Some(0.5)).is_failure());
772        assert_eq!(Outcome::timeout(Some(0.5)).score(), 0.5);
773
774        // Unknown
775        assert!(!Outcome::Unknown.is_success());
776        assert!(!Outcome::Unknown.is_failure());
777    }
778
779    #[test]
780    fn test_outcome_aggregated() {
781        // Aggregated: 高成功率
782        let high = Outcome::aggregated(0.9, 10, 9, 100);
783        assert!(high.is_aggregated());
784        assert!(high.is_aggregated_success(0.5));
785        assert!(!high.is_aggregated_failure(0.5));
786        assert_eq!(high.score(), 0.9);
787        assert_eq!(high.success_rate(), Some(0.9));
788        assert_eq!(high.ticks(), Some(100));
789
790        // Aggregated: 低成功率
791        let low = Outcome::aggregated(0.3, 10, 3, 50);
792        assert!(low.is_aggregated());
793        assert!(!low.is_aggregated_success(0.5));
794        assert!(low.is_aggregated_failure(0.5));
795        assert_eq!(low.score(), 0.3);
796
797        // from_eval_result (引数順: total_tasks, successful_tasks, success_rate, total_ticks)
798        let from_eval = Outcome::from_eval_result(5, 4, 0.8, 20);
799        assert!(from_eval.is_aggregated());
800        assert_eq!(from_eval.success_rate(), Some(0.8));
801        assert_eq!(from_eval.ticks(), Some(20));
802    }
803
804    #[test]
805    fn test_episode_context_iter() {
806        let mut context = EpisodeContext::new();
807        context.push(ActionRecord::new(1, 0, "A").success(true));
808        context.push(ActionRecord::new(2, 0, "B").success(true));
809        context.push(ActionRecord::new(3, 0, "C").success(false));
810
811        // iter::<ActionRecord>() でカウント
812        assert_eq!(context.iter::<ActionRecord>().count(), 3);
813
814        // 成功したアクションのカウント
815        let success_count = context.iter::<ActionRecord>().filter(|a| a.success).count();
816        assert_eq!(success_count, 2);
817
818        // アクションシーケンス
819        let actions: Vec<&str> = context
820            .iter::<ActionRecord>()
821            .map(|a| a.action.as_str())
822            .collect();
823        assert_eq!(actions, vec!["A", "B", "C"]);
824    }
825
826    #[test]
827    fn test_episode_serialization() {
828        let episode = Episode::builder()
829            .learn_model("worker_task")
830            .record(ActionRecord::new(1, 0, "CheckStatus").success(true))
831            .outcome(Outcome::success_binary())
832            .build();
833
834        // Serialize
835        let json = serde_json::to_string(&episode).unwrap();
836        assert!(json.contains("\"learn_model\":\"worker_task\""));
837        assert!(json.contains("\"action\":\"CheckStatus\""));
838
839        // Deserialize
840        let restored: Episode = serde_json::from_str(&json).unwrap();
841        assert_eq!(restored.learn_model, "worker_task");
842        assert_eq!(restored.context.iter::<ActionRecord>().count(), 1);
843        assert!(restored.is_success());
844    }
845
846    #[test]
847    fn test_llm_call_record_builder() {
848        let record = LlmCallRecord::new("decide", "qwen2.5")
849            .prompt("prompt")
850            .response("response")
851            .endpoint("http://localhost:11434")
852            .lora("adapter1")
853            .latency_ms(100)
854            .worker_id(5);
855
856        assert_eq!(record.call_type, "decide");
857        assert_eq!(record.model, "qwen2.5");
858        assert_eq!(record.prompt, "prompt");
859        assert_eq!(record.response, "response");
860        assert_eq!(record.lora, Some("adapter1".to_string()));
861        assert_eq!(record.worker_id, Some(5));
862        assert!(record.is_success());
863
864        let error_record = LlmCallRecord::new("decide", "model").error("timeout");
865        assert!(!error_record.is_success());
866    }
867
868    #[test]
869    fn test_episode_builder_with_id_and_metadata() {
870        let custom_id = EpisodeId::from_parts(12345, 1);
871        let mut custom_metadata = EpisodeMetadata::new();
872        custom_metadata.scenario_name = Some("custom-scenario".to_string());
873        custom_metadata
874            .tags
875            .insert("key".to_string(), "value".to_string());
876
877        let episode = Episode::builder()
878            .id(custom_id.clone())
879            .learn_model("test")
880            .metadata(custom_metadata)
881            .outcome(Outcome::Unknown)
882            .build();
883
884        assert_eq!(episode.id, custom_id);
885        assert_eq!(
886            episode.metadata.scenario_name,
887            Some("custom-scenario".to_string())
888        );
889        assert_eq!(episode.metadata.tags.get("key"), Some(&"value".to_string()));
890    }
891}