Skip to main content

swarm_engine_core/learn/
offline.rs

1//! Offline Learning - セッション間学習の分析・最適化
2//!
3//! 複数セッションの統計データを分析し、最適なパラメータや方針を導出する。
4//!
5//! # アーキテクチャ
6//!
7//! ```text
8//! LearningStore (sessions/*.json)
9//!      ↓
10//! OfflineAnalyzer
11//!  ├── analyze_parameters() → OptimalParameters
12//!  ├── extract_paths() → RecommendedPaths
13//!  └── evaluate_strategies() → StrategyConfig
14//!      ↓
15//! OfflineModel (保存)
16//!      ↓
17//! 次回セッション開始時に読み込み → Orchestrator/Provider に反映
18//! ```
19//!
20//! # 使用例
21//!
22//! ```ignore
23//! use swarm_engine_core::learn::{LearningStore, OfflineAnalyzer, OfflineModel};
24//!
25//! // 履歴データを分析
26//! let store = LearningStore::new("./learning")?;
27//! let snapshots = store.query_latest("my-scenario", 10)?;
28//! let analyzer = OfflineAnalyzer::new(&snapshots);
29//!
30//! // 最適パラメータを算出
31//! let model = analyzer.analyze();
32//!
33//! // 保存
34//! store.save_offline_model("my-scenario", &model)?;
35//!
36//! // 次回セッションで読み込み
37//! let model = store.load_offline_model("my-scenario")?;
38//! builder.with_offline_model(model)
39//! ```
40
41use std::collections::HashMap;
42
43use serde::{Deserialize, Serialize};
44
45use super::snapshot::LearningSnapshot;
46
47/// Offline 学習モデル
48///
49/// 複数セッションの分析結果を保持し、次回セッションに適用する。
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct OfflineModel {
52    /// モデルバージョン
53    pub version: u32,
54    /// 最適化されたパラメータ
55    pub parameters: OptimalParameters,
56    /// 推奨アクションパス(成功率順)
57    pub recommended_paths: Vec<RecommendedPath>,
58    /// Selection 戦略設定
59    pub strategy_config: StrategyConfig,
60    /// 分析に使用したセッション数
61    pub analyzed_sessions: usize,
62    /// 最終更新タイムスタンプ
63    pub updated_at: u64,
64}
65
66/// 最適化されたパラメータ
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct OptimalParameters {
69    /// UCB1 の探索係数
70    pub ucb1_c: f64,
71    /// 学習ボーナス係数
72    pub learning_weight: f64,
73    /// N-gram ボーナス係数(trigram の重み)
74    pub ngram_weight: f64,
75}
76
77impl Default for OptimalParameters {
78    fn default() -> Self {
79        Self {
80            ucb1_c: std::f64::consts::SQRT_2,
81            learning_weight: 0.3,
82            ngram_weight: 1.0,
83        }
84    }
85}
86
87/// 推奨アクションパス
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct RecommendedPath {
90    /// アクションシーケンス
91    pub actions: Vec<String>,
92    /// 成功率
93    pub success_rate: f64,
94    /// 観測回数
95    pub observations: u32,
96}
97
98/// Selection 戦略設定
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct StrategyConfig {
101    /// 成熟判定の閾値(これ以上の訪問で成熟)
102    pub maturity_threshold: u32,
103    /// エラー率の閾値(これ以上なら Thompson)
104    pub error_rate_threshold: f64,
105    /// 推奨初期戦略
106    pub initial_strategy: String,
107}
108
109impl Default for StrategyConfig {
110    fn default() -> Self {
111        Self {
112            maturity_threshold: 10,
113            error_rate_threshold: 0.3,
114            initial_strategy: "ucb1".to_string(),
115        }
116    }
117}
118
119impl Default for OfflineModel {
120    fn default() -> Self {
121        Self {
122            version: 1,
123            parameters: OptimalParameters::default(),
124            recommended_paths: Vec::new(),
125            strategy_config: StrategyConfig::default(),
126            analyzed_sessions: 0,
127            updated_at: 0,
128        }
129    }
130}
131
132/// Offline 分析器
133///
134/// 複数の LearningSnapshot を分析し、最適なパラメータや方針を導出する。
135pub struct OfflineAnalyzer<'a> {
136    snapshots: &'a [LearningSnapshot],
137}
138
139impl<'a> OfflineAnalyzer<'a> {
140    /// 新しい分析器を作成
141    pub fn new(snapshots: &'a [LearningSnapshot]) -> Self {
142        Self { snapshots }
143    }
144
145    /// 全ての分析を実行して OfflineModel を生成
146    pub fn analyze(&self) -> OfflineModel {
147        let now = std::time::SystemTime::now()
148            .duration_since(std::time::UNIX_EPOCH)
149            .map(|d| d.as_secs())
150            .unwrap_or(0);
151
152        OfflineModel {
153            version: 1,
154            parameters: self.analyze_parameters(),
155            recommended_paths: self.extract_paths(),
156            strategy_config: self.analyze_strategy(),
157            analyzed_sessions: self.snapshots.len(),
158            updated_at: now,
159        }
160    }
161
162    /// パラメータ最適化
163    ///
164    /// 履歴データから最適な UCB1 c, learning_weight 等を算出。
165    /// 現在は統計ベースのヒューリスティックを使用。
166    pub fn analyze_parameters(&self) -> OptimalParameters {
167        if self.snapshots.is_empty() {
168            return OptimalParameters::default();
169        }
170
171        // 成功率を計算
172        let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
173            (
174                acc.0 + s.episode_transitions.success_episodes,
175                acc.1 + s.episode_transitions.failure_episodes,
176            )
177        });
178
179        let success_rate = if total_success + total_failure > 0 {
180            total_success as f64 / (total_success + total_failure) as f64
181        } else {
182            0.5
183        };
184
185        // 成功率に基づいて UCB1 c を調整
186        // - 成功率が高い(>0.8): 活用重視 → c を下げる
187        // - 成功率が低い(<0.5): 探索重視 → c を上げる
188        let ucb1_c = if success_rate > 0.8 {
189            1.0 // 活用重視
190        } else if success_rate < 0.5 {
191            2.0 // 探索重視
192        } else {
193            std::f64::consts::SQRT_2 // バランス
194        };
195
196        // N-gram データの有効性を評価
197        let ngram_effectiveness = self.evaluate_ngram_effectiveness();
198        let ngram_weight = if ngram_effectiveness > 0.7 {
199            1.5 // N-gram が有効なら重みを上げる
200        } else if ngram_effectiveness < 0.3 {
201            0.5 // N-gram が効かないなら重みを下げる
202        } else {
203            1.0
204        };
205
206        OptimalParameters {
207            ucb1_c,
208            learning_weight: 0.3, // 現状は固定
209            ngram_weight,
210        }
211    }
212
213    /// N-gram の有効性を評価
214    ///
215    /// trigram の成功率分散が大きいほど、N-gram が選択に有効。
216    fn evaluate_ngram_effectiveness(&self) -> f64 {
217        let mut all_rates: Vec<f64> = Vec::new();
218
219        for snapshot in self.snapshots {
220            for &(success, failure) in snapshot.ngram_stats.trigrams.values() {
221                let total = success + failure;
222                if total >= 3 {
223                    // 最低3回以上の観測
224                    all_rates.push(success as f64 / total as f64);
225                }
226            }
227        }
228
229        if all_rates.is_empty() {
230            return 0.5; // データ不足
231        }
232
233        // 分散を計算(大きいほど識別力がある)
234        let mean = all_rates.iter().sum::<f64>() / all_rates.len() as f64;
235        let variance =
236            all_rates.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / all_rates.len() as f64;
237
238        // 分散を [0, 1] にスケール(0.25 が最大分散)
239        (variance / 0.25).min(1.0)
240    }
241
242    /// 推奨パスを抽出
243    ///
244    /// 成功エピソードで頻出するアクションシーケンスを抽出。
245    pub fn extract_paths(&self) -> Vec<RecommendedPath> {
246        // trigram から成功率の高いパスを抽出
247        let mut path_stats: HashMap<Vec<String>, (u32, u32)> = HashMap::new();
248
249        for snapshot in self.snapshots {
250            for (key, &(success, failure)) in &snapshot.ngram_stats.trigrams {
251                let path = vec![key.0.clone(), key.1.clone(), key.2.clone()];
252                let entry = path_stats.entry(path).or_insert((0, 0));
253                entry.0 += success;
254                entry.1 += failure;
255            }
256        }
257
258        // 成功率でソートして上位を返す
259        let mut paths: Vec<RecommendedPath> = path_stats
260            .into_iter()
261            .filter(|(_, (s, f))| s + f >= 5) // 最低5回以上の観測
262            .map(|(actions, (success, failure))| {
263                let total = success + failure;
264                RecommendedPath {
265                    actions,
266                    success_rate: success as f64 / total as f64,
267                    observations: total,
268                }
269            })
270            .collect();
271
272        paths.sort_by(|a, b| {
273            b.success_rate
274                .partial_cmp(&a.success_rate)
275                .unwrap_or(std::cmp::Ordering::Equal)
276        });
277
278        paths.into_iter().take(10).collect() // 上位10パス
279    }
280
281    /// 戦略設定を分析
282    ///
283    /// 履歴データから最適な AdaptiveProvider 設定を算出。
284    pub fn analyze_strategy(&self) -> StrategyConfig {
285        if self.snapshots.is_empty() {
286            return StrategyConfig::default();
287        }
288
289        // エラー率の平均を計算
290        let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
291            (
292                acc.0 + s.episode_transitions.success_episodes,
293                acc.1 + s.episode_transitions.failure_episodes,
294            )
295        });
296
297        let avg_error_rate = if total_success + total_failure > 0 {
298            total_failure as f64 / (total_success + total_failure) as f64
299        } else {
300            0.3
301        };
302
303        // 総アクション数から成熟閾値を推定
304        let total_actions: u64 = self
305            .snapshots
306            .iter()
307            .map(|s| s.metadata.total_actions as u64)
308            .sum();
309        let avg_actions = total_actions as f64 / self.snapshots.len().max(1) as f64;
310
311        // 平均アクション数の 10% を成熟閾値に
312        let maturity_threshold = ((avg_actions * 0.1) as u32).clamp(5, 50);
313
314        // 初期戦略の決定
315        let initial_strategy = if avg_error_rate > 0.4 {
316            "thompson" // エラー率高 → 探索重視
317        } else if avg_error_rate < 0.1 {
318            "greedy" // エラー率低 → 活用重視
319        } else {
320            "ucb1" // バランス
321        };
322
323        StrategyConfig {
324            maturity_threshold,
325            error_rate_threshold: (avg_error_rate * 1.5).min(0.5), // 平均の1.5倍を閾値に
326            initial_strategy: initial_strategy.to_string(),
327        }
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    fn create_test_snapshot(success: u32, failure: u32) -> LearningSnapshot {
336        let mut snapshot = LearningSnapshot::empty();
337        snapshot.episode_transitions.success_episodes = success;
338        snapshot.episode_transitions.failure_episodes = failure;
339        snapshot.metadata.total_actions = (success + failure) * 5;
340        snapshot
341    }
342
343    #[test]
344    fn test_analyzer_empty_snapshots() {
345        let snapshots: Vec<LearningSnapshot> = vec![];
346        let analyzer = OfflineAnalyzer::new(&snapshots);
347        let model = analyzer.analyze();
348
349        assert_eq!(model.analyzed_sessions, 0);
350        assert!((model.parameters.ucb1_c - std::f64::consts::SQRT_2).abs() < 0.01);
351    }
352
353    #[test]
354    fn test_analyzer_high_success_rate() {
355        let snapshots = vec![
356            create_test_snapshot(9, 1),
357            create_test_snapshot(8, 2),
358            create_test_snapshot(10, 0),
359        ];
360        let analyzer = OfflineAnalyzer::new(&snapshots);
361        let params = analyzer.analyze_parameters();
362
363        // 成功率が高い → ucb1_c は低め(活用重視)
364        assert!(params.ucb1_c < std::f64::consts::SQRT_2);
365    }
366
367    #[test]
368    fn test_analyzer_low_success_rate() {
369        let snapshots = vec![
370            create_test_snapshot(3, 7),
371            create_test_snapshot(4, 6),
372            create_test_snapshot(2, 8),
373        ];
374        let analyzer = OfflineAnalyzer::new(&snapshots);
375        let params = analyzer.analyze_parameters();
376
377        // 成功率が低い → ucb1_c は高め(探索重視)
378        assert!(params.ucb1_c > std::f64::consts::SQRT_2);
379    }
380
381    #[test]
382    fn test_strategy_config_high_error() {
383        let snapshots = vec![create_test_snapshot(3, 7), create_test_snapshot(4, 6)];
384        let analyzer = OfflineAnalyzer::new(&snapshots);
385        let config = analyzer.analyze_strategy();
386
387        assert_eq!(config.initial_strategy, "thompson");
388    }
389
390    #[test]
391    fn test_strategy_config_low_error() {
392        let snapshots = vec![create_test_snapshot(19, 1), create_test_snapshot(18, 2)];
393        let analyzer = OfflineAnalyzer::new(&snapshots);
394        let config = analyzer.analyze_strategy();
395
396        assert_eq!(config.initial_strategy, "greedy");
397    }
398}