Skip to main content

swarm_engine_core/learn/
component_learners.rs

1//! ComponentLearner 具体実装
2//!
3//! ScenarioProfile の各コンポーネントを学習する具体的な実装。
4//!
5//! ## 実装一覧
6//!
7//! | Learner | Output | 学習元 |
8//! |---------|--------|--------|
9//! | `DepGraphLearner` | `LearnedDepGraph` | 成功 Episode のアクション系列 |
10//! | `ExplorationLearner` | `LearnedExploration` | セッション統計 |
11//! | `StrategyLearner` | `LearnedStrategy` | 戦略切り替え履歴 |
12
13use std::collections::HashMap;
14
15use super::episode::Episode;
16use super::learn_model::LearnError;
17use super::learned_component::{
18    ComponentLearner, LearnedDepGraph, LearnedExploration, LearnedStrategy,
19};
20use super::record::ActionRecord;
21use super::RecommendedPath;
22use crate::exploration::DependencyGraph;
23
24// ============================================================================
25// DepGraphLearner - 依存グラフ学習
26// ============================================================================
27
28/// 依存グラフ学習器
29///
30/// 成功した Episode のアクション系列から依存グラフを学習する。
31///
32/// ## 学習ロジック
33///
34/// 1. 成功 Episode からアクション系列を抽出
35/// 2. アクション間の出現順序を統計
36/// 3. 高頻度の順序関係から DependencyGraph を構築
37/// 4. 推奨パスを成功率順にソート
38#[derive(Debug, Clone, Default)]
39pub struct DepGraphLearner {
40    /// 最小成功 Episode 数(これ未満だと低信頼度)
41    pub min_episodes: usize,
42    /// 順序関係の最小出現回数
43    pub min_order_count: usize,
44}
45
46impl DepGraphLearner {
47    /// 新規作成
48    pub fn new() -> Self {
49        Self {
50            min_episodes: 3,
51            min_order_count: 2,
52        }
53    }
54
55    /// 最小 Episode 数を設定
56    pub fn with_min_episodes(mut self, n: usize) -> Self {
57        self.min_episodes = n;
58        self
59    }
60
61    /// アクション系列から順序関係を抽出
62    fn extract_order_relations(
63        &self,
64        action_sequences: &[Vec<String>],
65    ) -> HashMap<(String, String), usize> {
66        let mut relations: HashMap<(String, String), usize> = HashMap::new();
67
68        for sequence in action_sequences {
69            // 各ペアの順序関係をカウント
70            for i in 0..sequence.len() {
71                for j in (i + 1)..sequence.len() {
72                    let key = (sequence[i].clone(), sequence[j].clone());
73                    *relations.entry(key).or_insert(0) += 1;
74                }
75            }
76        }
77
78        relations
79    }
80
81    /// 順序関係からアクション順序を計算
82    ///
83    /// 各アクションの「先行回数」でソートすることで、
84    /// 頻繁に先に来るアクションを前に配置
85    fn compute_action_order(&self, relations: &HashMap<(String, String), usize>) -> Vec<String> {
86        // 各アクションの先行スコアを計算
87        let mut scores: HashMap<String, i64> = HashMap::new();
88
89        for ((from, to), &count) in relations {
90            // from は to より count 回先に来た
91            *scores.entry(from.clone()).or_insert(0) += count as i64;
92            *scores.entry(to.clone()).or_insert(0) -= count as i64;
93        }
94
95        // スコア降順でソート
96        let mut actions: Vec<_> = scores.into_iter().collect();
97        actions.sort_by(|a, b| b.1.cmp(&a.1));
98
99        actions.into_iter().map(|(action, _)| action).collect()
100    }
101
102    /// 推奨パスを計算
103    fn compute_recommended_paths(
104        &self,
105        success_count: &HashMap<Vec<String>, usize>,
106        total_success: usize,
107    ) -> Vec<RecommendedPath> {
108        let mut paths: Vec<_> = success_count
109            .iter()
110            .map(|(actions, &count)| {
111                let success_rate = count as f64 / total_success.max(1) as f64;
112                RecommendedPath {
113                    actions: actions.clone(),
114                    success_rate,
115                    observations: count as u32,
116                }
117            })
118            .collect();
119
120        // 成功率でソート(降順)
121        paths.sort_by(|a, b| {
122            b.success_rate
123                .partial_cmp(&a.success_rate)
124                .unwrap_or(std::cmp::Ordering::Equal)
125        });
126
127        // 上位 10 パスに制限
128        paths.truncate(10);
129        paths
130    }
131}
132
133impl ComponentLearner for DepGraphLearner {
134    type Output = LearnedDepGraph;
135
136    fn name(&self) -> &str {
137        "dep_graph_learner"
138    }
139
140    fn objective(&self) -> &str {
141        "Learn action dependency graph from successful execution traces"
142    }
143
144    fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError> {
145        // 成功 Episode のみ抽出
146        let success_episodes: Vec<_> = episodes.iter().filter(|e| e.outcome.is_success()).collect();
147
148        if success_episodes.is_empty() {
149            return Err(LearnError::InsufficientData(
150                "No successful episodes to learn from".into(),
151            ));
152        }
153
154        // アクション系列を抽出
155        let mut action_sequences: Vec<Vec<String>> = Vec::new();
156        let mut success_count: HashMap<Vec<String>, usize> = HashMap::new();
157        let mut session_ids: Vec<String> = Vec::new();
158
159        for episode in &success_episodes {
160            // Episode の context から ActionRecord を取得
161            let actions: Vec<String> = episode
162                .context
163                .iter::<ActionRecord>()
164                .map(|r| r.action.clone())
165                .collect();
166
167            if !actions.is_empty() {
168                *success_count.entry(actions.clone()).or_insert(0) += 1;
169                action_sequences.push(actions);
170            }
171
172            // Episode ID をセッション ID として使用
173            let episode_id = episode.id.to_string();
174            if !session_ids.contains(&episode_id) {
175                session_ids.push(episode_id);
176            }
177        }
178
179        // 順序関係を抽出
180        let relations = self.extract_order_relations(&action_sequences);
181
182        // アクション順序を計算
183        let action_order = self.compute_action_order(&relations);
184
185        // 推奨パスを計算
186        let recommended_paths =
187            self.compute_recommended_paths(&success_count, success_episodes.len());
188
189        // 信頼度を計算(成功 Episode 数に基づく)
190        let confidence = if success_episodes.len() >= self.min_episodes {
191            (success_episodes.len() as f64 / (self.min_episodes as f64 * 2.0)).min(1.0)
192        } else {
193            success_episodes.len() as f64 / self.min_episodes as f64
194        };
195
196        // 空の DependencyGraph を作成(実際のグラフは別途構築)
197        let graph = DependencyGraph::new();
198
199        Ok(LearnedDepGraph::new(graph, action_order)
200            .with_confidence(confidence)
201            .with_sessions(session_ids)
202            .with_recommended_paths(recommended_paths))
203    }
204}
205
206// ============================================================================
207// ExplorationLearner - 探索パラメータ学習
208// ============================================================================
209
210/// 探索パラメータ学習器
211///
212/// セッション統計から最適な探索パラメータを学習する。
213#[derive(Debug, Clone, Default)]
214pub struct ExplorationLearner {
215    /// 初期 UCB1 係数
216    pub initial_ucb1_c: f64,
217}
218
219impl ExplorationLearner {
220    /// 新規作成
221    pub fn new() -> Self {
222        Self {
223            initial_ucb1_c: 1.414,
224        }
225    }
226}
227
228impl ComponentLearner for ExplorationLearner {
229    type Output = LearnedExploration;
230
231    fn name(&self) -> &str {
232        "exploration_learner"
233    }
234
235    fn objective(&self) -> &str {
236        "Optimize exploration parameters from session statistics"
237    }
238
239    fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError> {
240        if episodes.is_empty() {
241            return Err(LearnError::InsufficientData(
242                "No episodes to learn from".into(),
243            ));
244        }
245
246        // 成功/失敗率を計算
247        let total = episodes.len();
248        let success = episodes.iter().filter(|e| e.outcome.is_success()).count();
249        let success_rate = success as f64 / total as f64;
250
251        // 成功率に基づいて UCB1 係数を調整
252        // 低成功率 → より探索的に(ucb1_c 大きく)
253        // 高成功率 → より搾取的に(ucb1_c 小さく)
254        let ucb1_c = if success_rate < 0.3 {
255            2.0 // 探索重視
256        } else if success_rate < 0.7 {
257            1.414 // バランス
258        } else {
259            1.0 // 搾取重視
260        };
261
262        // 信頼度(サンプル数に基づく)
263        let confidence = (total as f64 / 10.0).min(1.0);
264
265        Ok(LearnedExploration {
266            ucb1_c,
267            learning_weight: 0.3,
268            ngram_weight: 1.0,
269            confidence,
270            session_count: total,
271            updated_at: std::time::SystemTime::now()
272                .duration_since(std::time::UNIX_EPOCH)
273                .map(|d| d.as_secs())
274                .unwrap_or(0),
275        })
276    }
277}
278
279// ============================================================================
280// StrategyLearner - 戦略設定学習
281// ============================================================================
282
283/// 戦略設定学習器
284///
285/// 戦略切り替え履歴から最適な戦略設定を学習する。
286#[derive(Debug, Clone, Default)]
287pub struct StrategyLearner;
288
289impl StrategyLearner {
290    /// 新規作成
291    pub fn new() -> Self {
292        Self
293    }
294}
295
296impl ComponentLearner for StrategyLearner {
297    type Output = LearnedStrategy;
298
299    fn name(&self) -> &str {
300        "strategy_learner"
301    }
302
303    fn objective(&self) -> &str {
304        "Determine optimal strategy selection settings"
305    }
306
307    fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError> {
308        if episodes.is_empty() {
309            return Err(LearnError::InsufficientData(
310                "No episodes to learn from".into(),
311            ));
312        }
313
314        let total = episodes.len();
315        let success = episodes.iter().filter(|e| e.outcome.is_success()).count();
316        let success_rate = success as f64 / total as f64;
317
318        // 成功率に基づいて戦略を選択
319        let initial_strategy = if success_rate < 0.5 {
320            "ucb1".to_string() // 探索重視
321        } else {
322            "greedy".to_string() // 搾取重視
323        };
324
325        // エラー率閾値を調整
326        let error_rate_threshold = if success_rate < 0.3 {
327            0.6 // 緩め(切り替えにくい)
328        } else {
329            0.45 // 標準
330        };
331
332        let confidence = (total as f64 / 10.0).min(1.0);
333
334        Ok(LearnedStrategy {
335            initial_strategy,
336            maturity_threshold: 5,
337            error_rate_threshold,
338            confidence,
339            session_count: total,
340            updated_at: std::time::SystemTime::now()
341                .duration_since(std::time::UNIX_EPOCH)
342                .map(|d| d.as_secs())
343                .unwrap_or(0),
344        })
345    }
346}
347
348// ============================================================================
349// Tests
350// ============================================================================
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::learn::episode::{Episode, EpisodeContext, Outcome};
356
357    fn make_success_episode(_actions: Vec<&str>) -> Episode {
358        let context = EpisodeContext::new();
359        // Note: 実際には ActionRecord を context に入れる必要があるが、
360        // テストでは簡略化
361        Episode::builder()
362            .learn_model("test")
363            .context(context)
364            .outcome(Outcome::success(1.0))
365            .build()
366    }
367
368    fn make_failure_episode() -> Episode {
369        Episode::builder()
370            .learn_model("test")
371            .context(EpisodeContext::new())
372            .outcome(Outcome::failure("test failure"))
373            .build()
374    }
375
376    #[test]
377    fn test_dep_graph_learner_empty() {
378        let learner = DepGraphLearner::new();
379        let result = learner.learn(&[]);
380        assert!(result.is_err());
381    }
382
383    #[test]
384    fn test_dep_graph_learner_no_success() {
385        let learner = DepGraphLearner::new();
386        let episodes = vec![make_failure_episode(), make_failure_episode()];
387        let result = learner.learn(&episodes);
388        assert!(result.is_err());
389    }
390
391    #[test]
392    fn test_dep_graph_learner_with_success() {
393        let learner = DepGraphLearner::new();
394        let episodes = vec![
395            make_success_episode(vec!["A", "B", "C"]),
396            make_success_episode(vec!["A", "B", "C"]),
397            make_success_episode(vec!["A", "B", "C"]),
398        ];
399        let result = learner.learn(&episodes);
400        assert!(result.is_ok());
401
402        let learned = result.unwrap();
403        assert!(learned.confidence > 0.0);
404    }
405
406    #[test]
407    fn test_exploration_learner() {
408        let learner = ExplorationLearner::new();
409        let episodes = vec![
410            make_success_episode(vec![]),
411            make_success_episode(vec![]),
412            make_failure_episode(),
413        ];
414        let result = learner.learn(&episodes);
415        assert!(result.is_ok());
416
417        let learned = result.unwrap();
418        assert!(learned.ucb1_c > 0.0);
419        assert_eq!(learned.session_count, 3);
420    }
421
422    #[test]
423    fn test_strategy_learner() {
424        let learner = StrategyLearner::new();
425        let episodes = vec![make_success_episode(vec![]), make_failure_episode()];
426        let result = learner.learn(&episodes);
427        assert!(result.is_ok());
428
429        let learned = result.unwrap();
430        assert!(!learned.initial_strategy.is_empty());
431    }
432
433    #[test]
434    fn test_extract_order_relations() {
435        let learner = DepGraphLearner::new().with_min_episodes(1);
436
437        let sequences = vec![
438            vec!["A".to_string(), "B".to_string(), "C".to_string()],
439            vec!["A".to_string(), "B".to_string(), "C".to_string()],
440        ];
441
442        let relations = learner.extract_order_relations(&sequences);
443
444        // A→B, A→C, B→C の順序関係が各2回ずつ出現
445        assert_eq!(relations.get(&("A".to_string(), "B".to_string())), Some(&2));
446        assert_eq!(relations.get(&("A".to_string(), "C".to_string())), Some(&2));
447        assert_eq!(relations.get(&("B".to_string(), "C".to_string())), Some(&2));
448    }
449}