Skip to main content

swarm_engine_core/learn/learn_model/
mod.rs

1//! LearnModel - 学習の統合モデル
2//!
3//! ## 設計思想
4//!
5//! LearnModel は「何を学習するか」を統合的に定義する。
6//!
7//! - **何を目的とするか** (objective)
8//! - **何を Episode として切り出すか** (build_episodes)
9//! - **何を Success/Failure とするか** (evaluate)
10//! - **どう TrainingData に変換するか** (convert)
11//!
12//! ## Learn の価値
13//!
14//! Core(Swarm本体)は性能制約で 3-gram までしか取れない。
15//! しかし Learn は非同期/オフラインなので、5-gram や 10-gram など
16//! 自由に分析できる。これが Learn モジュールの価値。
17//!
18//! ## モジュール構造
19//!
20//! ```text
21//! learn_model/
22//! ├── mod.rs              # LearnModel trait 定義
23//! ├── dpo.rs              # 汎用 DPO 基盤 (DpoLearnModel<F>)
24//! ├── dependency_graph.rs # DependencyGraph 推論の学習
25//! ├── worker_task.rs      # Worker の Task 完了パターン学習
26//! ├── worker_decision.rs  # Worker の意思決定シーケンス学習
27//! └── error.rs            # エラー型
28//! ```
29//!
30//! ## 実装一覧
31//!
32//! | 実装 | 目的 | グルーピング |
33//! |------|------|-------------|
34//! | [`DpoLearnModel`] | 汎用 DPO 学習 | `group_id` |
35//! | [`DependencyGraphLearnModel`] | DependencyGraph 推論 | - |
36//! | [`WorkerTaskLearn`] | Task 完了パターン | `task_id` |
37//! | [`WorkerDecisionSequenceLearn`] | 意思決定シーケンス | `worker_id` |
38//!
39//! ## DPO (Direct Preference Optimization)
40//!
41//! DPO は成功/失敗 Episode のペアから学習する手法。
42//! `group_id` でグルーピングし、カスタム `extractor` で prompt/response を抽出。
43//!
44//! [`DependencyGraphLearnModel::extractor()`] を使って DPO を行う例:
45//!
46//! ```ignore
47//! let dpo = DpoLearnModel::new(
48//!     DependencyGraphLearnModel::default_system_prompt(),
49//!     DependencyGraphLearnModel::extractor(),
50//! );
51//! let pairs = dpo.build_pairs(&episodes);
52//! ```
53//!
54//! ## Record による抽象化
55//!
56//! ActionEvent と LlmDebugEvent を `Record` enum で統一的に扱う。
57//! LearnModel は Record のストリームから Episode を構築する。
58//!
59//! ```text
60//! ActionEvent ──┐
61//!               ├──▶ Vec<Record> ──▶ LearnModel.build_episodes()
62//! LlmDebugEvent ┘                         ↓
63//!                                    Vec<Episode>
64//!                                         ↓
65//!                                    LearnModel.convert()
66//!                                         ↓
67//!                                    TrainingData
68//! ```
69
70mod dependency_graph;
71mod dpo;
72mod error;
73mod worker_decision;
74mod worker_task;
75
76pub use dependency_graph::DependencyGraphLearnModel;
77pub use dpo::{DpoConfig, DpoLearnModel, DpoPair};
78pub use error::LearnError;
79pub use worker_decision::WorkerDecisionSequenceLearn;
80pub use worker_task::WorkerTaskLearn;
81
82use crate::events::ActionEvent;
83
84use super::episode::{Episode, EpisodeContext, Outcome};
85use super::record::Record;
86use super::training::TrainingData;
87
88// ============================================================================
89// System Event Constants
90// ============================================================================
91
92/// システムイベント定数
93pub mod system_events {
94    /// Tick 開始イベント
95    pub const TICK_START: &str = "tick_start";
96    /// Tick 終了イベント
97    pub const TICK_END: &str = "tick_end";
98    /// タスク完了イベント
99    pub const DONE: &str = "done";
100
101    /// デフォルトのシステムイベント一覧
102    pub const DEFAULT_SYSTEM_EVENTS: &[&str] = &[TICK_START, TICK_END, DONE];
103}
104
105// ============================================================================
106// LearnModel Trait
107// ============================================================================
108
109/// 学習の統合モデル
110///
111/// 何を学習対象とし、何を成功とするかを統合的に定義する。
112/// Record[] から Episode を構築し、TrainingData に変換するまでの全責務を担う。
113///
114/// ## Record による統一インターフェース
115///
116/// ActionEvent も LlmDebugEvent も `Record` として統一的に扱う。
117/// これにより:
118/// - ActionEvent ベースの Learn
119/// - LlmDebugEvent ベースの Learn
120/// - 両方を混ぜた Learn
121/// 全て同じインターフェースで実装可能。
122pub trait LearnModel: Send + Sync {
123    /// 名前
124    fn name(&self) -> &str;
125
126    /// 目的を表す説明
127    fn objective(&self) -> &str;
128
129    /// Record のストリームから Episode を構築
130    ///
131    /// N-gram、Worker単位、任意のグルーピングが可能。
132    /// Core が 3-gram までしか取れなくても、Learn は 5-gram や 10-gram を
133    /// 自由に構築できる。
134    fn build_episodes(&self, records: &[Record]) -> Vec<Episode>;
135
136    /// Records から Success/Failure を判定
137    ///
138    /// 純粋なロジック: EpisodeContext (Records) → Outcome
139    /// build_episodes() 内でこれを呼んで Episode.outcome を設定する。
140    fn evaluate(&self, context: &EpisodeContext) -> Outcome;
141
142    /// Episode を TrainingData に変換
143    fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError>;
144
145    /// 複数 Episode を一括変換(デフォルト実装)
146    fn convert_batch(&self, episodes: &[Episode]) -> Vec<TrainingData> {
147        episodes
148            .iter()
149            .filter_map(|ep| self.convert(ep).ok())
150            .collect()
151    }
152
153    /// 便利メソッド: ActionEvent[] から直接変換
154    fn build_episodes_from_actions(&self, actions: &[ActionEvent]) -> Vec<Episode> {
155        let records: Vec<Record> = actions.iter().map(Record::from).collect();
156        self.build_episodes(&records)
157    }
158}
159
160// ============================================================================
161// Tests
162// ============================================================================
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::events::{ActionContext, ActionEventBuilder, ActionEventResult};
168    use crate::learn::record::RecordStream;
169    use crate::types::WorkerId;
170    use std::time::Duration;
171
172    fn make_action(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
173        let result = if success {
174            ActionEventResult::success()
175        } else {
176            ActionEventResult::failure("error")
177        };
178
179        ActionEventBuilder::new(tick, WorkerId(worker_id), action)
180            .result(result)
181            .duration(Duration::from_millis(10))
182            .context(ActionContext::new())
183            .build()
184    }
185
186    fn make_records(actions: &[ActionEvent]) -> Vec<Record> {
187        actions.iter().map(Record::from).collect()
188    }
189
190    #[test]
191    fn test_record_accessors() {
192        let action = make_action(1, 5, "CheckStatus", true);
193        let record = Record::from(&action);
194
195        assert!(record.is_action());
196        assert!(!record.is_llm());
197        assert_eq!(record.worker_id(), Some(5));
198        assert!(record.as_action().is_some());
199        assert!(record.as_llm().is_none());
200    }
201
202    #[test]
203    fn test_record_stream_group_by_worker() {
204        let actions = vec![
205            make_action(1, 0, "A", true),
206            make_action(2, 1, "B", true),
207            make_action(3, 0, "C", true),
208            make_action(4, 1, "D", true),
209        ];
210        let records = make_records(&actions);
211        let stream = RecordStream::new(&records);
212
213        let groups = stream.group_by_worker();
214        assert_eq!(groups.len(), 2);
215        assert_eq!(groups.get(&0).map(|v| v.len()), Some(2));
216        assert_eq!(groups.get(&1).map(|v| v.len()), Some(2));
217    }
218}