Skip to main content

swarm_engine_core/learn/
offline.rs

1//! Offline Learning - セッション間学習の分析・最適化
2//!
3//! 複数セッションの統計データを分析し、最適なパラメータや方針を導出する。
4//!
5//! # アーキテクチャ
6//!
7//! ```text
8//! LearningStore (sessions/*.json)
9//!      ↓
10//! OfflineAnalyzer
11//!  ├── analyze_parameters() → OptimalParameters
12//!  ├── extract_paths() → RecommendedPaths
13//!  └── evaluate_strategies() → StrategyConfig
14//!      ↓
15//! OfflineModel (保存)
16//!      ↓
17//! 次回セッション開始時に読み込み → Orchestrator/Provider に反映
18//! ```
19//!
20//! # 使用例
21//!
22//! ```ignore
23//! use swarm_engine_core::learn::{LearningStore, OfflineAnalyzer, OfflineModel};
24//!
25//! // 履歴データを分析
26//! let store = LearningStore::new("./learning")?;
27//! let snapshots = store.query_latest("my-scenario", 10)?;
28//! let analyzer = OfflineAnalyzer::new(&snapshots);
29//!
30//! // 最適パラメータを算出
31//! let model = analyzer.analyze();
32//!
33//! // 保存
34//! store.save_offline_model("my-scenario", &model)?;
35//!
36//! // 次回セッションで読み込み
37//! let model = store.load_offline_model("my-scenario")?;
38//! builder.with_offline_model(model)
39//! ```
40
41use std::collections::HashMap;
42
43use serde::{Deserialize, Serialize};
44
45use super::snapshot::LearningSnapshot;
46
47/// Offline 学習モデル
48///
49/// 複数セッションの分析結果を保持し、次回セッションに適用する。
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct OfflineModel {
52    /// モデルバージョン
53    pub version: u32,
54    /// 最適化されたパラメータ
55    pub parameters: OptimalParameters,
56    /// 推奨アクションパス(成功率順)
57    pub recommended_paths: Vec<RecommendedPath>,
58    /// Selection 戦略設定
59    pub strategy_config: StrategyConfig,
60    /// 分析に使用したセッション数
61    pub analyzed_sessions: usize,
62    /// 最終更新タイムスタンプ
63    pub updated_at: u64,
64    /// 学習済みアクション順序(DependencyGraph キャッシュ用)
65    #[serde(default)]
66    pub action_order: Option<LearnedActionOrder>,
67}
68
69/// 学習済みアクション順序
70///
71/// DependencyGraph を LLM なしで即座に構築するためのキャッシュ。
72/// 同じアクション集合であれば、LLM を呼ばずにグラフを生成できる。
73///
74/// # フィールド
75///
76/// - `discover` / `not_discover`: アクション実行順序
77/// - `action_set_hash`: アクション集合の識別子(マッチング用)
78/// - `lora`: このエントリに関連付けられた LoRA 設定
79/// - `validated_accuracy`: 検証済み精度(オプション)
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct LearnedActionOrder {
82    /// Discover(NodeExpand)アクションの順序
83    pub discover: Vec<String>,
84    /// NotDiscover(NodeStateChange)アクションの順序
85    pub not_discover: Vec<String>,
86    /// アクション集合のハッシュ(キャッシュヒット判定用)
87    ///
88    /// ハッシュが一致すれば同じアクション集合とみなす。
89    pub action_set_hash: u64,
90    /// 生成元の情報(デバッグ用)
91    #[serde(default)]
92    pub source: ActionOrderSource,
93    /// 対応する LoRA 設定(オプション)
94    #[serde(default)]
95    pub lora: Option<crate::types::LoraConfig>,
96    /// 検証済み精度(オプション)
97    #[serde(default)]
98    pub validated_accuracy: Option<f64>,
99}
100
101/// アクション順序の生成元
102#[derive(Debug, Clone, Default, Serialize, Deserialize)]
103pub enum ActionOrderSource {
104    /// LLM により生成
105    #[default]
106    Llm,
107    /// 静的パターン
108    Static,
109    /// ユーザー定義
110    Manual,
111}
112
113impl LearnedActionOrder {
114    /// 新しい LearnedActionOrder を作成
115    pub fn new(discover: Vec<String>, not_discover: Vec<String>, actions: &[String]) -> Self {
116        Self {
117            discover,
118            not_discover,
119            action_set_hash: Self::compute_hash(actions),
120            source: ActionOrderSource::Llm,
121            lora: None,
122            validated_accuracy: None,
123        }
124    }
125
126    /// LoRA 設定を追加
127    pub fn with_lora(mut self, lora: crate::types::LoraConfig) -> Self {
128        self.lora = Some(lora);
129        self
130    }
131
132    /// 検証済み精度を設定
133    pub fn with_accuracy(mut self, accuracy: f64) -> Self {
134        self.validated_accuracy = Some(accuracy);
135        self
136    }
137
138    /// 生成元を設定
139    pub fn with_source(mut self, source: ActionOrderSource) -> Self {
140        self.source = source;
141        self
142    }
143
144    /// アクション集合のハッシュを計算
145    ///
146    /// アクション名をソートしてハッシュすることで、順序に依存しないハッシュを生成。
147    pub fn compute_hash(actions: &[String]) -> u64 {
148        use std::collections::hash_map::DefaultHasher;
149        use std::hash::{Hash, Hasher};
150
151        let mut sorted: Vec<&str> = actions.iter().map(|s| s.as_str()).collect();
152        sorted.sort();
153
154        let mut hasher = DefaultHasher::new();
155        for action in sorted {
156            action.hash(&mut hasher);
157        }
158        hasher.finish()
159    }
160
161    /// アクション集合が完全一致するか判定(ハッシュで判定)
162    ///
163    /// Note: `matches_actions` は `is_exact_match` のエイリアス(後方互換性)
164    pub fn is_exact_match(&self, actions: &[String]) -> bool {
165        self.action_set_hash == Self::compute_hash(actions)
166    }
167
168    /// `is_exact_match` のエイリアス(後方互換性のため維持)
169    #[inline]
170    pub fn matches_actions(&self, actions: &[String]) -> bool {
171        self.is_exact_match(actions)
172    }
173
174    /// 一致率を計算(Jaccard 係数)
175    pub fn match_rate(&self, actions: &[String]) -> f64 {
176        use std::collections::HashSet;
177
178        let mut self_actions: Vec<String> = self.discover.clone();
179        self_actions.extend(self.not_discover.clone());
180
181        if self_actions.is_empty() && actions.is_empty() {
182            return 1.0;
183        }
184        if self_actions.is_empty() || actions.is_empty() {
185            return 0.0;
186        }
187
188        let self_set: HashSet<_> = self_actions.iter().collect();
189        let other_set: HashSet<_> = actions.iter().collect();
190
191        let intersection = self_set.intersection(&other_set).count();
192        let union = self_set.union(&other_set).count();
193
194        intersection as f64 / union as f64
195    }
196}
197
198/// 最適化されたパラメータ
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct OptimalParameters {
201    /// UCB1 の探索係数
202    pub ucb1_c: f64,
203    /// 学習ボーナス係数
204    pub learning_weight: f64,
205    /// N-gram ボーナス係数(trigram の重み)
206    pub ngram_weight: f64,
207}
208
209impl Default for OptimalParameters {
210    fn default() -> Self {
211        Self {
212            ucb1_c: std::f64::consts::SQRT_2,
213            learning_weight: 0.3,
214            ngram_weight: 1.0,
215        }
216    }
217}
218
219/// 推奨アクションパス
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct RecommendedPath {
222    /// アクションシーケンス
223    pub actions: Vec<String>,
224    /// 成功率
225    pub success_rate: f64,
226    /// 観測回数
227    pub observations: u32,
228}
229
230/// Selection 戦略設定
231#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct StrategyConfig {
233    /// 成熟判定の閾値(これ以上の訪問で成熟)
234    pub maturity_threshold: u32,
235    /// エラー率の閾値(これ以上なら Thompson)
236    pub error_rate_threshold: f64,
237    /// 推奨初期戦略
238    pub initial_strategy: String,
239}
240
241impl Default for StrategyConfig {
242    fn default() -> Self {
243        Self {
244            maturity_threshold: 10,
245            error_rate_threshold: 0.3,
246            initial_strategy: "ucb1".to_string(),
247        }
248    }
249}
250
251impl Default for OfflineModel {
252    fn default() -> Self {
253        Self {
254            version: 1,
255            parameters: OptimalParameters::default(),
256            recommended_paths: Vec::new(),
257            strategy_config: StrategyConfig::default(),
258            analyzed_sessions: 0,
259            updated_at: 0,
260            action_order: None,
261        }
262    }
263}
264
265/// Offline 分析器
266///
267/// 複数の LearningSnapshot を分析し、最適なパラメータや方針を導出する。
268pub struct OfflineAnalyzer<'a> {
269    snapshots: &'a [LearningSnapshot],
270}
271
272impl<'a> OfflineAnalyzer<'a> {
273    /// 新しい分析器を作成
274    pub fn new(snapshots: &'a [LearningSnapshot]) -> Self {
275        Self { snapshots }
276    }
277
278    /// 全ての分析を実行して OfflineModel を生成
279    pub fn analyze(&self) -> OfflineModel {
280        let now = std::time::SystemTime::now()
281            .duration_since(std::time::UNIX_EPOCH)
282            .map(|d| d.as_secs())
283            .unwrap_or(0);
284
285        OfflineModel {
286            version: 1,
287            parameters: self.analyze_parameters(),
288            recommended_paths: self.extract_paths(),
289            strategy_config: self.analyze_strategy(),
290            analyzed_sessions: self.snapshots.len(),
291            updated_at: now,
292            action_order: None, // 別途設定される
293        }
294    }
295
296    /// パラメータ最適化
297    ///
298    /// 履歴データから最適な UCB1 c, learning_weight 等を算出。
299    /// 現在は統計ベースのヒューリスティックを使用。
300    pub fn analyze_parameters(&self) -> OptimalParameters {
301        if self.snapshots.is_empty() {
302            return OptimalParameters::default();
303        }
304
305        // 成功率を計算
306        let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
307            (
308                acc.0 + s.episode_transitions.success_episodes,
309                acc.1 + s.episode_transitions.failure_episodes,
310            )
311        });
312
313        let success_rate = if total_success + total_failure > 0 {
314            total_success as f64 / (total_success + total_failure) as f64
315        } else {
316            0.5
317        };
318
319        // 成功率に基づいて UCB1 c を調整
320        // - 成功率が高い(>0.8): 活用重視 → c を下げる
321        // - 成功率が低い(<0.5): 探索重視 → c を上げる
322        let ucb1_c = if success_rate > 0.8 {
323            1.0 // 活用重視
324        } else if success_rate < 0.5 {
325            2.0 // 探索重視
326        } else {
327            std::f64::consts::SQRT_2 // バランス
328        };
329
330        // N-gram データの有効性を評価
331        let ngram_effectiveness = self.evaluate_ngram_effectiveness();
332        let ngram_weight = if ngram_effectiveness > 0.7 {
333            1.5 // N-gram が有効なら重みを上げる
334        } else if ngram_effectiveness < 0.3 {
335            0.5 // N-gram が効かないなら重みを下げる
336        } else {
337            1.0
338        };
339
340        OptimalParameters {
341            ucb1_c,
342            learning_weight: 0.3, // 現状は固定
343            ngram_weight,
344        }
345    }
346
347    /// N-gram の有効性を評価
348    ///
349    /// trigram の成功率分散が大きいほど、N-gram が選択に有効。
350    fn evaluate_ngram_effectiveness(&self) -> f64 {
351        let mut all_rates: Vec<f64> = Vec::new();
352
353        for snapshot in self.snapshots {
354            for &(success, failure) in snapshot.ngram_stats.trigrams.values() {
355                let total = success + failure;
356                if total >= 3 {
357                    // 最低3回以上の観測
358                    all_rates.push(success as f64 / total as f64);
359                }
360            }
361        }
362
363        if all_rates.is_empty() {
364            return 0.5; // データ不足
365        }
366
367        // 分散を計算(大きいほど識別力がある)
368        let mean = all_rates.iter().sum::<f64>() / all_rates.len() as f64;
369        let variance =
370            all_rates.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / all_rates.len() as f64;
371
372        // 分散を [0, 1] にスケール(0.25 が最大分散)
373        (variance / 0.25).min(1.0)
374    }
375
376    /// 推奨パスを抽出
377    ///
378    /// 成功エピソードで頻出するアクションシーケンスを抽出。
379    pub fn extract_paths(&self) -> Vec<RecommendedPath> {
380        // trigram から成功率の高いパスを抽出
381        let mut path_stats: HashMap<Vec<String>, (u32, u32)> = HashMap::new();
382
383        for snapshot in self.snapshots {
384            for (key, &(success, failure)) in &snapshot.ngram_stats.trigrams {
385                let path = vec![key.0.clone(), key.1.clone(), key.2.clone()];
386                let entry = path_stats.entry(path).or_insert((0, 0));
387                entry.0 += success;
388                entry.1 += failure;
389            }
390        }
391
392        // 成功率でソートして上位を返す
393        let mut paths: Vec<RecommendedPath> = path_stats
394            .into_iter()
395            .filter(|(_, (s, f))| s + f >= 5) // 最低5回以上の観測
396            .map(|(actions, (success, failure))| {
397                let total = success + failure;
398                RecommendedPath {
399                    actions,
400                    success_rate: success as f64 / total as f64,
401                    observations: total,
402                }
403            })
404            .collect();
405
406        paths.sort_by(|a, b| {
407            b.success_rate
408                .partial_cmp(&a.success_rate)
409                .unwrap_or(std::cmp::Ordering::Equal)
410        });
411
412        paths.into_iter().take(10).collect() // 上位10パス
413    }
414
415    /// 戦略設定を分析
416    ///
417    /// 履歴データから最適な AdaptiveProvider 設定を算出。
418    pub fn analyze_strategy(&self) -> StrategyConfig {
419        if self.snapshots.is_empty() {
420            return StrategyConfig::default();
421        }
422
423        // エラー率の平均を計算
424        let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
425            (
426                acc.0 + s.episode_transitions.success_episodes,
427                acc.1 + s.episode_transitions.failure_episodes,
428            )
429        });
430
431        let avg_error_rate = if total_success + total_failure > 0 {
432            total_failure as f64 / (total_success + total_failure) as f64
433        } else {
434            0.3
435        };
436
437        // 総アクション数から成熟閾値を推定
438        let total_actions: u64 = self
439            .snapshots
440            .iter()
441            .map(|s| s.metadata.total_actions as u64)
442            .sum();
443        let avg_actions = total_actions as f64 / self.snapshots.len().max(1) as f64;
444
445        // 平均アクション数の 10% を成熟閾値に
446        let maturity_threshold = ((avg_actions * 0.1) as u32).clamp(5, 50);
447
448        // 初期戦略の決定
449        let initial_strategy = if avg_error_rate > 0.4 {
450            "thompson" // エラー率高 → 探索重視
451        } else if avg_error_rate < 0.1 {
452            "greedy" // エラー率低 → 活用重視
453        } else {
454            "ucb1" // バランス
455        };
456
457        StrategyConfig {
458            maturity_threshold,
459            error_rate_threshold: (avg_error_rate * 1.5).min(0.5), // 平均の1.5倍を閾値に
460            initial_strategy: initial_strategy.to_string(),
461        }
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    fn create_test_snapshot(success: u32, failure: u32) -> LearningSnapshot {
470        let mut snapshot = LearningSnapshot::empty();
471        snapshot.episode_transitions.success_episodes = success;
472        snapshot.episode_transitions.failure_episodes = failure;
473        snapshot.metadata.total_actions = (success + failure) * 5;
474        snapshot
475    }
476
477    #[test]
478    fn test_analyzer_empty_snapshots() {
479        let snapshots: Vec<LearningSnapshot> = vec![];
480        let analyzer = OfflineAnalyzer::new(&snapshots);
481        let model = analyzer.analyze();
482
483        assert_eq!(model.analyzed_sessions, 0);
484        assert!((model.parameters.ucb1_c - std::f64::consts::SQRT_2).abs() < 0.01);
485    }
486
487    #[test]
488    fn test_analyzer_high_success_rate() {
489        let snapshots = vec![
490            create_test_snapshot(9, 1),
491            create_test_snapshot(8, 2),
492            create_test_snapshot(10, 0),
493        ];
494        let analyzer = OfflineAnalyzer::new(&snapshots);
495        let params = analyzer.analyze_parameters();
496
497        // 成功率が高い → ucb1_c は低め(活用重視)
498        assert!(params.ucb1_c < std::f64::consts::SQRT_2);
499    }
500
501    #[test]
502    fn test_analyzer_low_success_rate() {
503        let snapshots = vec![
504            create_test_snapshot(3, 7),
505            create_test_snapshot(4, 6),
506            create_test_snapshot(2, 8),
507        ];
508        let analyzer = OfflineAnalyzer::new(&snapshots);
509        let params = analyzer.analyze_parameters();
510
511        // 成功率が低い → ucb1_c は高め(探索重視)
512        assert!(params.ucb1_c > std::f64::consts::SQRT_2);
513    }
514
515    #[test]
516    fn test_strategy_config_high_error() {
517        let snapshots = vec![create_test_snapshot(3, 7), create_test_snapshot(4, 6)];
518        let analyzer = OfflineAnalyzer::new(&snapshots);
519        let config = analyzer.analyze_strategy();
520
521        assert_eq!(config.initial_strategy, "thompson");
522    }
523
524    #[test]
525    fn test_strategy_config_low_error() {
526        let snapshots = vec![create_test_snapshot(19, 1), create_test_snapshot(18, 2)];
527        let analyzer = OfflineAnalyzer::new(&snapshots);
528        let config = analyzer.analyze_strategy();
529
530        assert_eq!(config.initial_strategy, "greedy");
531    }
532}