Skip to main content

swarm_engine_core/learn/
stats.rs

1//! Learn Stats - 学習用統計
2//!
3//! ActionEvent から抽出された学習パターン。
4//! Engine (Core) は基本統計のみを持ち、学習パターンはこのモジュールに分離。
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10// ============================================================================
11// Type Aliases (clippy::type_complexity)
12// ============================================================================
13
14/// 3-tuple key を持つ HashMap の型
15type Tuple3Map<V> = HashMap<(String, String, String), V>;
16
17/// 4-tuple key を持つ HashMap の型
18type Tuple4Map<V> = HashMap<(String, String, String, String), V>;
19
20// ============================================================================
21// LearnStats - 学習用統計の集約
22// ============================================================================
23
24/// 学習用統計の集約
25///
26/// Core の SwarmStats から分離された学習専用統計。
27/// ActionEvent を分析して学習パターンを抽出する。
28#[derive(Debug, Clone, Default, Serialize, Deserialize)]
29pub struct LearnStats {
30    /// エピソード遷移統計(成功/失敗エピソードでの遷移パターン)
31    pub episode_transitions: EpisodeTransitions,
32    /// N-gram 統計(3-gram, 4-gram パターン)
33    pub ngram_stats: NgramStats,
34    /// Selection 戦略効果測定
35    pub selection_performance: SelectionPerformance,
36    /// コンテキスト条件付き統計(prev_action → action の成功/失敗)
37    #[serde(
38        serialize_with = "serialize_tuple2_map",
39        deserialize_with = "deserialize_tuple2_map"
40    )]
41    pub contextual_stats: HashMap<(String, String), ContextualActionStats>,
42}
43
44impl LearnStats {
45    /// LearningSnapshot から prior をロード
46    pub fn load_prior(&mut self, snapshot: &crate::learn::LearningSnapshot) {
47        self.episode_transitions = snapshot.episode_transitions.clone();
48        self.ngram_stats = snapshot.ngram_stats.clone();
49        self.selection_performance = snapshot.selection_performance.clone();
50        // contextual_stats は ActionStats -> ContextualActionStats の変換が必要
51        for ((prev, action), stats) in &snapshot.contextual_stats {
52            let contextual = ContextualActionStats {
53                visits: stats.visits,
54                successes: stats.successes,
55                failures: stats.failures,
56            };
57            self.contextual_stats
58                .insert((prev.clone(), action.clone()), contextual);
59        }
60    }
61}
62
63/// コンテキスト条件付きアクション統計
64#[derive(Debug, Clone, Default, Serialize, Deserialize)]
65pub struct ContextualActionStats {
66    pub visits: u32,
67    pub successes: u32,
68    pub failures: u32,
69}
70
71impl ContextualActionStats {
72    pub fn success_rate(&self) -> f64 {
73        if self.visits == 0 {
74            0.5
75        } else {
76            self.successes as f64 / self.visits as f64
77        }
78    }
79}
80
81// ============================================================================
82// EpisodeTransitions
83// ============================================================================
84
85/// エピソード遷移統計
86///
87/// 成功エピソードと失敗エピソードでの遷移パターンを分離して記録。
88/// Q-learning風の価値関数近似に活用可能。
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
90pub struct EpisodeTransitions {
91    /// 成功エピソードでの遷移カウント (prev_action, action) → count
92    #[serde(
93        serialize_with = "serialize_tuple2_map",
94        deserialize_with = "deserialize_tuple2_map"
95    )]
96    pub success_transitions: HashMap<(String, String), u32>,
97    /// 失敗エピソードでの遷移カウント (prev_action, action) → count
98    #[serde(
99        serialize_with = "serialize_tuple2_map",
100        deserialize_with = "deserialize_tuple2_map"
101    )]
102    pub failure_transitions: HashMap<(String, String), u32>,
103    /// 成功エピソード数
104    pub success_episodes: u32,
105    /// 失敗エピソード数
106    pub failure_episodes: u32,
107}
108
109impl EpisodeTransitions {
110    /// 成功遷移確率を計算
111    pub fn success_transition_rate(&self, from: &str, to: &str) -> f64 {
112        let key = (from.to_string(), to.to_string());
113        let success_count = self.success_transitions.get(&key).copied().unwrap_or(0);
114        let failure_count = self.failure_transitions.get(&key).copied().unwrap_or(0);
115        let total = success_count + failure_count;
116
117        if total == 0 {
118            0.5
119        } else {
120            success_count as f64 / total as f64
121        }
122    }
123
124    /// 遷移の価値スコアを計算 [-1, 1]
125    pub fn transition_value(&self, from: &str, to: &str) -> f64 {
126        let key = (from.to_string(), to.to_string());
127        let success_count = self.success_transitions.get(&key).copied().unwrap_or(0) as f64;
128        let failure_count = self.failure_transitions.get(&key).copied().unwrap_or(0) as f64;
129
130        let total_success = self.success_transitions.values().sum::<u32>() as f64;
131        let total_failure = self.failure_transitions.values().sum::<u32>() as f64;
132
133        let success_rate = if total_success > 0.0 {
134            success_count / total_success
135        } else {
136            0.0
137        };
138        let failure_rate = if total_failure > 0.0 {
139            failure_count / total_failure
140        } else {
141            0.0
142        };
143
144        success_rate - failure_rate
145    }
146
147    /// 特定のアクションからの推奨次アクションを取得
148    pub fn recommended_next_actions(&self, from: &str) -> Vec<(String, f64)> {
149        let mut candidates: Vec<_> = self
150            .success_transitions
151            .iter()
152            .filter(|((f, _), _)| f == from)
153            .map(|((_, to), _)| {
154                let value = self.transition_value(from, to);
155                (to.clone(), value)
156            })
157            .collect();
158
159        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
160        candidates
161    }
162}
163
164// ============================================================================
165// NgramStats
166// ============================================================================
167
168/// N-gram 統計(3-gram, 4-gram パターン学習)
169#[derive(Debug, Clone, Default, Serialize, Deserialize)]
170pub struct NgramStats {
171    /// 3-gram: (a1, a2, a3) → (success_count, failure_count)
172    #[serde(
173        serialize_with = "serialize_tuple3_map",
174        deserialize_with = "deserialize_tuple3_map"
175    )]
176    pub trigrams: HashMap<(String, String, String), (u32, u32)>,
177    /// 4-gram: (a1, a2, a3, a4) → (success_count, failure_count)
178    #[serde(
179        serialize_with = "serialize_tuple4_map",
180        deserialize_with = "deserialize_tuple4_map"
181    )]
182    pub quadgrams: HashMap<(String, String, String, String), (u32, u32)>,
183}
184
185impl NgramStats {
186    /// 3-gram の成功率を計算
187    pub fn trigram_success_rate(&self, a1: &str, a2: &str, a3: &str) -> f64 {
188        let key = (a1.to_string(), a2.to_string(), a3.to_string());
189        match self.trigrams.get(&key) {
190            Some(&(success, failure)) => {
191                let total = success + failure;
192                if total == 0 {
193                    0.5
194                } else {
195                    success as f64 / total as f64
196                }
197            }
198            None => 0.5,
199        }
200    }
201
202    /// 4-gram の成功率を計算
203    pub fn quadgram_success_rate(&self, a1: &str, a2: &str, a3: &str, a4: &str) -> f64 {
204        let key = (
205            a1.to_string(),
206            a2.to_string(),
207            a3.to_string(),
208            a4.to_string(),
209        );
210        match self.quadgrams.get(&key) {
211            Some(&(success, failure)) => {
212                let total = success + failure;
213                if total == 0 {
214                    0.5
215                } else {
216                    success as f64 / total as f64
217                }
218            }
219            None => 0.5,
220        }
221    }
222
223    /// 3-gram の価値スコア [-1, 1]
224    pub fn trigram_value(&self, a1: &str, a2: &str, a3: &str) -> f64 {
225        let key = (a1.to_string(), a2.to_string(), a3.to_string());
226        match self.trigrams.get(&key) {
227            Some(&(success, failure)) => {
228                let total = success + failure;
229                if total == 0 {
230                    0.0
231                } else {
232                    (success as f64 / total as f64) * 2.0 - 1.0
233                }
234            }
235            None => 0.0,
236        }
237    }
238
239    /// 2つのアクション列の後に推奨されるアクション一覧
240    pub fn recommended_after(&self, a1: &str, a2: &str) -> Vec<(String, f64)> {
241        let mut candidates: Vec<_> = self
242            .trigrams
243            .iter()
244            .filter(|((x1, x2, _), _)| x1 == a1 && x2 == a2)
245            .map(|((_, _, a3), &(success, failure))| {
246                let total = success + failure;
247                let score = if total == 0 {
248                    0.0
249                } else {
250                    (success as f64 / total as f64) * 2.0 - 1.0
251                };
252                (a3.clone(), score)
253            })
254            .collect();
255
256        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
257        candidates
258    }
259
260    /// 3つのアクション列の後に推奨されるアクション一覧
261    pub fn recommended_after_three(&self, a1: &str, a2: &str, a3: &str) -> Vec<(String, f64)> {
262        let mut candidates: Vec<_> = self
263            .quadgrams
264            .iter()
265            .filter(|((x1, x2, x3, _), _)| x1 == a1 && x2 == a2 && x3 == a3)
266            .map(|((_, _, _, a4), &(success, failure))| {
267                let total = success + failure;
268                let score = if total == 0 {
269                    0.0
270                } else {
271                    (success as f64 / total as f64) * 2.0 - 1.0
272                };
273                (a4.clone(), score)
274            })
275            .collect();
276
277        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
278        candidates
279    }
280
281    pub fn trigram_count(&self) -> usize {
282        self.trigrams.len()
283    }
284
285    pub fn quadgram_count(&self) -> usize {
286        self.quadgrams.len()
287    }
288}
289
290// ============================================================================
291// SelectionPerformance
292// ============================================================================
293
294/// Selection 戦略効果測定(Meta-learning)
295#[derive(Debug, Clone, Default, Serialize, Deserialize)]
296pub struct SelectionPerformance {
297    /// 戦略ごとの統計
298    pub strategy_stats: HashMap<String, StrategyStats>,
299    /// 戦略切り替え履歴
300    pub switch_history: Vec<StrategySwitchEvent>,
301    /// 現在の戦略
302    pub current_strategy: Option<String>,
303    /// 現在の戦略開始時の visits
304    pub strategy_start_visits: u32,
305    /// 現在の戦略開始時の success_rate
306    pub strategy_start_success_rate: f64,
307}
308
309/// 戦略ごとの統計
310#[derive(Debug, Clone, Default, Serialize, Deserialize)]
311pub struct StrategyStats {
312    pub visits: u32,
313    pub successes: u32,
314    pub failures: u32,
315    pub usage_count: u32,
316    pub episodes_success: u32,
317    pub episodes_failure: u32,
318}
319
320impl StrategyStats {
321    pub fn success_rate(&self) -> f64 {
322        if self.visits == 0 {
323            0.5
324        } else {
325            self.successes as f64 / self.visits as f64
326        }
327    }
328
329    pub fn episode_success_rate(&self) -> f64 {
330        let total = self.episodes_success + self.episodes_failure;
331        if total == 0 {
332            0.5
333        } else {
334            self.episodes_success as f64 / total as f64
335        }
336    }
337}
338
339/// 戦略切り替えイベント
340#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct StrategySwitchEvent {
342    pub from: String,
343    pub to: String,
344    pub visits_at_switch: u32,
345    pub success_rate_at_switch: f64,
346    pub from_strategy_success_rate: f64,
347}
348
349impl SelectionPerformance {
350    /// 戦略を記録開始
351    pub fn start_strategy(
352        &mut self,
353        strategy: &str,
354        current_visits: u32,
355        current_success_rate: f64,
356    ) {
357        if let Some(ref current) = self.current_strategy {
358            if current != strategy {
359                let from_stats = self
360                    .strategy_stats
361                    .get(current)
362                    .cloned()
363                    .unwrap_or_default();
364                self.switch_history.push(StrategySwitchEvent {
365                    from: current.clone(),
366                    to: strategy.to_string(),
367                    visits_at_switch: current_visits,
368                    success_rate_at_switch: current_success_rate,
369                    from_strategy_success_rate: from_stats.success_rate(),
370                });
371            }
372        }
373
374        self.current_strategy = Some(strategy.to_string());
375        self.strategy_start_visits = current_visits;
376        self.strategy_start_success_rate = current_success_rate;
377
378        self.strategy_stats
379            .entry(strategy.to_string())
380            .or_default()
381            .usage_count += 1;
382    }
383
384    /// アクション結果を記録
385    pub fn record_action(&mut self, success: bool) {
386        if let Some(ref strategy) = self.current_strategy {
387            let stats = self.strategy_stats.entry(strategy.clone()).or_default();
388            stats.visits += 1;
389            if success {
390                stats.successes += 1;
391            } else {
392                stats.failures += 1;
393            }
394        }
395    }
396
397    /// エピソード終了を記録
398    pub fn record_episode_end(&mut self, success: bool) {
399        if let Some(ref strategy) = self.current_strategy {
400            let stats = self.strategy_stats.entry(strategy.clone()).or_default();
401            if success {
402                stats.episodes_success += 1;
403            } else {
404                stats.episodes_failure += 1;
405            }
406        }
407    }
408
409    /// 戦略の効果を取得
410    pub fn strategy_effectiveness(&self, strategy: &str) -> Option<f64> {
411        self.strategy_stats.get(strategy).map(|s| s.success_rate())
412    }
413
414    /// 最も効果的な戦略を取得
415    pub fn best_strategy(&self) -> Option<(&str, f64)> {
416        self.strategy_stats
417            .iter()
418            .filter(|(_, stats)| stats.visits >= 10)
419            .max_by(|(_, a), (_, b)| {
420                a.success_rate()
421                    .partial_cmp(&b.success_rate())
422                    .unwrap_or(std::cmp::Ordering::Equal)
423            })
424            .map(|(name, stats)| (name.as_str(), stats.success_rate()))
425    }
426
427    /// 状況に応じた推奨戦略を取得
428    pub fn recommended_strategy(&self, failure_rate: f64, visits: u32) -> &str {
429        let ucb1_score = self.strategy_score_for_context("UCB1", failure_rate, visits);
430        let greedy_score = self.strategy_score_for_context("Greedy", failure_rate, visits);
431        let thompson_score = self.strategy_score_for_context("Thompson", failure_rate, visits);
432
433        if ucb1_score >= greedy_score && ucb1_score >= thompson_score {
434            "UCB1"
435        } else if greedy_score >= thompson_score {
436            "Greedy"
437        } else {
438            "Thompson"
439        }
440    }
441
442    fn strategy_score_for_context(&self, strategy: &str, failure_rate: f64, _visits: u32) -> f64 {
443        let base_score = self
444            .strategy_stats
445            .get(strategy)
446            .map(|s| s.success_rate())
447            .unwrap_or(0.5);
448
449        match strategy {
450            "UCB1" => base_score + failure_rate * 0.2,
451            "Greedy" => base_score + (1.0 - failure_rate) * 0.2,
452            "Thompson" => {
453                let distance_from_middle = (failure_rate - 0.5).abs();
454                base_score + (0.5 - distance_from_middle) * 0.2
455            }
456            _ => base_score,
457        }
458    }
459}
460
461// ============================================================================
462// Tuple Key HashMap Serialization
463// ============================================================================
464
465fn serialize_tuple2_map<V, S>(
466    map: &HashMap<(String, String), V>,
467    serializer: S,
468) -> Result<S::Ok, S::Error>
469where
470    V: Serialize,
471    S: serde::Serializer,
472{
473    use serde::ser::SerializeMap;
474    let mut ser_map = serializer.serialize_map(Some(map.len()))?;
475    for ((a, b), v) in map {
476        ser_map.serialize_entry(&format!("{}:{}", a, b), v)?;
477    }
478    ser_map.end()
479}
480
481fn deserialize_tuple2_map<'de, V, D>(
482    deserializer: D,
483) -> Result<HashMap<(String, String), V>, D::Error>
484where
485    V: Deserialize<'de>,
486    D: serde::Deserializer<'de>,
487{
488    use serde::de::Error;
489    let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
490    let mut result = HashMap::new();
491    for (k, v) in string_map {
492        let parts: Vec<&str> = k.splitn(2, ':').collect();
493        if parts.len() != 2 {
494            return Err(D::Error::custom(format!("invalid tuple2 key: {}", k)));
495        }
496        result.insert((parts[0].to_string(), parts[1].to_string()), v);
497    }
498    Ok(result)
499}
500
501fn serialize_tuple3_map<V, S>(
502    map: &HashMap<(String, String, String), V>,
503    serializer: S,
504) -> Result<S::Ok, S::Error>
505where
506    V: Serialize,
507    S: serde::Serializer,
508{
509    use serde::ser::SerializeMap;
510    let mut ser_map = serializer.serialize_map(Some(map.len()))?;
511    for ((a, b, c), v) in map {
512        ser_map.serialize_entry(&format!("{}:{}:{}", a, b, c), v)?;
513    }
514    ser_map.end()
515}
516
517fn deserialize_tuple3_map<'de, V, D>(deserializer: D) -> Result<Tuple3Map<V>, D::Error>
518where
519    V: Deserialize<'de>,
520    D: serde::Deserializer<'de>,
521{
522    use serde::de::Error;
523    let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
524    let mut result = HashMap::new();
525    for (k, v) in string_map {
526        let parts: Vec<&str> = k.splitn(3, ':').collect();
527        if parts.len() != 3 {
528            return Err(D::Error::custom(format!("invalid tuple3 key: {}", k)));
529        }
530        result.insert(
531            (
532                parts[0].to_string(),
533                parts[1].to_string(),
534                parts[2].to_string(),
535            ),
536            v,
537        );
538    }
539    Ok(result)
540}
541
542fn serialize_tuple4_map<V, S>(
543    map: &HashMap<(String, String, String, String), V>,
544    serializer: S,
545) -> Result<S::Ok, S::Error>
546where
547    V: Serialize,
548    S: serde::Serializer,
549{
550    use serde::ser::SerializeMap;
551    let mut ser_map = serializer.serialize_map(Some(map.len()))?;
552    for ((a, b, c, d), v) in map {
553        ser_map.serialize_entry(&format!("{}:{}:{}:{}", a, b, c, d), v)?;
554    }
555    ser_map.end()
556}
557
558fn deserialize_tuple4_map<'de, V, D>(deserializer: D) -> Result<Tuple4Map<V>, D::Error>
559where
560    V: Deserialize<'de>,
561    D: serde::Deserializer<'de>,
562{
563    use serde::de::Error;
564    let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
565    let mut result = HashMap::new();
566    for (k, v) in string_map {
567        let parts: Vec<&str> = k.splitn(4, ':').collect();
568        if parts.len() != 4 {
569            return Err(D::Error::custom(format!("invalid tuple4 key: {}", k)));
570        }
571        result.insert(
572            (
573                parts[0].to_string(),
574                parts[1].to_string(),
575                parts[2].to_string(),
576                parts[3].to_string(),
577            ),
578            v,
579        );
580    }
581    Ok(result)
582}