swarm_engine_core/learn/learn_model/mod.rs
1//! LearnModel - 学習の統合モデル
2//!
3//! ## 設計思想
4//!
5//! LearnModel は「何を学習するか」を統合的に定義する。
6//!
7//! - **何を目的とするか** (objective)
8//! - **何を Episode として切り出すか** (build_episodes)
9//! - **何を Success/Failure とするか** (evaluate)
10//! - **どう TrainingData に変換するか** (convert)
11//!
12//! ## Learn の価値
13//!
14//! Core(Swarm本体)は性能制約で 3-gram までしか取れない。
15//! しかし Learn は非同期/オフラインなので、5-gram や 10-gram など
16//! 自由に分析できる。これが Learn モジュールの価値。
17//!
18//! ## モジュール構造
19//!
20//! ```text
21//! learn_model/
22//! ├── mod.rs # LearnModel trait 定義
23//! ├── dpo.rs # 汎用 DPO 基盤 (DpoLearnModel<F>)
24//! ├── dependency_graph.rs # DependencyGraph 推論の学習
25//! ├── worker_task.rs # Worker の Task 完了パターン学習
26//! ├── worker_decision.rs # Worker の意思決定シーケンス学習
27//! └── error.rs # エラー型
28//! ```
29//!
30//! ## 実装一覧
31//!
32//! | 実装 | 目的 | グルーピング |
33//! |------|------|-------------|
34//! | [`DpoLearnModel`] | 汎用 DPO 学習 | `group_id` |
35//! | [`DependencyGraphLearnModel`] | DependencyGraph 推論 | - |
36//! | [`WorkerTaskLearn`] | Task 完了パターン | `task_id` |
37//! | [`WorkerDecisionSequenceLearn`] | 意思決定シーケンス | `worker_id` |
38//!
39//! ## DPO (Direct Preference Optimization)
40//!
41//! DPO は成功/失敗 Episode のペアから学習する手法。
42//! `group_id` でグルーピングし、カスタム `extractor` で prompt/response を抽出。
43//!
44//! [`DependencyGraphLearnModel::extractor()`] を使って DPO を行う例:
45//!
46//! ```ignore
47//! let dpo = DpoLearnModel::new(
48//! DependencyGraphLearnModel::default_system_prompt(),
49//! DependencyGraphLearnModel::extractor(),
50//! );
51//! let pairs = dpo.build_pairs(&episodes);
52//! ```
53//!
54//! ## Record による抽象化
55//!
56//! ActionEvent と LlmDebugEvent を `Record` enum で統一的に扱う。
57//! LearnModel は Record のストリームから Episode を構築する。
58//!
59//! ```text
60//! ActionEvent ──┐
61//! ├──▶ Vec<Record> ──▶ LearnModel.build_episodes()
62//! LlmDebugEvent ┘ ↓
63//! Vec<Episode>
64//! ↓
65//! LearnModel.convert()
66//! ↓
67//! TrainingData
68//! ```
69
70mod dependency_graph;
71mod dpo;
72mod error;
73mod worker_decision;
74mod worker_task;
75
76pub use dependency_graph::DependencyGraphLearnModel;
77pub use dpo::{DpoConfig, DpoLearnModel, DpoPair};
78pub use error::LearnError;
79pub use worker_decision::WorkerDecisionSequenceLearn;
80pub use worker_task::WorkerTaskLearn;
81
82use crate::events::ActionEvent;
83
84use super::episode::{Episode, EpisodeContext, Outcome};
85use super::record::Record;
86use super::training::TrainingData;
87
88// ============================================================================
89// System Event Constants
90// ============================================================================
91
92/// システムイベント定数
93pub mod system_events {
94 /// Tick 開始イベント
95 pub const TICK_START: &str = "tick_start";
96 /// Tick 終了イベント
97 pub const TICK_END: &str = "tick_end";
98 /// タスク完了イベント
99 pub const DONE: &str = "done";
100
101 /// デフォルトのシステムイベント一覧
102 pub const DEFAULT_SYSTEM_EVENTS: &[&str] = &[TICK_START, TICK_END, DONE];
103}
104
105// ============================================================================
106// LearnModel Trait
107// ============================================================================
108
109/// 学習の統合モデル
110///
111/// 何を学習対象とし、何を成功とするかを統合的に定義する。
112/// Record[] から Episode を構築し、TrainingData に変換するまでの全責務を担う。
113///
114/// ## Record による統一インターフェース
115///
116/// ActionEvent も LlmDebugEvent も `Record` として統一的に扱う。
117/// これにより:
118/// - ActionEvent ベースの Learn
119/// - LlmDebugEvent ベースの Learn
120/// - 両方を混ぜた Learn
121///
122/// 全て同じインターフェースで実装可能。
123pub trait LearnModel: Send + Sync {
124 /// 名前
125 fn name(&self) -> &str;
126
127 /// 目的を表す説明
128 fn objective(&self) -> &str;
129
130 /// Record のストリームから Episode を構築
131 ///
132 /// N-gram、Worker単位、任意のグルーピングが可能。
133 /// Core が 3-gram までしか取れなくても、Learn は 5-gram や 10-gram を
134 /// 自由に構築できる。
135 fn build_episodes(&self, records: &[Record]) -> Vec<Episode>;
136
137 /// Records から Success/Failure を判定
138 ///
139 /// 純粋なロジック: EpisodeContext (Records) → Outcome
140 /// build_episodes() 内でこれを呼んで Episode.outcome を設定する。
141 fn evaluate(&self, context: &EpisodeContext) -> Outcome;
142
143 /// Episode を TrainingData に変換
144 fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError>;
145
146 /// 複数 Episode を一括変換(デフォルト実装)
147 fn convert_batch(&self, episodes: &[Episode]) -> Vec<TrainingData> {
148 episodes
149 .iter()
150 .filter_map(|ep| self.convert(ep).ok())
151 .collect()
152 }
153
154 /// 便利メソッド: ActionEvent[] から直接変換
155 fn build_episodes_from_actions(&self, actions: &[ActionEvent]) -> Vec<Episode> {
156 let records: Vec<Record> = actions.iter().map(Record::from).collect();
157 self.build_episodes(&records)
158 }
159}
160
161// ============================================================================
162// Tests
163// ============================================================================
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::events::{ActionContext, ActionEventBuilder, ActionEventResult};
169 use crate::learn::record::RecordStream;
170 use crate::types::WorkerId;
171 use std::time::Duration;
172
173 fn make_action(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
174 let result = if success {
175 ActionEventResult::success()
176 } else {
177 ActionEventResult::failure("error")
178 };
179
180 ActionEventBuilder::new(tick, WorkerId(worker_id), action)
181 .result(result)
182 .duration(Duration::from_millis(10))
183 .context(ActionContext::new())
184 .build()
185 }
186
187 fn make_records(actions: &[ActionEvent]) -> Vec<Record> {
188 actions.iter().map(Record::from).collect()
189 }
190
191 #[test]
192 fn test_record_accessors() {
193 let action = make_action(1, 5, "CheckStatus", true);
194 let record = Record::from(&action);
195
196 assert!(record.is_action());
197 assert!(!record.is_llm());
198 assert_eq!(record.worker_id(), Some(5));
199 assert!(record.as_action().is_some());
200 assert!(record.as_llm().is_none());
201 }
202
203 #[test]
204 fn test_record_stream_group_by_worker() {
205 let actions = vec![
206 make_action(1, 0, "A", true),
207 make_action(2, 1, "B", true),
208 make_action(3, 0, "C", true),
209 make_action(4, 1, "D", true),
210 ];
211 let records = make_records(&actions);
212 let stream = RecordStream::new(&records);
213
214 let groups = stream.group_by_worker();
215 assert_eq!(groups.len(), 2);
216 assert_eq!(groups.get(&0).map(|v| v.len()), Some(2));
217 assert_eq!(groups.get(&1).map(|v| v.len()), Some(2));
218 }
219}