Skip to main content

swarm_engine_core/learn/
episode.rs

1//! Episode - 学習の基本単位
2//!
3//! ## 設計思想
4//!
5//! Episode は既存の統計システム(SwarmStats, N-gram)の **派生** であり、
6//! 既存システムに影響を与えない追加レイヤーとして機能する。
7//!
8//! - **既存統計が取れないのは NG**: Episode はあくまでオプショナル
9//! - **Episode が取れなくても OK**: 既存の Core 統計は独立して動作
10//! - **変換は Trait ベース**: From/Into で柔軟に既存 Event から変換
11//!
12//! ## アーキテクチャ
13//!
14//! ```text
15//! ActionEvent ──┬──▶ SwarmStats.record() [既存、必須]
16//!               │
17//!               └──▶ Episode への変換 [新規、optional]
18//!                    └─▶ EpisodeStore.append()
19//!
20//! LlmDebugEvent ─┬──▶ StderrLlmSubscriber [既存、デバッグ用]
21//!                │
22//!                └──▶ Episode への変換 [新規、optional]
23//!                     └─▶ EpisodeStore.append()
24//! ```
25
26use std::collections::HashMap;
27
28use serde::{Deserialize, Serialize};
29
30use super::record::{ActionRecord, FromRecord, Record};
31use crate::util::{epoch_millis, epoch_millis_for_ordering};
32
33// ============================================================================
34// EpisodeId
35// ============================================================================
36
37/// Episode ID - 一意識別子
38///
39/// timestamp (ms) + counter の組み合わせでユニーク性を保証。
40/// - timestamp: ソート可能性(ordering)
41/// - counter: 同一 ms 内の順序保証 + NTP 巻き戻り耐性
42#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
43pub struct EpisodeId {
44    /// タイムスタンプ部分(Unix epoch ms)
45    pub timestamp_ms: u64,
46    /// カウンタ部分(単調増加)
47    pub counter: u32,
48}
49
50impl EpisodeId {
51    pub fn new() -> Self {
52        use std::sync::atomic::{AtomicU32, Ordering};
53        static COUNTER: AtomicU32 = AtomicU32::new(0);
54
55        Self {
56            timestamp_ms: epoch_millis_for_ordering(),
57            counter: COUNTER.fetch_add(1, Ordering::Relaxed),
58        }
59    }
60
61    /// 既知の値から作成(テスト用)
62    pub fn from_parts(timestamp_ms: u64, counter: u32) -> Self {
63        Self {
64            timestamp_ms,
65            counter,
66        }
67    }
68}
69
70impl Default for EpisodeId {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl std::fmt::Display for EpisodeId {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        write!(f, "{}-{:08x}", self.timestamp_ms, self.counter)
79    }
80}
81
82// ============================================================================
83// Outcome
84// ============================================================================
85
86/// エピソードの結果
87///
88/// LearnModel::evaluate() で判定される。
89#[derive(Debug, Clone, Serialize, Deserialize)]
90#[serde(tag = "type")]
91#[derive(Default)]
92pub enum Outcome {
93    /// 成功
94    Success {
95        /// スコア(0.0〜1.0、または任意のスケール)
96        score: f64,
97    },
98    /// 失敗
99    Failure {
100        /// 失敗理由
101        reason: String,
102    },
103    /// タイムアウト
104    Timeout {
105        /// 部分スコア(タイムアウト時点での進捗)
106        partial_score: Option<f64>,
107    },
108    /// 不明(判定できない場合)
109    #[default]
110    Unknown,
111}
112
113impl Outcome {
114    pub fn success(score: f64) -> Self {
115        Self::Success { score }
116    }
117
118    pub fn success_binary() -> Self {
119        Self::Success { score: 1.0 }
120    }
121
122    pub fn failure(reason: impl Into<String>) -> Self {
123        Self::Failure {
124            reason: reason.into(),
125        }
126    }
127
128    pub fn timeout(partial_score: Option<f64>) -> Self {
129        Self::Timeout { partial_score }
130    }
131
132    pub fn is_success(&self) -> bool {
133        matches!(self, Self::Success { .. })
134    }
135
136    pub fn is_failure(&self) -> bool {
137        matches!(self, Self::Failure { .. } | Self::Timeout { .. })
138    }
139
140    /// スコアを取得(失敗なら0.0)
141    pub fn score(&self) -> f64 {
142        match self {
143            Self::Success { score } => *score,
144            Self::Timeout { partial_score } => partial_score.unwrap_or(0.0),
145            _ => 0.0,
146        }
147    }
148}
149
150
151// ============================================================================
152// EpisodeContext
153// ============================================================================
154
155/// エピソードのコンテキスト
156///
157/// Record のコレクションを保持。新しい Record 種別が追加されても
158/// 構造体を変更する必要がない。
159#[derive(Debug, Clone, Default, Serialize, Deserialize)]
160pub struct EpisodeContext {
161    /// Record のリスト(統一的に保持)
162    pub records: Vec<Record>,
163}
164
165impl EpisodeContext {
166    pub fn new() -> Self {
167        Self::default()
168    }
169
170    /// Record を追加
171    pub fn push(&mut self, record: impl Into<Record>) {
172        self.records.push(record.into());
173    }
174
175    /// Record を追加(builder pattern)
176    pub fn with_record(mut self, record: impl Into<Record>) -> Self {
177        self.records.push(record.into());
178        self
179    }
180
181    /// 型でフィルタしてイテレート
182    ///
183    /// ```ignore
184    /// context.iter::<ActionRecord>()
185    /// context.iter::<LlmCallRecord>()
186    /// ```
187    pub fn iter<'a, T: FromRecord + 'a>(&'a self) -> impl Iterator<Item = &'a T> {
188        self.records.iter().filter_map(T::from_record)
189    }
190
191    /// 全 Record 数
192    pub fn len(&self) -> usize {
193        self.records.len()
194    }
195
196    /// 空かどうか
197    pub fn is_empty(&self) -> bool {
198        self.records.is_empty()
199    }
200}
201
202// ============================================================================
203// EpisodeMetadata
204// ============================================================================
205
206/// エピソードのメタデータ
207#[derive(Debug, Clone, Default, Serialize, Deserialize)]
208pub struct EpisodeMetadata {
209    /// Strategy名(どの抽出戦略で生成されたか)
210    pub strategy_name: Option<String>,
211    /// シナリオ名
212    pub scenario_name: Option<String>,
213    /// 作成日時(Unix timestamp ms)
214    pub created_at: u64,
215    /// 開始日時(Unix timestamp ms)
216    pub started_at: Option<u64>,
217    /// 終了日時(Unix timestamp ms)
218    pub ended_at: Option<u64>,
219    /// 拡張タグ
220    pub tags: HashMap<String, String>,
221}
222
223impl EpisodeMetadata {
224    pub fn new() -> Self {
225        Self {
226            created_at: epoch_millis(),
227            ..Default::default()
228        }
229    }
230
231    pub fn with_strategy(mut self, name: impl Into<String>) -> Self {
232        self.strategy_name = Some(name.into());
233        self
234    }
235
236    pub fn with_scenario(mut self, name: impl Into<String>) -> Self {
237        self.scenario_name = Some(name.into());
238        self
239    }
240
241    pub fn with_duration(mut self, start: u64, end: u64) -> Self {
242        self.started_at = Some(start);
243        self.ended_at = Some(end);
244        self
245    }
246
247    pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
248        self.tags.insert(key.into(), value.into());
249        self
250    }
251
252    /// 実行時間(ミリ秒)
253    pub fn duration_ms(&self) -> Option<u64> {
254        match (self.started_at, self.ended_at) {
255            (Some(start), Some(end)) => Some(end.saturating_sub(start)),
256            _ => None,
257        }
258    }
259}
260
261// ============================================================================
262// Episode Entity
263// ============================================================================
264
265/// Episode - 学習の基本単位
266///
267/// Swarmの「経験」を表現する。LearnModel によって Record[] から構築され、
268/// TrainingData 生成の元となる。
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct Episode {
271    /// 一意識別子
272    pub id: EpisodeId,
273    /// どの LearnModel で生成されたか(e.g., "ngram-5", "worker_task")
274    pub learn_model: String,
275    /// コンテキスト(LLM呼び出し + アクション履歴)
276    pub context: EpisodeContext,
277    /// 結果(LearnModel が判定)
278    pub outcome: Outcome,
279    /// メタデータ
280    pub metadata: EpisodeMetadata,
281}
282
283impl Episode {
284    /// 新規作成
285    pub fn new(learn_model: impl Into<String>, outcome: Outcome) -> Self {
286        Self {
287            id: EpisodeId::new(),
288            learn_model: learn_model.into(),
289            context: EpisodeContext::default(),
290            outcome,
291            metadata: EpisodeMetadata::new(),
292        }
293    }
294
295    /// Builder を取得
296    pub fn builder() -> EpisodeBuilder {
297        EpisodeBuilder::default()
298    }
299
300    /// 成功したかどうか
301    pub fn is_success(&self) -> bool {
302        self.outcome.is_success()
303    }
304
305    /// Worker ID を取得(最初の Action から)
306    pub fn worker_id(&self) -> Option<usize> {
307        self.context
308            .iter::<ActionRecord>()
309            .next()
310            .map(|a| a.worker_id)
311    }
312}
313
314// ============================================================================
315// EpisodeBuilder
316// ============================================================================
317
318/// Episode を構築するためのビルダー
319#[derive(Debug, Default)]
320pub struct EpisodeBuilder {
321    id: Option<EpisodeId>,
322    learn_model: Option<String>,
323    context: EpisodeContext,
324    outcome: Option<Outcome>,
325    metadata: EpisodeMetadata,
326}
327
328impl EpisodeBuilder {
329    /// Episode ID を設定(永続化からの復元用)
330    pub fn id(mut self, id: EpisodeId) -> Self {
331        self.id = Some(id);
332        self
333    }
334
335    /// LearnModel 名を設定
336    pub fn learn_model(mut self, name: impl Into<String>) -> Self {
337        self.learn_model = Some(name.into());
338        self
339    }
340
341    /// Record を追加(汎用)
342    pub fn record(mut self, record: impl Into<Record>) -> Self {
343        self.context.push(record);
344        self
345    }
346
347    /// EpisodeContext を設定
348    pub fn context(mut self, context: EpisodeContext) -> Self {
349        self.context = context;
350        self
351    }
352
353    pub fn outcome(mut self, outcome: Outcome) -> Self {
354        self.outcome = Some(outcome);
355        self
356    }
357
358    pub fn scenario(mut self, name: impl Into<String>) -> Self {
359        self.metadata.scenario_name = Some(name.into());
360        self
361    }
362
363    pub fn tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
364        self.metadata.tags.insert(key.into(), value.into());
365        self
366    }
367
368    /// EpisodeMetadata を設定(永続化からの復元用)
369    pub fn metadata(mut self, metadata: EpisodeMetadata) -> Self {
370        self.metadata = metadata;
371        self
372    }
373
374    pub fn build(self) -> Episode {
375        Episode {
376            id: self.id.unwrap_or_default(),
377            learn_model: self.learn_model.unwrap_or_else(|| "unknown".to_string()),
378            context: self.context,
379            outcome: self.outcome.unwrap_or(Outcome::Unknown),
380            metadata: self.metadata,
381        }
382    }
383}
384
385// ============================================================================
386// Tests
387// ============================================================================
388
389#[cfg(test)]
390mod tests {
391    use std::time::Duration;
392
393    use super::*;
394    use crate::events::{ActionContext, ActionEvent, ActionEventBuilder, ActionEventResult};
395    use crate::learn::record::LlmCallRecord;
396    use crate::types::WorkerId;
397
398    fn make_action_event(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
399        let result = if success {
400            ActionEventResult::success()
401        } else {
402            ActionEventResult::failure("test error")
403        };
404
405        ActionEventBuilder::new(tick, WorkerId(worker_id), action)
406            .result(result)
407            .duration(Duration::from_millis(50))
408            .context(
409                ActionContext::new()
410                    .with_selection_logic("UCB1")
411                    .with_previous_action("PrevAction"),
412            )
413            .build()
414    }
415
416    #[test]
417    fn test_action_record_from_action_event() {
418        let event = make_action_event(10, 1, "CheckStatus", true);
419        let record = ActionRecord::from(&event);
420
421        assert_eq!(record.tick, 10);
422        assert_eq!(record.worker_id, 1);
423        assert_eq!(record.action, "CheckStatus");
424        assert!(record.success);
425        assert_eq!(record.duration_ms, 50);
426        assert_eq!(record.selection_logic, Some("UCB1".to_string()));
427        assert_eq!(record.previous_action, Some("PrevAction".to_string()));
428    }
429
430    #[test]
431    fn test_episode_builder_with_actions() {
432        let event1 = make_action_event(1, 0, "Grep", true);
433        let event2 = make_action_event(2, 0, "Read", true);
434        let event3 = make_action_event(3, 0, "done", true);
435
436        let episode = Episode::builder()
437            .learn_model("worker_task")
438            .record(ActionRecord::from(&event1))
439            .record(ActionRecord::from(&event2))
440            .record(ActionRecord::from(&event3))
441            .outcome(Outcome::success_binary())
442            .scenario("troubleshooting")
443            .build();
444
445        assert_eq!(episode.learn_model, "worker_task");
446        assert_eq!(episode.context.iter::<ActionRecord>().count(), 3);
447
448        let actions: Vec<&str> = episode
449            .context
450            .iter::<ActionRecord>()
451            .map(|a| a.action.as_str())
452            .collect();
453        assert_eq!(actions, vec!["Grep", "Read", "done"]);
454
455        assert!(episode.is_success());
456        assert_eq!(
457            episode.metadata.scenario_name,
458            Some("troubleshooting".to_string())
459        );
460    }
461
462    #[test]
463    fn test_episode_builder_with_llm_call() {
464        let llm_record = LlmCallRecord::new("decide", "qwen2.5")
465            .prompt("What action?")
466            .response("CheckStatus")
467            .latency_ms(150)
468            .worker_id(0);
469
470        let episode = Episode::builder()
471            .learn_model("llm_call")
472            .record(llm_record.clone())
473            .outcome(Outcome::success(0.9))
474            .build();
475
476        assert_eq!(episode.learn_model, "llm_call");
477        assert_eq!(episode.context.iter::<LlmCallRecord>().count(), 1);
478
479        let llm_call = episode.context.iter::<LlmCallRecord>().next().unwrap();
480        assert_eq!(llm_call.prompt, "What action?");
481        assert_eq!(llm_call.response, "CheckStatus");
482    }
483
484    #[test]
485    fn test_outcome_variants() {
486        assert!(Outcome::success(1.0).is_success());
487        assert!(!Outcome::success(1.0).is_failure());
488        assert_eq!(Outcome::success(0.8).score(), 0.8);
489
490        assert!(!Outcome::failure("test").is_success());
491        assert!(Outcome::failure("test").is_failure());
492        assert_eq!(Outcome::failure("test").score(), 0.0);
493
494        assert!(!Outcome::timeout(Some(0.5)).is_success());
495        assert!(Outcome::timeout(Some(0.5)).is_failure());
496        assert_eq!(Outcome::timeout(Some(0.5)).score(), 0.5);
497
498        assert!(!Outcome::Unknown.is_success());
499        assert!(!Outcome::Unknown.is_failure());
500    }
501
502    #[test]
503    fn test_episode_context_iter() {
504        let mut context = EpisodeContext::new();
505        context.push(ActionRecord::new(1, 0, "A").success(true));
506        context.push(ActionRecord::new(2, 0, "B").success(true));
507        context.push(ActionRecord::new(3, 0, "C").success(false));
508
509        // iter::<ActionRecord>() でカウント
510        assert_eq!(context.iter::<ActionRecord>().count(), 3);
511
512        // 成功したアクションのカウント
513        let success_count = context.iter::<ActionRecord>().filter(|a| a.success).count();
514        assert_eq!(success_count, 2);
515
516        // アクションシーケンス
517        let actions: Vec<&str> = context
518            .iter::<ActionRecord>()
519            .map(|a| a.action.as_str())
520            .collect();
521        assert_eq!(actions, vec!["A", "B", "C"]);
522    }
523
524    #[test]
525    fn test_episode_serialization() {
526        let episode = Episode::builder()
527            .learn_model("worker_task")
528            .record(ActionRecord::new(1, 0, "CheckStatus").success(true))
529            .outcome(Outcome::success_binary())
530            .build();
531
532        // Serialize
533        let json = serde_json::to_string(&episode).unwrap();
534        assert!(json.contains("\"learn_model\":\"worker_task\""));
535        assert!(json.contains("\"action\":\"CheckStatus\""));
536
537        // Deserialize
538        let restored: Episode = serde_json::from_str(&json).unwrap();
539        assert_eq!(restored.learn_model, "worker_task");
540        assert_eq!(restored.context.iter::<ActionRecord>().count(), 1);
541        assert!(restored.is_success());
542    }
543
544    #[test]
545    fn test_llm_call_record_builder() {
546        let record = LlmCallRecord::new("decide", "qwen2.5")
547            .prompt("prompt")
548            .response("response")
549            .endpoint("http://localhost:11434")
550            .lora("adapter1")
551            .latency_ms(100)
552            .worker_id(5);
553
554        assert_eq!(record.call_type, "decide");
555        assert_eq!(record.model, "qwen2.5");
556        assert_eq!(record.prompt, "prompt");
557        assert_eq!(record.response, "response");
558        assert_eq!(record.lora, Some("adapter1".to_string()));
559        assert_eq!(record.worker_id, Some(5));
560        assert!(record.is_success());
561
562        let error_record = LlmCallRecord::new("decide", "model").error("timeout");
563        assert!(!error_record.is_success());
564    }
565
566    #[test]
567    fn test_episode_builder_with_id_and_metadata() {
568        let custom_id = EpisodeId::from_parts(12345, 1);
569        let mut custom_metadata = EpisodeMetadata::new();
570        custom_metadata.scenario_name = Some("custom-scenario".to_string());
571        custom_metadata
572            .tags
573            .insert("key".to_string(), "value".to_string());
574
575        let episode = Episode::builder()
576            .id(custom_id.clone())
577            .learn_model("test")
578            .metadata(custom_metadata)
579            .outcome(Outcome::Unknown)
580            .build();
581
582        assert_eq!(episode.id, custom_id);
583        assert_eq!(
584            episode.metadata.scenario_name,
585            Some("custom-scenario".to_string())
586        );
587        assert_eq!(episode.metadata.tags.get("key"), Some(&"value".to_string()));
588    }
589}