Skip to main content

swarm_engine_core/learn/
offline.rs

1//! Offline Learning - セッション間学習の分析・最適化
2//!
3//! 複数セッションの統計データを分析し、最適なパラメータや方針を導出する。
4//!
5//! # アーキテクチャ
6//!
7//! ```text
8//! LearningStore (sessions/*.json)
9//!      ↓
10//! OfflineAnalyzer
11//!  ├── analyze_parameters() → OptimalParameters
12//!  ├── extract_paths() → RecommendedPaths
13//!  └── evaluate_strategies() → StrategyConfig
14//!      ↓
15//! OfflineModel (保存)
16//!      ↓
17//! 次回セッション開始時に読み込み → Orchestrator/Provider に反映
18//! ```
19//!
20//! # 使用例
21//!
22//! ```ignore
23//! use swarm_engine_core::learn::{LearningStore, OfflineAnalyzer, OfflineModel};
24//!
25//! // 履歴データを分析
26//! let store = LearningStore::new("./learning")?;
27//! let snapshots = store.query_latest("my-scenario", 10)?;
28//! let analyzer = OfflineAnalyzer::new(&snapshots);
29//!
30//! // 最適パラメータを算出
31//! let model = analyzer.analyze();
32//!
33//! // 保存
34//! store.save_offline_model("my-scenario", &model)?;
35//!
36//! // 次回セッションで読み込み
37//! let model = store.load_offline_model("my-scenario")?;
38//! builder.with_offline_model(model)
39//! ```
40
41use std::collections::HashMap;
42
43use serde::{Deserialize, Serialize};
44
45use super::snapshot::LearningSnapshot;
46
47/// Offline 学習モデル
48///
49/// 複数セッションの分析結果を保持し、次回セッションに適用する。
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct OfflineModel {
52    /// モデルバージョン
53    pub version: u32,
54    /// 最適化されたパラメータ
55    pub parameters: OptimalParameters,
56    /// 推奨アクションパス(成功率順)
57    pub recommended_paths: Vec<RecommendedPath>,
58    /// Selection 戦略設定
59    pub strategy_config: StrategyConfig,
60    /// 分析に使用したセッション数
61    pub analyzed_sessions: usize,
62    /// 最終更新タイムスタンプ
63    pub updated_at: u64,
64    /// 学習済みアクション順序(DependencyGraph キャッシュ用)
65    #[serde(default)]
66    pub action_order: Option<LearnedActionOrder>,
67}
68
69/// 学習済みアクション順序
70///
71/// DependencyGraph を LLM なしで即座に構築するためのキャッシュ。
72/// 同じアクション集合であれば、LLM を呼ばずにグラフを生成できる。
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct LearnedActionOrder {
75    /// Discover(NodeExpand)アクションの順序
76    pub discover: Vec<String>,
77    /// NotDiscover(NodeStateChange)アクションの順序
78    pub not_discover: Vec<String>,
79    /// アクション集合のハッシュ(キャッシュヒット判定用)
80    ///
81    /// ハッシュが一致すれば同じアクション集合とみなす。
82    pub action_set_hash: u64,
83    /// 生成元の情報(デバッグ用)
84    #[serde(default)]
85    pub source: ActionOrderSource,
86}
87
88/// アクション順序の生成元
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
90pub enum ActionOrderSource {
91    /// LLM により生成
92    #[default]
93    Llm,
94    /// 静的パターン
95    Static,
96    /// ユーザー定義
97    Manual,
98}
99
100impl LearnedActionOrder {
101    /// 新しい LearnedActionOrder を作成
102    pub fn new(discover: Vec<String>, not_discover: Vec<String>, actions: &[String]) -> Self {
103        Self {
104            discover,
105            not_discover,
106            action_set_hash: Self::compute_hash(actions),
107            source: ActionOrderSource::Llm,
108        }
109    }
110
111    /// アクション集合のハッシュを計算
112    ///
113    /// アクション名をソートしてハッシュすることで、順序に依存しないハッシュを生成。
114    pub fn compute_hash(actions: &[String]) -> u64 {
115        use std::collections::hash_map::DefaultHasher;
116        use std::hash::{Hash, Hasher};
117
118        let mut sorted: Vec<&str> = actions.iter().map(|s| s.as_str()).collect();
119        sorted.sort();
120
121        let mut hasher = DefaultHasher::new();
122        for action in sorted {
123            action.hash(&mut hasher);
124        }
125        hasher.finish()
126    }
127
128    /// アクション集合が一致するか判定
129    pub fn matches_actions(&self, actions: &[String]) -> bool {
130        self.action_set_hash == Self::compute_hash(actions)
131    }
132}
133
134/// 最適化されたパラメータ
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct OptimalParameters {
137    /// UCB1 の探索係数
138    pub ucb1_c: f64,
139    /// 学習ボーナス係数
140    pub learning_weight: f64,
141    /// N-gram ボーナス係数(trigram の重み)
142    pub ngram_weight: f64,
143}
144
145impl Default for OptimalParameters {
146    fn default() -> Self {
147        Self {
148            ucb1_c: std::f64::consts::SQRT_2,
149            learning_weight: 0.3,
150            ngram_weight: 1.0,
151        }
152    }
153}
154
155/// 推奨アクションパス
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct RecommendedPath {
158    /// アクションシーケンス
159    pub actions: Vec<String>,
160    /// 成功率
161    pub success_rate: f64,
162    /// 観測回数
163    pub observations: u32,
164}
165
166/// Selection 戦略設定
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct StrategyConfig {
169    /// 成熟判定の閾値(これ以上の訪問で成熟)
170    pub maturity_threshold: u32,
171    /// エラー率の閾値(これ以上なら Thompson)
172    pub error_rate_threshold: f64,
173    /// 推奨初期戦略
174    pub initial_strategy: String,
175}
176
177impl Default for StrategyConfig {
178    fn default() -> Self {
179        Self {
180            maturity_threshold: 10,
181            error_rate_threshold: 0.3,
182            initial_strategy: "ucb1".to_string(),
183        }
184    }
185}
186
187impl Default for OfflineModel {
188    fn default() -> Self {
189        Self {
190            version: 1,
191            parameters: OptimalParameters::default(),
192            recommended_paths: Vec::new(),
193            strategy_config: StrategyConfig::default(),
194            analyzed_sessions: 0,
195            updated_at: 0,
196            action_order: None,
197        }
198    }
199}
200
201/// Offline 分析器
202///
203/// 複数の LearningSnapshot を分析し、最適なパラメータや方針を導出する。
204pub struct OfflineAnalyzer<'a> {
205    snapshots: &'a [LearningSnapshot],
206}
207
208impl<'a> OfflineAnalyzer<'a> {
209    /// 新しい分析器を作成
210    pub fn new(snapshots: &'a [LearningSnapshot]) -> Self {
211        Self { snapshots }
212    }
213
214    /// 全ての分析を実行して OfflineModel を生成
215    pub fn analyze(&self) -> OfflineModel {
216        let now = std::time::SystemTime::now()
217            .duration_since(std::time::UNIX_EPOCH)
218            .map(|d| d.as_secs())
219            .unwrap_or(0);
220
221        OfflineModel {
222            version: 1,
223            parameters: self.analyze_parameters(),
224            recommended_paths: self.extract_paths(),
225            strategy_config: self.analyze_strategy(),
226            analyzed_sessions: self.snapshots.len(),
227            updated_at: now,
228            action_order: None, // 別途設定される
229        }
230    }
231
232    /// パラメータ最適化
233    ///
234    /// 履歴データから最適な UCB1 c, learning_weight 等を算出。
235    /// 現在は統計ベースのヒューリスティックを使用。
236    pub fn analyze_parameters(&self) -> OptimalParameters {
237        if self.snapshots.is_empty() {
238            return OptimalParameters::default();
239        }
240
241        // 成功率を計算
242        let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
243            (
244                acc.0 + s.episode_transitions.success_episodes,
245                acc.1 + s.episode_transitions.failure_episodes,
246            )
247        });
248
249        let success_rate = if total_success + total_failure > 0 {
250            total_success as f64 / (total_success + total_failure) as f64
251        } else {
252            0.5
253        };
254
255        // 成功率に基づいて UCB1 c を調整
256        // - 成功率が高い(>0.8): 活用重視 → c を下げる
257        // - 成功率が低い(<0.5): 探索重視 → c を上げる
258        let ucb1_c = if success_rate > 0.8 {
259            1.0 // 活用重視
260        } else if success_rate < 0.5 {
261            2.0 // 探索重視
262        } else {
263            std::f64::consts::SQRT_2 // バランス
264        };
265
266        // N-gram データの有効性を評価
267        let ngram_effectiveness = self.evaluate_ngram_effectiveness();
268        let ngram_weight = if ngram_effectiveness > 0.7 {
269            1.5 // N-gram が有効なら重みを上げる
270        } else if ngram_effectiveness < 0.3 {
271            0.5 // N-gram が効かないなら重みを下げる
272        } else {
273            1.0
274        };
275
276        OptimalParameters {
277            ucb1_c,
278            learning_weight: 0.3, // 現状は固定
279            ngram_weight,
280        }
281    }
282
283    /// N-gram の有効性を評価
284    ///
285    /// trigram の成功率分散が大きいほど、N-gram が選択に有効。
286    fn evaluate_ngram_effectiveness(&self) -> f64 {
287        let mut all_rates: Vec<f64> = Vec::new();
288
289        for snapshot in self.snapshots {
290            for &(success, failure) in snapshot.ngram_stats.trigrams.values() {
291                let total = success + failure;
292                if total >= 3 {
293                    // 最低3回以上の観測
294                    all_rates.push(success as f64 / total as f64);
295                }
296            }
297        }
298
299        if all_rates.is_empty() {
300            return 0.5; // データ不足
301        }
302
303        // 分散を計算(大きいほど識別力がある)
304        let mean = all_rates.iter().sum::<f64>() / all_rates.len() as f64;
305        let variance =
306            all_rates.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / all_rates.len() as f64;
307
308        // 分散を [0, 1] にスケール(0.25 が最大分散)
309        (variance / 0.25).min(1.0)
310    }
311
312    /// 推奨パスを抽出
313    ///
314    /// 成功エピソードで頻出するアクションシーケンスを抽出。
315    pub fn extract_paths(&self) -> Vec<RecommendedPath> {
316        // trigram から成功率の高いパスを抽出
317        let mut path_stats: HashMap<Vec<String>, (u32, u32)> = HashMap::new();
318
319        for snapshot in self.snapshots {
320            for (key, &(success, failure)) in &snapshot.ngram_stats.trigrams {
321                let path = vec![key.0.clone(), key.1.clone(), key.2.clone()];
322                let entry = path_stats.entry(path).or_insert((0, 0));
323                entry.0 += success;
324                entry.1 += failure;
325            }
326        }
327
328        // 成功率でソートして上位を返す
329        let mut paths: Vec<RecommendedPath> = path_stats
330            .into_iter()
331            .filter(|(_, (s, f))| s + f >= 5) // 最低5回以上の観測
332            .map(|(actions, (success, failure))| {
333                let total = success + failure;
334                RecommendedPath {
335                    actions,
336                    success_rate: success as f64 / total as f64,
337                    observations: total,
338                }
339            })
340            .collect();
341
342        paths.sort_by(|a, b| {
343            b.success_rate
344                .partial_cmp(&a.success_rate)
345                .unwrap_or(std::cmp::Ordering::Equal)
346        });
347
348        paths.into_iter().take(10).collect() // 上位10パス
349    }
350
351    /// 戦略設定を分析
352    ///
353    /// 履歴データから最適な AdaptiveProvider 設定を算出。
354    pub fn analyze_strategy(&self) -> StrategyConfig {
355        if self.snapshots.is_empty() {
356            return StrategyConfig::default();
357        }
358
359        // エラー率の平均を計算
360        let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
361            (
362                acc.0 + s.episode_transitions.success_episodes,
363                acc.1 + s.episode_transitions.failure_episodes,
364            )
365        });
366
367        let avg_error_rate = if total_success + total_failure > 0 {
368            total_failure as f64 / (total_success + total_failure) as f64
369        } else {
370            0.3
371        };
372
373        // 総アクション数から成熟閾値を推定
374        let total_actions: u64 = self
375            .snapshots
376            .iter()
377            .map(|s| s.metadata.total_actions as u64)
378            .sum();
379        let avg_actions = total_actions as f64 / self.snapshots.len().max(1) as f64;
380
381        // 平均アクション数の 10% を成熟閾値に
382        let maturity_threshold = ((avg_actions * 0.1) as u32).clamp(5, 50);
383
384        // 初期戦略の決定
385        let initial_strategy = if avg_error_rate > 0.4 {
386            "thompson" // エラー率高 → 探索重視
387        } else if avg_error_rate < 0.1 {
388            "greedy" // エラー率低 → 活用重視
389        } else {
390            "ucb1" // バランス
391        };
392
393        StrategyConfig {
394            maturity_threshold,
395            error_rate_threshold: (avg_error_rate * 1.5).min(0.5), // 平均の1.5倍を閾値に
396            initial_strategy: initial_strategy.to_string(),
397        }
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    fn create_test_snapshot(success: u32, failure: u32) -> LearningSnapshot {
406        let mut snapshot = LearningSnapshot::empty();
407        snapshot.episode_transitions.success_episodes = success;
408        snapshot.episode_transitions.failure_episodes = failure;
409        snapshot.metadata.total_actions = (success + failure) * 5;
410        snapshot
411    }
412
413    #[test]
414    fn test_analyzer_empty_snapshots() {
415        let snapshots: Vec<LearningSnapshot> = vec![];
416        let analyzer = OfflineAnalyzer::new(&snapshots);
417        let model = analyzer.analyze();
418
419        assert_eq!(model.analyzed_sessions, 0);
420        assert!((model.parameters.ucb1_c - std::f64::consts::SQRT_2).abs() < 0.01);
421    }
422
423    #[test]
424    fn test_analyzer_high_success_rate() {
425        let snapshots = vec![
426            create_test_snapshot(9, 1),
427            create_test_snapshot(8, 2),
428            create_test_snapshot(10, 0),
429        ];
430        let analyzer = OfflineAnalyzer::new(&snapshots);
431        let params = analyzer.analyze_parameters();
432
433        // 成功率が高い → ucb1_c は低め(活用重視)
434        assert!(params.ucb1_c < std::f64::consts::SQRT_2);
435    }
436
437    #[test]
438    fn test_analyzer_low_success_rate() {
439        let snapshots = vec![
440            create_test_snapshot(3, 7),
441            create_test_snapshot(4, 6),
442            create_test_snapshot(2, 8),
443        ];
444        let analyzer = OfflineAnalyzer::new(&snapshots);
445        let params = analyzer.analyze_parameters();
446
447        // 成功率が低い → ucb1_c は高め(探索重視)
448        assert!(params.ucb1_c > std::f64::consts::SQRT_2);
449    }
450
451    #[test]
452    fn test_strategy_config_high_error() {
453        let snapshots = vec![create_test_snapshot(3, 7), create_test_snapshot(4, 6)];
454        let analyzer = OfflineAnalyzer::new(&snapshots);
455        let config = analyzer.analyze_strategy();
456
457        assert_eq!(config.initial_strategy, "thompson");
458    }
459
460    #[test]
461    fn test_strategy_config_low_error() {
462        let snapshots = vec![create_test_snapshot(19, 1), create_test_snapshot(18, 2)];
463        let analyzer = OfflineAnalyzer::new(&snapshots);
464        let config = analyzer.analyze_strategy();
465
466        assert_eq!(config.initial_strategy, "greedy");
467    }
468}