swarm_engine_core/learn/learn_model/mod.rs
1//! LearnModel - 学習の統合モデル
2//!
3//! ## 設計思想
4//!
5//! LearnModel は「何を学習するか」を統合的に定義する。
6//!
7//! - **何を目的とするか** (objective)
8//! - **何を Episode として切り出すか** (build_episodes)
9//! - **何を Success/Failure とするか** (evaluate)
10//! - **どう TrainingData に変換するか** (convert)
11//!
12//! ## Learn の価値
13//!
14//! Core(Swarm本体)は性能制約で 3-gram までしか取れない。
15//! しかし Learn は非同期/オフラインなので、5-gram や 10-gram など
16//! 自由に分析できる。これが Learn モジュールの価値。
17//!
18//! ## モジュール構造
19//!
20//! ```text
21//! learn_model/
22//! ├── mod.rs # LearnModel trait 定義
23//! ├── dpo.rs # 汎用 DPO 基盤 (DpoLearnModel<F>)
24//! ├── dependency_graph.rs # DependencyGraph 推論の学習
25//! ├── worker_task.rs # Worker の Task 完了パターン学習
26//! ├── worker_decision.rs # Worker の意思決定シーケンス学習
27//! └── error.rs # エラー型
28//! ```
29//!
30//! ## 実装一覧
31//!
32//! | 実装 | 目的 | グルーピング |
33//! |------|------|-------------|
34//! | [`DpoLearnModel`] | 汎用 DPO 学習 | `group_id` |
35//! | [`DependencyGraphLearnModel`] | DependencyGraph 推論 | - |
36//! | [`WorkerTaskLearn`] | Task 完了パターン | `task_id` |
37//! | [`WorkerDecisionSequenceLearn`] | 意思決定シーケンス | `worker_id` |
38//!
39//! ## DPO (Direct Preference Optimization)
40//!
41//! DPO は成功/失敗 Episode のペアから学習する手法。
42//! `group_id` でグルーピングし、カスタム `extractor` で prompt/response を抽出。
43//!
44//! [`DependencyGraphLearnModel::extractor()`] を使って DPO を行う例:
45//!
46//! ```ignore
47//! let dpo = DpoLearnModel::new(
48//! DependencyGraphLearnModel::default_system_prompt(),
49//! DependencyGraphLearnModel::extractor(),
50//! );
51//! let pairs = dpo.build_pairs(&episodes);
52//! ```
53//!
54//! ## Record による抽象化
55//!
56//! ActionEvent と LlmDebugEvent を `Record` enum で統一的に扱う。
57//! LearnModel は Record のストリームから Episode を構築する。
58//!
59//! ```text
60//! ActionEvent ──┐
61//! ├──▶ Vec<Record> ──▶ LearnModel.build_episodes()
62//! LlmDebugEvent ┘ ↓
63//! Vec<Episode>
64//! ↓
65//! LearnModel.convert()
66//! ↓
67//! TrainingData
68//! ```
69
70mod dependency_graph;
71mod dpo;
72mod error;
73mod worker_decision;
74mod worker_task;
75
76pub use dependency_graph::DependencyGraphLearnModel;
77pub use dpo::{DpoConfig, DpoLearnModel, DpoPair};
78pub use error::LearnError;
79pub use worker_decision::WorkerDecisionSequenceLearn;
80pub use worker_task::WorkerTaskLearn;
81
82use crate::events::ActionEvent;
83
84use super::episode::{Episode, EpisodeContext, Outcome};
85use super::record::Record;
86use super::training::TrainingData;
87
88// ============================================================================
89// System Event Constants
90// ============================================================================
91
92/// システムイベント定数
93pub mod system_events {
94 /// Tick 開始イベント
95 pub const TICK_START: &str = "tick_start";
96 /// Tick 終了イベント
97 pub const TICK_END: &str = "tick_end";
98 /// タスク完了イベント
99 pub const DONE: &str = "done";
100
101 /// デフォルトのシステムイベント一覧
102 pub const DEFAULT_SYSTEM_EVENTS: &[&str] = &[TICK_START, TICK_END, DONE];
103}
104
105// ============================================================================
106// LearnModel Trait
107// ============================================================================
108
109/// 学習の統合モデル
110///
111/// 何を学習対象とし、何を成功とするかを統合的に定義する。
112/// Record[] から Episode を構築し、TrainingData に変換するまでの全責務を担う。
113///
114/// ## Record による統一インターフェース
115///
116/// ActionEvent も LlmDebugEvent も `Record` として統一的に扱う。
117/// これにより:
118/// - ActionEvent ベースの Learn
119/// - LlmDebugEvent ベースの Learn
120/// - 両方を混ぜた Learn
121/// 全て同じインターフェースで実装可能。
122pub trait LearnModel: Send + Sync {
123 /// 名前
124 fn name(&self) -> &str;
125
126 /// 目的を表す説明
127 fn objective(&self) -> &str;
128
129 /// Record のストリームから Episode を構築
130 ///
131 /// N-gram、Worker単位、任意のグルーピングが可能。
132 /// Core が 3-gram までしか取れなくても、Learn は 5-gram や 10-gram を
133 /// 自由に構築できる。
134 fn build_episodes(&self, records: &[Record]) -> Vec<Episode>;
135
136 /// Records から Success/Failure を判定
137 ///
138 /// 純粋なロジック: EpisodeContext (Records) → Outcome
139 /// build_episodes() 内でこれを呼んで Episode.outcome を設定する。
140 fn evaluate(&self, context: &EpisodeContext) -> Outcome;
141
142 /// Episode を TrainingData に変換
143 fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError>;
144
145 /// 複数 Episode を一括変換(デフォルト実装)
146 fn convert_batch(&self, episodes: &[Episode]) -> Vec<TrainingData> {
147 episodes
148 .iter()
149 .filter_map(|ep| self.convert(ep).ok())
150 .collect()
151 }
152
153 /// 便利メソッド: ActionEvent[] から直接変換
154 fn build_episodes_from_actions(&self, actions: &[ActionEvent]) -> Vec<Episode> {
155 let records: Vec<Record> = actions.iter().map(Record::from).collect();
156 self.build_episodes(&records)
157 }
158}
159
160// ============================================================================
161// Tests
162// ============================================================================
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::events::{ActionContext, ActionEventBuilder, ActionEventResult};
168 use crate::learn::record::RecordStream;
169 use crate::types::WorkerId;
170 use std::time::Duration;
171
172 fn make_action(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
173 let result = if success {
174 ActionEventResult::success()
175 } else {
176 ActionEventResult::failure("error")
177 };
178
179 ActionEventBuilder::new(tick, WorkerId(worker_id), action)
180 .result(result)
181 .duration(Duration::from_millis(10))
182 .context(ActionContext::new())
183 .build()
184 }
185
186 fn make_records(actions: &[ActionEvent]) -> Vec<Record> {
187 actions.iter().map(Record::from).collect()
188 }
189
190 #[test]
191 fn test_record_accessors() {
192 let action = make_action(1, 5, "CheckStatus", true);
193 let record = Record::from(&action);
194
195 assert!(record.is_action());
196 assert!(!record.is_llm());
197 assert_eq!(record.worker_id(), Some(5));
198 assert!(record.as_action().is_some());
199 assert!(record.as_llm().is_none());
200 }
201
202 #[test]
203 fn test_record_stream_group_by_worker() {
204 let actions = vec![
205 make_action(1, 0, "A", true),
206 make_action(2, 1, "B", true),
207 make_action(3, 0, "C", true),
208 make_action(4, 1, "D", true),
209 ];
210 let records = make_records(&actions);
211 let stream = RecordStream::new(&records);
212
213 let groups = stream.group_by_worker();
214 assert_eq!(groups.len(), 2);
215 assert_eq!(groups.get(&0).map(|v| v.len()), Some(2));
216 assert_eq!(groups.get(&1).map(|v| v.len()), Some(2));
217 }
218}