Skip to main content

swarm_engine_core/learn/
learn_model.rs

1//! LearnModel - 学習の統合モデル
2//!
3//! ## 設計思想
4//!
5//! LearnModel は「何を学習するか」を統合的に定義する。
6//!
7//! - **何を目的とするか** (objective)
8//! - **何を Episode として切り出すか** (build_episodes)
9//! - **何を Success/Failure とするか** (evaluate)
10//! - **どう TrainingData に変換するか** (convert)
11//!
12//! ## Learn の価値
13//!
14//! Core(Swarm本体)は性能制約で 3-gram までしか取れない。
15//! しかし Learn は非同期/オフラインなので、5-gram や 10-gram など
16//! 自由に分析できる。これが Learn モジュールの価値。
17//!
18//! ## Record による抽象化
19//!
20//! ActionEvent と LlmDebugEvent を `Record` enum で統一的に扱う。
21//! LearnModel は Record のストリームから Episode を構築する。
22//!
23//! ```text
24//! ActionEvent ──┐
25//!               ├──▶ Vec<Record> ──▶ LearnModel.build_episodes()
26//! LlmDebugEvent ┘                         ↓
27//!                                    Vec<Episode>
28//!                                         ↓
29//!                                    LearnModel.convert()
30//!                                         ↓
31//!                                    TrainingData
32//! ```
33
34use crate::events::ActionEvent;
35
36use super::episode::{Episode, EpisodeContext, Outcome};
37use super::record::{ActionRecord, Record, RecordStream};
38use super::training::TrainingData;
39
40// ============================================================================
41// System Event Constants
42// ============================================================================
43
44/// システムイベント定数
45pub mod system_events {
46    /// Tick 開始イベント
47    pub const TICK_START: &str = "tick_start";
48    /// Tick 終了イベント
49    pub const TICK_END: &str = "tick_end";
50    /// タスク完了イベント
51    pub const DONE: &str = "done";
52
53    /// デフォルトのシステムイベント一覧
54    pub const DEFAULT_SYSTEM_EVENTS: &[&str] = &[TICK_START, TICK_END, DONE];
55}
56
57// ============================================================================
58// LearnModel Trait
59// ============================================================================
60
61/// 学習の統合モデル
62///
63/// 何を学習対象とし、何を成功とするかを統合的に定義する。
64/// Record[] から Episode を構築し、TrainingData に変換するまでの全責務を担う。
65///
66/// ## Record による統一インターフェース
67///
68/// ActionEvent も LlmDebugEvent も `Record` として統一的に扱う。
69/// これにより:
70/// - ActionEvent ベースの Learn
71/// - LlmDebugEvent ベースの Learn
72/// - 両方を混ぜた Learn
73/// 全て同じインターフェースで実装可能。
74pub trait LearnModel: Send + Sync {
75    /// 名前
76    fn name(&self) -> &str;
77
78    /// 目的を表す説明
79    fn objective(&self) -> &str;
80
81    /// Record のストリームから Episode を構築
82    ///
83    /// N-gram、Worker単位、任意のグルーピングが可能。
84    /// Core が 3-gram までしか取れなくても、Learn は 5-gram や 10-gram を
85    /// 自由に構築できる。
86    fn build_episodes(&self, records: &[Record]) -> Vec<Episode>;
87
88    /// Records から Success/Failure を判定
89    ///
90    /// 純粋なロジック: EpisodeContext (Records) → Outcome
91    /// build_episodes() 内でこれを呼んで Episode.outcome を設定する。
92    fn evaluate(&self, context: &EpisodeContext) -> Outcome;
93
94    /// Episode を TrainingData に変換
95    fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError>;
96
97    /// 複数 Episode を一括変換(デフォルト実装)
98    fn convert_batch(&self, episodes: &[Episode]) -> Vec<TrainingData> {
99        episodes
100            .iter()
101            .filter_map(|ep| self.convert(ep).ok())
102            .collect()
103    }
104
105    /// 便利メソッド: ActionEvent[] から直接変換
106    fn build_episodes_from_actions(&self, actions: &[ActionEvent]) -> Vec<Episode> {
107        let records: Vec<Record> = actions.iter().map(Record::from).collect();
108        self.build_episodes(&records)
109    }
110}
111
112// ============================================================================
113// LearnError
114// ============================================================================
115
116/// LearnModel のエラー型
117#[derive(Debug, thiserror::Error)]
118pub enum LearnError {
119    #[error("Build error: {0}")]
120    Build(String),
121
122    #[error("Conversion error: {0}")]
123    Conversion(String),
124
125    #[error("Missing data: {0}")]
126    MissingData(String),
127
128    #[error("Invalid episode: {0}")]
129    InvalidEpisode(String),
130
131    #[error("{0}")]
132    Other(String),
133}
134
135// ============================================================================
136// 組み込み LearnModel 実装
137// ============================================================================
138
139/// Worker タスク完了ベースの LearnModel
140///
141/// Worker が開始から done までに実行したアクション列を Episode として構築。
142pub struct WorkerTaskLearn {
143    /// システムプロンプト
144    system_prompt: String,
145    /// 最小アクション数
146    min_actions: usize,
147}
148
149impl WorkerTaskLearn {
150    pub fn new() -> Self {
151        Self {
152            system_prompt:
153                "You are an intelligent agent that diagnoses and resolves system issues."
154                    .to_string(),
155            min_actions: 2,
156        }
157    }
158
159    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
160        self.system_prompt = prompt.into();
161        self
162    }
163
164    pub fn with_min_actions(mut self, min: usize) -> Self {
165        self.min_actions = min;
166        self
167    }
168}
169
170impl Default for WorkerTaskLearn {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176impl LearnModel for WorkerTaskLearn {
177    fn name(&self) -> &str {
178        "worker_task"
179    }
180
181    fn objective(&self) -> &str {
182        "Learn complete worker task sequences from start to done"
183    }
184
185    fn evaluate(&self, context: &EpisodeContext) -> Outcome {
186        // 空のコンテキストは評価不能
187        if context.is_empty() {
188            return Outcome::failure("Empty context: no actions to evaluate");
189        }
190
191        // 最後のアクションが done かつ success なら成功
192        let last_action = context.iter::<ActionRecord>().last();
193
194        match last_action {
195            Some(action) if action.is_terminal() => {
196                if action.success {
197                    Outcome::success_binary()
198                } else {
199                    Outcome::failure("Task failed")
200                }
201            }
202            _ => Outcome::Unknown,
203        }
204    }
205
206    fn build_episodes(&self, records: &[Record]) -> Vec<Episode> {
207        use std::collections::HashMap;
208
209        let stream = RecordStream::new(records);
210
211        // Worker ID ごとにグルーピング(Action Records のみ)
212        let mut worker_actions: HashMap<usize, Vec<&ActionRecord>> = HashMap::new();
213        for record in stream.actions() {
214            worker_actions
215                .entry(record.worker_id)
216                .or_default()
217                .push(record);
218        }
219
220        let mut episodes = Vec::new();
221
222        for (_worker_id, worker_records) in worker_actions {
223            // done で終わるシーケンスを探す
224            let mut current_sequence: Vec<&ActionRecord> = Vec::new();
225
226            for record in worker_records {
227                current_sequence.push(record);
228
229                if record.is_terminal() {
230                    // シーケンス完了
231                    if current_sequence.len() >= self.min_actions {
232                        // context を構築
233                        let mut context = EpisodeContext::new();
234                        for r in &current_sequence {
235                            context.push((*r).clone());
236                        }
237
238                        // evaluate で outcome を判定
239                        let outcome = self.evaluate(&context);
240
241                        let episode = Episode::builder()
242                            .learn_model("worker_task")
243                            .context(context)
244                            .outcome(outcome)
245                            .build();
246
247                        episodes.push(episode);
248                    }
249
250                    current_sequence.clear();
251                }
252            }
253        }
254
255        episodes
256    }
257
258    fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError> {
259        if !episode.outcome.is_success() {
260            return Err(LearnError::InvalidEpisode(
261                "Episode is not successful".into(),
262            ));
263        }
264
265        let action_count = episode.context.iter::<ActionRecord>().count();
266        if action_count < self.min_actions {
267            return Err(LearnError::InvalidEpisode(format!(
268                "Too few actions: {} < {}",
269                action_count, self.min_actions
270            )));
271        }
272
273        let actions: Vec<&str> = episode
274            .context
275            .iter::<ActionRecord>()
276            .filter(|a| !a.is_terminal())
277            .map(|a| a.action.as_str())
278            .collect();
279
280        let prompt = format!(
281            "Diagnose and resolve the issue.\nAvailable actions: {}",
282            actions.join(", ")
283        );
284
285        let response = format!("Execute the following sequence: {}", actions.join(" -> "));
286
287        Ok(TrainingData::sft(&self.system_prompt, &prompt, &response)
288            .with_episode_id(episode.id.to_string())
289            .with_outcome_score(episode.outcome.score()))
290    }
291}
292
293/// Worker Decision 学習モデル(シーケンスベース)
294///
295/// **ターゲット**: Worker の Decision(アクション選択)を改善
296/// **手法**: 成功シーケンスを使った学習
297///
298/// ## 学習目的
299///
300/// Worker の Decider(アクション決定LLM)を fine-tuning するためのデータ生成。
301/// 成功したアクションシーケンスから「どの順序でアクションを実行すべきか」を学習。
302///
303/// ## プロンプト形式
304///
305/// - **入力**: コンテキスト + 利用可能なアクション一覧
306/// - **出力**: 最適なアクションシーケンス
307///
308/// ## 用途
309///
310/// - LoRA fine-tuning 用のデータ生成(Decider向け)
311/// - 成功パターンの抽出と再利用
312pub struct WorkerDecisionSequenceLearn {
313    /// システムプロンプト
314    system_prompt: String,
315    /// 最小アクション数(これ以下は無視)
316    min_actions: usize,
317    /// 利用可能なアクション一覧(プロンプト生成用)
318    available_actions: Vec<String>,
319    /// システムイベント(フィルタ対象)
320    system_events: Vec<String>,
321}
322
323impl WorkerDecisionSequenceLearn {
324    pub fn new() -> Self {
325        Self {
326            system_prompt: "You are an intelligent agent that diagnoses and resolves system issues. \
327                           Given a context and available actions, determine the optimal action sequence.".to_string(),
328            min_actions: 3,
329            available_actions: vec![
330                "CheckStatus".to_string(),
331                "ReadLogs".to_string(),
332                "AnalyzeMetrics".to_string(),
333                "Diagnose".to_string(),
334                "Restart".to_string(),
335            ],
336            system_events: system_events::DEFAULT_SYSTEM_EVENTS
337                .iter()
338                .map(|s| s.to_string())
339                .collect(),
340        }
341    }
342
343    /// システムプロンプトを設定
344    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
345        self.system_prompt = prompt.into();
346        self
347    }
348
349    /// 最小アクション数を設定
350    pub fn with_min_actions(mut self, min: usize) -> Self {
351        self.min_actions = min;
352        self
353    }
354
355    /// 利用可能なアクション一覧を設定
356    pub fn with_available_actions(mut self, actions: Vec<String>) -> Self {
357        self.available_actions = actions;
358        self
359    }
360
361    /// システムイベント(フィルタ対象)を追加
362    pub fn with_system_event(mut self, event: impl Into<String>) -> Self {
363        self.system_events.push(event.into());
364        self
365    }
366
367    /// アクションがシステムイベントかどうか判定
368    fn is_system_event(&self, action: &str) -> bool {
369        self.system_events.iter().any(|e| e == action)
370    }
371}
372
373impl Default for WorkerDecisionSequenceLearn {
374    fn default() -> Self {
375        Self::new()
376    }
377}
378
379impl LearnModel for WorkerDecisionSequenceLearn {
380    fn name(&self) -> &str {
381        "worker_decision_sequence"
382    }
383
384    fn objective(&self) -> &str {
385        "Learn successful action sequences for problem resolution"
386    }
387
388    fn evaluate(&self, context: &EpisodeContext) -> Outcome {
389        // 空のコンテキストは評価不能
390        if context.is_empty() {
391            return Outcome::failure("Empty context: no actions to evaluate");
392        }
393
394        // 成功したアクション(システムイベント除く)をカウント
395        let successful_actions: Vec<_> = context
396            .iter::<ActionRecord>()
397            .filter(|a| a.success && !self.is_system_event(&a.action))
398            .collect();
399
400        if successful_actions.len() >= self.min_actions {
401            Outcome::success(1.0)
402        } else {
403            Outcome::failure(format!(
404                "Insufficient successful actions: {} < {}",
405                successful_actions.len(),
406                self.min_actions
407            ))
408        }
409    }
410
411    fn build_episodes(&self, records: &[Record]) -> Vec<Episode> {
412        // 成功したアクション(システムイベント除く)のみ抽出
413        let successful_actions: Vec<&ActionRecord> = records
414            .iter()
415            .filter_map(Record::as_action)
416            .filter(|a| a.success && !self.is_system_event(&a.action))
417            .collect();
418
419        if successful_actions.len() < self.min_actions {
420            return vec![];
421        }
422
423        // 全成功アクションを1つのEpisodeとして構築
424        let mut context = EpisodeContext::new();
425        for action in &successful_actions {
426            context.push((*action).clone());
427        }
428
429        let outcome = self.evaluate(&context);
430
431        let episode = Episode::builder()
432            .learn_model("worker_decision_sequence")
433            .context(context)
434            .outcome(outcome)
435            .build();
436
437        vec![episode]
438    }
439
440    fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError> {
441        if !episode.outcome.is_success() {
442            return Err(LearnError::InvalidEpisode(
443                "Episode is not successful".into(),
444            ));
445        }
446
447        let actions: Vec<&str> = episode
448            .context
449            .iter::<ActionRecord>()
450            .map(|a| a.action.as_str())
451            .collect();
452
453        if actions.len() < self.min_actions {
454            return Err(LearnError::InvalidEpisode(format!(
455                "Too few actions: {} < {}",
456                actions.len(),
457                self.min_actions
458            )));
459        }
460
461        // プロンプト生成
462        let available = self.available_actions.join(", ");
463        let prompt = format!(
464            "Current context: default\n\
465             Available actions: {}\n\n\
466             What is the best sequence of actions to resolve this issue?",
467            available
468        );
469
470        // レスポンス生成
471        let action_sequence = actions.join(" -> ");
472        let response = format!(
473            "Based on the context, the optimal action sequence is: {}",
474            action_sequence
475        );
476
477        Ok(TrainingData::sft(&self.system_prompt, &prompt, &response)
478            .with_episode_id(episode.id.to_string())
479            .with_outcome_score(episode.outcome.score()))
480    }
481}
482
483// ============================================================================
484// Tests
485// ============================================================================
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use crate::events::{ActionContext, ActionEventBuilder, ActionEventResult};
491    use crate::types::WorkerId;
492    use std::time::Duration;
493
494    fn make_action(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
495        let result = if success {
496            ActionEventResult::success()
497        } else {
498            ActionEventResult::failure("error")
499        };
500
501        ActionEventBuilder::new(tick, WorkerId(worker_id), action)
502            .result(result)
503            .duration(Duration::from_millis(10))
504            .context(ActionContext::new())
505            .build()
506    }
507
508    fn make_records(actions: &[ActionEvent]) -> Vec<Record> {
509        actions.iter().map(Record::from).collect()
510    }
511
512    #[test]
513    fn test_record_accessors() {
514        let action = make_action(1, 5, "CheckStatus", true);
515        let record = Record::from(&action);
516
517        assert!(record.is_action());
518        assert!(!record.is_llm());
519        assert_eq!(record.worker_id(), Some(5));
520        assert!(record.as_action().is_some());
521        assert!(record.as_llm().is_none());
522    }
523
524    #[test]
525    fn test_record_stream_group_by_worker() {
526        let actions = vec![
527            make_action(1, 0, "A", true),
528            make_action(2, 1, "B", true),
529            make_action(3, 0, "C", true),
530            make_action(4, 1, "D", true),
531        ];
532        let records = make_records(&actions);
533        let stream = RecordStream::new(&records);
534
535        let groups = stream.group_by_worker();
536        assert_eq!(groups.len(), 2);
537        assert_eq!(groups.get(&0).map(|v| v.len()), Some(2));
538        assert_eq!(groups.get(&1).map(|v| v.len()), Some(2));
539    }
540
541    #[test]
542    fn test_worker_task_learn_build_episodes() {
543        let learn = WorkerTaskLearn::new().with_min_actions(2);
544
545        let actions = vec![
546            make_action(1, 0, "CheckStatus", true),
547            make_action(2, 0, "ReadLogs", true),
548            make_action(3, 0, "done", true),
549            make_action(4, 1, "Grep", true),
550            make_action(5, 1, "done", false),
551        ];
552        let records = make_records(&actions);
553
554        let episodes = learn.build_episodes(&records);
555
556        // Worker 0: success, Worker 1: failure
557        assert_eq!(episodes.len(), 2);
558
559        let worker0_ep = episodes.iter().find(|ep| ep.worker_id() == Some(0));
560        assert!(worker0_ep.is_some());
561        assert!(worker0_ep.unwrap().outcome.is_success());
562
563        let worker1_ep = episodes.iter().find(|ep| ep.worker_id() == Some(1));
564        assert!(worker1_ep.is_some());
565        assert!(worker1_ep.unwrap().outcome.is_failure());
566    }
567
568    #[test]
569    fn test_worker_task_learn_convert_success_only() {
570        let learn = WorkerTaskLearn::new();
571
572        // 失敗 Episode は変換エラー
573        let failed_ep = Episode::builder()
574            .learn_model("worker_task")
575            .outcome(Outcome::failure("test"))
576            .build();
577
578        assert!(learn.convert(&failed_ep).is_err());
579
580        // 成功 Episode は変換可能
581        let success_ep = Episode::builder()
582            .learn_model("worker_task")
583            .record(ActionRecord::new(1, 0, "Check").success(true))
584            .record(ActionRecord::new(2, 0, "Fix").success(true))
585            .record(ActionRecord::new(3, 0, "done").success(true))
586            .outcome(Outcome::success_binary())
587            .build();
588
589        assert!(learn.convert(&success_ep).is_ok());
590    }
591
592    // ========================================================================
593    // WorkerDecisionSequenceLearn Tests
594    // ========================================================================
595
596    #[test]
597    fn test_worker_decision_sequence_learn_build_episodes() {
598        let learn = WorkerDecisionSequenceLearn::new().with_min_actions(3);
599
600        let actions = vec![
601            make_action(1, 0, "CheckStatus", true),
602            make_action(2, 0, "ReadLogs", true),
603            make_action(3, 0, "Restart", true),
604            make_action(4, 0, "tick_end", true), // system event, should be filtered
605            make_action(5, 0, "done", true),     // system event, should be filtered
606        ];
607        let records = make_records(&actions);
608
609        let episodes = learn.build_episodes(&records);
610
611        // 1 episode(成功アクション3つ: CheckStatus, ReadLogs, Restart)
612        assert_eq!(episodes.len(), 1);
613        assert!(episodes[0].outcome.is_success());
614
615        // システムイベントは除外されている
616        let action_count = episodes[0].context.iter::<ActionRecord>().count();
617        assert_eq!(action_count, 3);
618    }
619
620    #[test]
621    fn test_worker_decision_sequence_learn_filters_failed_actions() {
622        let learn = WorkerDecisionSequenceLearn::new().with_min_actions(3);
623
624        let actions = vec![
625            make_action(1, 0, "CheckStatus", true),
626            make_action(2, 0, "ReadLogs", false), // failed, should be filtered
627            make_action(3, 0, "Restart", true),
628        ];
629        let records = make_records(&actions);
630
631        let episodes = learn.build_episodes(&records);
632
633        // 成功アクションが2つしかないので Episode は生成されない
634        assert_eq!(episodes.len(), 0);
635    }
636
637    #[test]
638    fn test_worker_decision_sequence_learn_convert() {
639        let learn = WorkerDecisionSequenceLearn::new().with_available_actions(vec![
640            "A".to_string(),
641            "B".to_string(),
642            "C".to_string(),
643        ]);
644
645        let episode = Episode::builder()
646            .learn_model("worker_decision_sequence")
647            .record(ActionRecord::new(1, 0, "A").success(true))
648            .record(ActionRecord::new(2, 0, "B").success(true))
649            .record(ActionRecord::new(3, 0, "C").success(true))
650            .outcome(Outcome::success(1.0))
651            .build();
652
653        let result = learn.convert(&episode);
654        assert!(result.is_ok());
655
656        let training_data = result.unwrap();
657        assert!(training_data.prompt.contains("Available actions: A, B, C"));
658        assert!(training_data.chosen.contains("A -> B -> C"));
659    }
660}