Skip to main content

swarm_engine_core/learn/
learned_component.rs

1//! LearnedComponent - 学習結果の型安全な抽象化
2//!
3//! ## 設計思想
4//!
5//! ComponentLearner(学習プロセス)と LearnedComponent(学習結果)をペアで定義。
6//! 各学習対象ごとに専用の型を持ち、Map/Any を使わず型安全性を確保する。
7//!
8//! ## 背景
9//!
10//! - 従来の ML: 口調、好み程度 → Map でも許容
11//! - Swarm Learning: Control の知識そのものが Domain → Typed 必須
12//!
13//! ## 使用例
14//!
15//! ```ignore
16//! // ComponentLearner と LearnedComponent はペアで定義
17//! struct DepGraphLearner;
18//!
19//! impl ComponentLearner for DepGraphLearner {
20//!     type Output = LearnedDepGraph;
21//!
22//!     fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError> {
23//!         // Episodes から LearnedDepGraph を生成
24//!     }
25//! }
26//!
27//! impl LearnedComponent for LearnedDepGraph {
28//!     fn component_id() -> &'static str { "dep_graph" }
29//!     // ...
30//! }
31//! ```
32//!
33//! ## LearnModel vs ComponentLearner
34//!
35//! - `LearnModel`: Episode → TrainingData (LoRA fine-tuning 用)
36//! - `ComponentLearner`: Episodes → LearnedComponent (ScenarioProfile 用)
37
38use serde::{de::DeserializeOwned, Serialize};
39
40use super::episode::Episode;
41use super::learn_model::LearnError;
42
43// ============================================================================
44// ComponentLearner Trait
45// ============================================================================
46
47/// ScenarioProfile コンポーネントの学習プロセス
48///
49/// Episode の集合から LearnedComponent を生成する。
50/// LearnModel(LoRA用)とは異なり、ScenarioProfile の各コンポーネントを
51/// 型安全に学習する。
52pub trait ComponentLearner: Send + Sync {
53    /// 学習結果の型
54    type Output: LearnedComponent;
55
56    /// 学習器の名前
57    fn name(&self) -> &str;
58
59    /// 目的の説明
60    fn objective(&self) -> &str;
61
62    /// Episode から学習結果を生成
63    fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError>;
64
65    /// 既存のコンポーネントを更新(増分学習)
66    fn update(
67        &self,
68        existing: &Self::Output,
69        new_episodes: &[Episode],
70    ) -> Result<Self::Output, LearnError> {
71        let mut learned = self.learn(new_episodes)?;
72        learned.merge(existing);
73        Ok(learned)
74    }
75}
76
77// ============================================================================
78// LearnedComponent Trait
79// ============================================================================
80
81/// 学習結果コンポーネントの共通 trait
82///
83/// 各学習対象(DepGraph, Strategy, Exploration 等)の学習結果が実装する。
84/// Typed で管理することで、型安全性と IDE サポートを確保。
85pub trait LearnedComponent: Send + Sync + Serialize + DeserializeOwned + Clone {
86    /// コンポーネント識別子(ファイル名等に使用)
87    fn component_id() -> &'static str
88    where
89        Self: Sized;
90
91    /// 信頼度スコア (0.0 - 1.0)
92    ///
93    /// 学習データ量や品質に基づく信頼度。
94    /// 低い場合は Bootstrap 追加実行を検討。
95    fn confidence(&self) -> f64;
96
97    /// 学習に使用したセッション数
98    fn session_count(&self) -> usize;
99
100    /// 最終更新タイムスタンプ (Unix epoch seconds)
101    fn updated_at(&self) -> u64;
102
103    /// 他のコンポーネントとマージ(増分学習用)
104    ///
105    /// デフォルト実装: 信頼度の高い方を優先
106    fn merge(&mut self, other: &Self)
107    where
108        Self: Sized,
109    {
110        if other.confidence() > self.confidence() {
111            *self = other.clone();
112        }
113    }
114
115    /// バージョン番号(互換性チェック用)
116    fn version() -> u32
117    where
118        Self: Sized,
119    {
120        1
121    }
122}
123
124// ============================================================================
125// LearnedDepGraph - 学習済み依存グラフ
126// ============================================================================
127
128use crate::exploration::DependencyGraph;
129
130// Re-use existing RecommendedPath from offline module
131pub use super::offline::RecommendedPath;
132
133/// 学習済み依存グラフ
134///
135/// Bootstrap フェーズで正解グラフから学習し、
136/// Release フェーズで LLM なしで即座にアクション順序を決定。
137#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
138pub struct LearnedDepGraph {
139    /// 依存グラフ本体
140    pub graph: DependencyGraph,
141
142    /// 学習済みアクション順序(トポロジカルソート済み)
143    pub action_order: Vec<String>,
144
145    /// 推奨パス(成功率順)
146    #[serde(default)]
147    pub recommended_paths: Vec<RecommendedPath>,
148
149    /// 信頼度 (0.0 - 1.0)
150    pub confidence: f64,
151
152    /// 学習に使用したセッション ID
153    pub learned_from: Vec<String>,
154
155    /// 最終更新タイムスタンプ
156    pub updated_at: u64,
157}
158
159impl LearnedDepGraph {
160    /// 新規作成
161    pub fn new(graph: DependencyGraph, action_order: Vec<String>) -> Self {
162        Self {
163            graph,
164            action_order,
165            recommended_paths: Vec::new(),
166            confidence: 0.0,
167            learned_from: Vec::new(),
168            updated_at: std::time::SystemTime::now()
169                .duration_since(std::time::UNIX_EPOCH)
170                .map(|d| d.as_secs())
171                .unwrap_or(0),
172        }
173    }
174
175    /// 信頼度を設定
176    pub fn with_confidence(mut self, confidence: f64) -> Self {
177        self.confidence = confidence;
178        self
179    }
180
181    /// 学習元セッションを追加
182    pub fn with_sessions(mut self, session_ids: Vec<String>) -> Self {
183        self.learned_from = session_ids;
184        self
185    }
186
187    /// 推奨パスを追加
188    pub fn with_recommended_paths(mut self, paths: Vec<RecommendedPath>) -> Self {
189        self.recommended_paths = paths;
190        self
191    }
192}
193
194impl LearnedComponent for LearnedDepGraph {
195    fn component_id() -> &'static str {
196        "dep_graph"
197    }
198
199    fn confidence(&self) -> f64 {
200        self.confidence
201    }
202
203    fn session_count(&self) -> usize {
204        self.learned_from.len()
205    }
206
207    fn updated_at(&self) -> u64 {
208        self.updated_at
209    }
210
211    fn merge(&mut self, other: &Self) {
212        // セッション数と信頼度を考慮してマージ
213        if other.learned_from.len() > self.learned_from.len() || other.confidence > self.confidence
214        {
215            self.graph = other.graph.clone();
216            self.action_order = other.action_order.clone();
217            self.confidence = other.confidence;
218        }
219        // セッション ID は結合
220        for id in &other.learned_from {
221            if !self.learned_from.contains(id) {
222                self.learned_from.push(id.clone());
223            }
224        }
225        // 推奨パスはマージ
226        for path in &other.recommended_paths {
227            if !self
228                .recommended_paths
229                .iter()
230                .any(|p| p.actions == path.actions)
231            {
232                self.recommended_paths.push(path.clone());
233            }
234        }
235        self.updated_at = other.updated_at.max(self.updated_at);
236    }
237}
238
239// ============================================================================
240// LearnedExploration - 学習済み探索パラメータ
241// ============================================================================
242
243/// 学習済み探索パラメータ
244///
245/// UCB1 の探索係数、学習重みなど、探索戦略のパラメータ。
246#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
247pub struct LearnedExploration {
248    /// UCB1 探索係数
249    pub ucb1_c: f64,
250
251    /// 学習重み (0.0 - 1.0)
252    pub learning_weight: f64,
253
254    /// N-gram 重み
255    pub ngram_weight: f64,
256
257    /// 信頼度
258    pub confidence: f64,
259
260    /// 学習セッション数
261    pub session_count: usize,
262
263    /// 最終更新
264    pub updated_at: u64,
265}
266
267impl Default for LearnedExploration {
268    fn default() -> Self {
269        Self {
270            ucb1_c: 1.414,
271            learning_weight: 0.3,
272            ngram_weight: 1.0,
273            confidence: 0.0,
274            session_count: 0,
275            updated_at: 0,
276        }
277    }
278}
279
280impl LearnedExploration {
281    /// 新規作成
282    pub fn new(ucb1_c: f64, learning_weight: f64, ngram_weight: f64) -> Self {
283        Self {
284            ucb1_c,
285            learning_weight,
286            ngram_weight,
287            confidence: 0.0,
288            session_count: 0,
289            updated_at: std::time::SystemTime::now()
290                .duration_since(std::time::UNIX_EPOCH)
291                .map(|d| d.as_secs())
292                .unwrap_or(0),
293        }
294    }
295}
296
297impl LearnedComponent for LearnedExploration {
298    fn component_id() -> &'static str {
299        "exploration"
300    }
301
302    fn confidence(&self) -> f64 {
303        self.confidence
304    }
305
306    fn session_count(&self) -> usize {
307        self.session_count
308    }
309
310    fn updated_at(&self) -> u64 {
311        self.updated_at
312    }
313}
314
315// ============================================================================
316// LearnedStrategy - 学習済み戦略設定
317// ============================================================================
318
319/// 学習済み戦略設定
320///
321/// 初期戦略の選択、戦略切り替えの閾値など。
322#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
323pub struct LearnedStrategy {
324    /// 初期選択戦略
325    pub initial_strategy: String,
326
327    /// 成熟度閾値(何回実行後に戦略切り替えを検討するか)
328    pub maturity_threshold: usize,
329
330    /// エラー率閾値(これを超えたら戦略切り替え)
331    pub error_rate_threshold: f64,
332
333    /// 信頼度
334    pub confidence: f64,
335
336    /// 学習セッション数
337    pub session_count: usize,
338
339    /// 最終更新
340    pub updated_at: u64,
341}
342
343impl Default for LearnedStrategy {
344    fn default() -> Self {
345        Self {
346            initial_strategy: "ucb1".to_string(),
347            maturity_threshold: 5,
348            error_rate_threshold: 0.45,
349            confidence: 0.0,
350            session_count: 0,
351            updated_at: 0,
352        }
353    }
354}
355
356impl LearnedComponent for LearnedStrategy {
357    fn component_id() -> &'static str {
358        "strategy"
359    }
360
361    fn confidence(&self) -> f64 {
362        self.confidence
363    }
364
365    fn session_count(&self) -> usize {
366        self.session_count
367    }
368
369    fn updated_at(&self) -> u64 {
370        self.updated_at
371    }
372}
373
374// ============================================================================
375// Tests
376// ============================================================================
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use crate::exploration::DependencyGraph;
382
383    #[test]
384    fn test_learned_dep_graph_creation() {
385        let graph = DependencyGraph::new();
386        let learned = LearnedDepGraph::new(graph, vec!["A".to_string(), "B".to_string()])
387            .with_confidence(0.8)
388            .with_sessions(vec!["s1".to_string(), "s2".to_string()]);
389
390        assert_eq!(learned.confidence(), 0.8);
391        assert_eq!(learned.session_count(), 2);
392        assert_eq!(LearnedDepGraph::component_id(), "dep_graph");
393    }
394
395    #[test]
396    fn test_learned_dep_graph_merge() {
397        let graph = DependencyGraph::new();
398        let mut learned1 = LearnedDepGraph::new(graph.clone(), vec!["A".to_string()])
399            .with_confidence(0.5)
400            .with_sessions(vec!["s1".to_string()]);
401
402        let learned2 = LearnedDepGraph::new(graph, vec!["A".to_string(), "B".to_string()])
403            .with_confidence(0.8)
404            .with_sessions(vec!["s2".to_string(), "s3".to_string()]);
405
406        learned1.merge(&learned2);
407
408        // Higher confidence wins for graph/order
409        assert_eq!(learned1.confidence, 0.8);
410        assert_eq!(learned1.action_order.len(), 2);
411        // Sessions are combined
412        assert_eq!(learned1.learned_from.len(), 3);
413    }
414
415    #[test]
416    fn test_learned_exploration_default() {
417        let exploration = LearnedExploration::default();
418        assert_eq!(exploration.ucb1_c, 1.414);
419        assert_eq!(LearnedExploration::component_id(), "exploration");
420    }
421
422    #[test]
423    fn test_learned_strategy_default() {
424        let strategy = LearnedStrategy::default();
425        assert_eq!(strategy.initial_strategy, "ucb1");
426        assert_eq!(LearnedStrategy::component_id(), "strategy");
427    }
428
429    #[test]
430    fn test_serialization() {
431        let exploration = LearnedExploration::new(2.0, 0.5, 1.5);
432        let json = serde_json::to_string(&exploration).unwrap();
433        let restored: LearnedExploration = serde_json::from_str(&json).unwrap();
434        assert_eq!(restored.ucb1_c, 2.0);
435    }
436}