Skip to main content

swarm_engine_core/learn/
learned_component.rs

1//! LearnedComponent - 学習結果の型安全な抽象化
2//!
3//! ## 設計思想
4//!
5//! ComponentLearner(学習プロセス)と LearnedComponent(学習結果)をペアで定義。
6//! 各学習対象ごとに専用の型を持ち、Map/Any を使わず型安全性を確保する。
7//!
8//! ## 背景
9//!
10//! - 従来の ML: 口調、好み程度 → Map でも許容
11//! - Swarm Learning: Control の知識そのものが Domain → Typed 必須
12//!
13//! ## 使用例
14//!
15//! ```ignore
16//! // ComponentLearner と LearnedComponent はペアで定義
17//! struct DepGraphLearner;
18//!
19//! impl ComponentLearner for DepGraphLearner {
20//!     type Output = LearnedDepGraph;
21//!
22//!     fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError> {
23//!         // Episodes から LearnedDepGraph を生成
24//!     }
25//! }
26//!
27//! impl LearnedComponent for LearnedDepGraph {
28//!     fn component_id() -> &'static str { "dep_graph" }
29//!     // ...
30//! }
31//! ```
32//!
33//! ## LearnModel vs ComponentLearner
34//!
35//! - `LearnModel`: Episode → TrainingData (LoRA fine-tuning 用)
36//! - `ComponentLearner`: Episodes → LearnedComponent (ScenarioProfile 用)
37
38use serde::{de::DeserializeOwned, Serialize};
39
40use super::episode::Episode;
41use super::learn_model::LearnError;
42
43// ============================================================================
44// ComponentLearner Trait
45// ============================================================================
46
47/// ScenarioProfile コンポーネントの学習プロセス
48///
49/// Episode の集合から LearnedComponent を生成する。
50/// LearnModel(LoRA用)とは異なり、ScenarioProfile の各コンポーネントを
51/// 型安全に学習する。
52pub trait ComponentLearner: Send + Sync {
53    /// 学習結果の型
54    type Output: LearnedComponent;
55
56    /// 学習器の名前
57    fn name(&self) -> &str;
58
59    /// 目的の説明
60    fn objective(&self) -> &str;
61
62    /// Episode から学習結果を生成
63    fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError>;
64
65    /// 既存のコンポーネントを更新(増分学習)
66    fn update(
67        &self,
68        existing: &Self::Output,
69        new_episodes: &[Episode],
70    ) -> Result<Self::Output, LearnError> {
71        let mut learned = self.learn(new_episodes)?;
72        learned.merge(existing);
73        Ok(learned)
74    }
75}
76
77// ============================================================================
78// LearnedComponent Trait
79// ============================================================================
80
81/// 学習結果コンポーネントの共通 trait
82///
83/// 各学習対象(DepGraph, Strategy, Exploration 等)の学習結果が実装する。
84/// Typed で管理することで、型安全性と IDE サポートを確保。
85pub trait LearnedComponent: Send + Sync + Serialize + DeserializeOwned + Clone {
86    /// コンポーネント識別子(ファイル名等に使用)
87    fn component_id() -> &'static str
88    where
89        Self: Sized;
90
91    /// 信頼度スコア (0.0 - 1.0)
92    ///
93    /// 学習データ量や品質に基づく信頼度。
94    /// 低い場合は Bootstrap 追加実行を検討。
95    fn confidence(&self) -> f64;
96
97    /// 学習に使用したセッション数
98    fn session_count(&self) -> usize;
99
100    /// 最終更新タイムスタンプ (Unix epoch seconds)
101    fn updated_at(&self) -> u64;
102
103    /// 他のコンポーネントとマージ(増分学習用)
104    ///
105    /// デフォルト実装: 信頼度の高い方を優先
106    fn merge(&mut self, other: &Self)
107    where
108        Self: Sized,
109    {
110        if other.confidence() > self.confidence() {
111            *self = other.clone();
112        }
113    }
114
115    /// バージョン番号(互換性チェック用)
116    fn version() -> u32
117    where
118        Self: Sized,
119    {
120        1
121    }
122}
123
124// ============================================================================
125// LearnedDepGraph - 学習済み依存グラフ
126// ============================================================================
127
128use crate::exploration::DependencyGraph;
129
130// Re-use existing RecommendedPath from offline module
131pub use super::offline::RecommendedPath;
132
133/// 学習済み依存グラフ
134///
135/// Bootstrap フェーズで正解グラフから学習し、
136/// Release フェーズで LLM なしで即座にアクション順序を決定。
137#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
138pub struct LearnedDepGraph {
139    /// 依存グラフ本体
140    pub graph: DependencyGraph,
141
142    /// 学習済みアクション順序(トポロジカルソート済み)
143    /// Deprecated: discover_order + not_discover_order を使用
144    pub action_order: Vec<String>,
145
146    /// Discover アクションの順序(NodeExpand 系)
147    #[serde(default)]
148    pub discover_order: Vec<String>,
149
150    /// Not-Discover アクションの順序(NodeStateChange 系)
151    #[serde(default)]
152    pub not_discover_order: Vec<String>,
153
154    /// 推奨パス(成功率順)
155    #[serde(default)]
156    pub recommended_paths: Vec<RecommendedPath>,
157
158    /// 信頼度 (0.0 - 1.0)
159    pub confidence: f64,
160
161    /// 学習に使用したセッション ID
162    pub learned_from: Vec<String>,
163
164    /// 最終更新タイムスタンプ
165    pub updated_at: u64,
166}
167
168impl LearnedDepGraph {
169    /// 新規作成
170    pub fn new(graph: DependencyGraph, action_order: Vec<String>) -> Self {
171        Self {
172            graph,
173            action_order,
174            discover_order: Vec::new(),
175            not_discover_order: Vec::new(),
176            recommended_paths: Vec::new(),
177            confidence: 0.0,
178            learned_from: Vec::new(),
179            updated_at: std::time::SystemTime::now()
180                .duration_since(std::time::UNIX_EPOCH)
181                .map(|d| d.as_secs())
182                .unwrap_or(0),
183        }
184    }
185
186    /// discover/not_discover を個別に設定して作成
187    pub fn with_orders(
188        graph: DependencyGraph,
189        discover_order: Vec<String>,
190        not_discover_order: Vec<String>,
191    ) -> Self {
192        let mut all_actions = discover_order.clone();
193        all_actions.extend(not_discover_order.clone());
194        Self {
195            graph,
196            action_order: all_actions,
197            discover_order,
198            not_discover_order,
199            recommended_paths: Vec::new(),
200            confidence: 0.0,
201            learned_from: Vec::new(),
202            updated_at: std::time::SystemTime::now()
203                .duration_since(std::time::UNIX_EPOCH)
204                .map(|d| d.as_secs())
205                .unwrap_or(0),
206        }
207    }
208
209    /// 信頼度を設定
210    pub fn with_confidence(mut self, confidence: f64) -> Self {
211        self.confidence = confidence;
212        self
213    }
214
215    /// 学習元セッションを追加
216    pub fn with_sessions(mut self, session_ids: Vec<String>) -> Self {
217        self.learned_from = session_ids;
218        self
219    }
220
221    /// 推奨パスを追加
222    pub fn with_recommended_paths(mut self, paths: Vec<RecommendedPath>) -> Self {
223        self.recommended_paths = paths;
224        self
225    }
226}
227
228impl LearnedComponent for LearnedDepGraph {
229    fn component_id() -> &'static str {
230        "dep_graph"
231    }
232
233    fn confidence(&self) -> f64 {
234        self.confidence
235    }
236
237    fn session_count(&self) -> usize {
238        self.learned_from.len()
239    }
240
241    fn updated_at(&self) -> u64 {
242        self.updated_at
243    }
244
245    fn merge(&mut self, other: &Self) {
246        // セッション数と信頼度を考慮してマージ
247        if other.learned_from.len() > self.learned_from.len() || other.confidence > self.confidence
248        {
249            self.graph = other.graph.clone();
250            self.action_order = other.action_order.clone();
251            self.confidence = other.confidence;
252        }
253        // セッション ID は結合
254        for id in &other.learned_from {
255            if !self.learned_from.contains(id) {
256                self.learned_from.push(id.clone());
257            }
258        }
259        // 推奨パスはマージ
260        for path in &other.recommended_paths {
261            if !self
262                .recommended_paths
263                .iter()
264                .any(|p| p.actions == path.actions)
265            {
266                self.recommended_paths.push(path.clone());
267            }
268        }
269        self.updated_at = other.updated_at.max(self.updated_at);
270    }
271}
272
273// ============================================================================
274// LearnedExploration - 学習済み探索パラメータ
275// ============================================================================
276
277/// 学習済み探索パラメータ
278///
279/// UCB1 の探索係数、学習重みなど、探索戦略のパラメータ。
280#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
281pub struct LearnedExploration {
282    /// UCB1 探索係数
283    pub ucb1_c: f64,
284
285    /// 学習重み (0.0 - 1.0)
286    pub learning_weight: f64,
287
288    /// N-gram 重み
289    pub ngram_weight: f64,
290
291    /// 信頼度
292    pub confidence: f64,
293
294    /// 学習セッション数
295    pub session_count: usize,
296
297    /// 最終更新
298    pub updated_at: u64,
299}
300
301impl Default for LearnedExploration {
302    fn default() -> Self {
303        Self {
304            ucb1_c: 1.414,
305            learning_weight: 0.3,
306            ngram_weight: 1.0,
307            confidence: 0.0,
308            session_count: 0,
309            updated_at: 0,
310        }
311    }
312}
313
314impl LearnedExploration {
315    /// 新規作成
316    pub fn new(ucb1_c: f64, learning_weight: f64, ngram_weight: f64) -> Self {
317        Self {
318            ucb1_c,
319            learning_weight,
320            ngram_weight,
321            confidence: 0.0,
322            session_count: 0,
323            updated_at: std::time::SystemTime::now()
324                .duration_since(std::time::UNIX_EPOCH)
325                .map(|d| d.as_secs())
326                .unwrap_or(0),
327        }
328    }
329}
330
331impl LearnedComponent for LearnedExploration {
332    fn component_id() -> &'static str {
333        "exploration"
334    }
335
336    fn confidence(&self) -> f64 {
337        self.confidence
338    }
339
340    fn session_count(&self) -> usize {
341        self.session_count
342    }
343
344    fn updated_at(&self) -> u64 {
345        self.updated_at
346    }
347}
348
349// ============================================================================
350// LearnedStrategy - 学習済み戦略設定
351// ============================================================================
352
353/// 学習済み戦略設定
354///
355/// 初期戦略の選択、戦略切り替えの閾値など。
356#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
357pub struct LearnedStrategy {
358    /// 初期選択戦略
359    pub initial_strategy: String,
360
361    /// 成熟度閾値(何回実行後に戦略切り替えを検討するか)
362    pub maturity_threshold: usize,
363
364    /// エラー率閾値(これを超えたら戦略切り替え)
365    pub error_rate_threshold: f64,
366
367    /// 信頼度
368    pub confidence: f64,
369
370    /// 学習セッション数
371    pub session_count: usize,
372
373    /// 最終更新
374    pub updated_at: u64,
375}
376
377impl Default for LearnedStrategy {
378    fn default() -> Self {
379        Self {
380            initial_strategy: "ucb1".to_string(),
381            maturity_threshold: 5,
382            error_rate_threshold: 0.45,
383            confidence: 0.0,
384            session_count: 0,
385            updated_at: 0,
386        }
387    }
388}
389
390impl LearnedComponent for LearnedStrategy {
391    fn component_id() -> &'static str {
392        "strategy"
393    }
394
395    fn confidence(&self) -> f64 {
396        self.confidence
397    }
398
399    fn session_count(&self) -> usize {
400        self.session_count
401    }
402
403    fn updated_at(&self) -> u64 {
404        self.updated_at
405    }
406}
407
408// ============================================================================
409// Tests
410// ============================================================================
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use crate::exploration::DependencyGraph;
416
417    #[test]
418    fn test_learned_dep_graph_creation() {
419        let graph = DependencyGraph::new();
420        let learned = LearnedDepGraph::new(graph, vec!["A".to_string(), "B".to_string()])
421            .with_confidence(0.8)
422            .with_sessions(vec!["s1".to_string(), "s2".to_string()]);
423
424        assert_eq!(learned.confidence(), 0.8);
425        assert_eq!(learned.session_count(), 2);
426        assert_eq!(LearnedDepGraph::component_id(), "dep_graph");
427    }
428
429    #[test]
430    fn test_learned_dep_graph_merge() {
431        let graph = DependencyGraph::new();
432        let mut learned1 = LearnedDepGraph::new(graph.clone(), vec!["A".to_string()])
433            .with_confidence(0.5)
434            .with_sessions(vec!["s1".to_string()]);
435
436        let learned2 = LearnedDepGraph::new(graph, vec!["A".to_string(), "B".to_string()])
437            .with_confidence(0.8)
438            .with_sessions(vec!["s2".to_string(), "s3".to_string()]);
439
440        learned1.merge(&learned2);
441
442        // Higher confidence wins for graph/order
443        assert_eq!(learned1.confidence, 0.8);
444        assert_eq!(learned1.action_order.len(), 2);
445        // Sessions are combined
446        assert_eq!(learned1.learned_from.len(), 3);
447    }
448
449    #[test]
450    fn test_learned_exploration_default() {
451        let exploration = LearnedExploration::default();
452        assert_eq!(exploration.ucb1_c, 1.414);
453        assert_eq!(LearnedExploration::component_id(), "exploration");
454    }
455
456    #[test]
457    fn test_learned_strategy_default() {
458        let strategy = LearnedStrategy::default();
459        assert_eq!(strategy.initial_strategy, "ucb1");
460        assert_eq!(LearnedStrategy::component_id(), "strategy");
461    }
462
463    #[test]
464    fn test_serialization() {
465        let exploration = LearnedExploration::new(2.0, 0.5, 1.5);
466        let json = serde_json::to_string(&exploration).unwrap();
467        let restored: LearnedExploration = serde_json::from_str(&json).unwrap();
468        assert_eq!(restored.ucb1_c, 2.0);
469    }
470}