Skip to main content

swarm_engine_core/learn/
stats.rs

1//! Learn Stats - 学習用統計
2//!
3//! ActionEvent から抽出された学習パターン。
4//! Engine (Core) は基本統計のみを持ち、学習パターンはこのモジュールに分離。
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10// ============================================================================
11// LearnStats - 学習用統計の集約
12// ============================================================================
13
14/// 学習用統計の集約
15///
16/// Core の SwarmStats から分離された学習専用統計。
17/// ActionEvent を分析して学習パターンを抽出する。
18#[derive(Debug, Clone, Default, Serialize, Deserialize)]
19pub struct LearnStats {
20    /// エピソード遷移統計(成功/失敗エピソードでの遷移パターン)
21    pub episode_transitions: EpisodeTransitions,
22    /// N-gram 統計(3-gram, 4-gram パターン)
23    pub ngram_stats: NgramStats,
24    /// Selection 戦略効果測定
25    pub selection_performance: SelectionPerformance,
26    /// コンテキスト条件付き統計(prev_action → action の成功/失敗)
27    #[serde(
28        serialize_with = "serialize_tuple2_map",
29        deserialize_with = "deserialize_tuple2_map"
30    )]
31    pub contextual_stats: HashMap<(String, String), ContextualActionStats>,
32}
33
34impl LearnStats {
35    /// LearningSnapshot から prior をロード
36    pub fn load_prior(&mut self, snapshot: &crate::learn::LearningSnapshot) {
37        self.episode_transitions = snapshot.episode_transitions.clone();
38        self.ngram_stats = snapshot.ngram_stats.clone();
39        self.selection_performance = snapshot.selection_performance.clone();
40        // contextual_stats は ActionStats -> ContextualActionStats の変換が必要
41        for ((prev, action), stats) in &snapshot.contextual_stats {
42            let contextual = ContextualActionStats {
43                visits: stats.visits,
44                successes: stats.successes,
45                failures: stats.failures,
46            };
47            self.contextual_stats
48                .insert((prev.clone(), action.clone()), contextual);
49        }
50    }
51}
52
53/// コンテキスト条件付きアクション統計
54#[derive(Debug, Clone, Default, Serialize, Deserialize)]
55pub struct ContextualActionStats {
56    pub visits: u32,
57    pub successes: u32,
58    pub failures: u32,
59}
60
61impl ContextualActionStats {
62    pub fn success_rate(&self) -> f64 {
63        if self.visits == 0 {
64            0.5
65        } else {
66            self.successes as f64 / self.visits as f64
67        }
68    }
69}
70
71// ============================================================================
72// EpisodeTransitions
73// ============================================================================
74
75/// エピソード遷移統計
76///
77/// 成功エピソードと失敗エピソードでの遷移パターンを分離して記録。
78/// Q-learning風の価値関数近似に活用可能。
79#[derive(Debug, Clone, Default, Serialize, Deserialize)]
80pub struct EpisodeTransitions {
81    /// 成功エピソードでの遷移カウント (prev_action, action) → count
82    #[serde(
83        serialize_with = "serialize_tuple2_map",
84        deserialize_with = "deserialize_tuple2_map"
85    )]
86    pub success_transitions: HashMap<(String, String), u32>,
87    /// 失敗エピソードでの遷移カウント (prev_action, action) → count
88    #[serde(
89        serialize_with = "serialize_tuple2_map",
90        deserialize_with = "deserialize_tuple2_map"
91    )]
92    pub failure_transitions: HashMap<(String, String), u32>,
93    /// 成功エピソード数
94    pub success_episodes: u32,
95    /// 失敗エピソード数
96    pub failure_episodes: u32,
97}
98
99impl EpisodeTransitions {
100    /// 成功遷移確率を計算
101    pub fn success_transition_rate(&self, from: &str, to: &str) -> f64 {
102        let key = (from.to_string(), to.to_string());
103        let success_count = self.success_transitions.get(&key).copied().unwrap_or(0);
104        let failure_count = self.failure_transitions.get(&key).copied().unwrap_or(0);
105        let total = success_count + failure_count;
106
107        if total == 0 {
108            0.5
109        } else {
110            success_count as f64 / total as f64
111        }
112    }
113
114    /// 遷移の価値スコアを計算 [-1, 1]
115    pub fn transition_value(&self, from: &str, to: &str) -> f64 {
116        let key = (from.to_string(), to.to_string());
117        let success_count = self.success_transitions.get(&key).copied().unwrap_or(0) as f64;
118        let failure_count = self.failure_transitions.get(&key).copied().unwrap_or(0) as f64;
119
120        let total_success = self.success_transitions.values().sum::<u32>() as f64;
121        let total_failure = self.failure_transitions.values().sum::<u32>() as f64;
122
123        let success_rate = if total_success > 0.0 {
124            success_count / total_success
125        } else {
126            0.0
127        };
128        let failure_rate = if total_failure > 0.0 {
129            failure_count / total_failure
130        } else {
131            0.0
132        };
133
134        success_rate - failure_rate
135    }
136
137    /// 特定のアクションからの推奨次アクションを取得
138    pub fn recommended_next_actions(&self, from: &str) -> Vec<(String, f64)> {
139        let mut candidates: Vec<_> = self
140            .success_transitions
141            .iter()
142            .filter(|((f, _), _)| f == from)
143            .map(|((_, to), _)| {
144                let value = self.transition_value(from, to);
145                (to.clone(), value)
146            })
147            .collect();
148
149        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
150        candidates
151    }
152}
153
154// ============================================================================
155// NgramStats
156// ============================================================================
157
158/// N-gram 統計(3-gram, 4-gram パターン学習)
159#[derive(Debug, Clone, Default, Serialize, Deserialize)]
160pub struct NgramStats {
161    /// 3-gram: (a1, a2, a3) → (success_count, failure_count)
162    #[serde(
163        serialize_with = "serialize_tuple3_map",
164        deserialize_with = "deserialize_tuple3_map"
165    )]
166    pub trigrams: HashMap<(String, String, String), (u32, u32)>,
167    /// 4-gram: (a1, a2, a3, a4) → (success_count, failure_count)
168    #[serde(
169        serialize_with = "serialize_tuple4_map",
170        deserialize_with = "deserialize_tuple4_map"
171    )]
172    pub quadgrams: HashMap<(String, String, String, String), (u32, u32)>,
173}
174
175impl NgramStats {
176    /// 3-gram の成功率を計算
177    pub fn trigram_success_rate(&self, a1: &str, a2: &str, a3: &str) -> f64 {
178        let key = (a1.to_string(), a2.to_string(), a3.to_string());
179        match self.trigrams.get(&key) {
180            Some(&(success, failure)) => {
181                let total = success + failure;
182                if total == 0 {
183                    0.5
184                } else {
185                    success as f64 / total as f64
186                }
187            }
188            None => 0.5,
189        }
190    }
191
192    /// 4-gram の成功率を計算
193    pub fn quadgram_success_rate(&self, a1: &str, a2: &str, a3: &str, a4: &str) -> f64 {
194        let key = (
195            a1.to_string(),
196            a2.to_string(),
197            a3.to_string(),
198            a4.to_string(),
199        );
200        match self.quadgrams.get(&key) {
201            Some(&(success, failure)) => {
202                let total = success + failure;
203                if total == 0 {
204                    0.5
205                } else {
206                    success as f64 / total as f64
207                }
208            }
209            None => 0.5,
210        }
211    }
212
213    /// 3-gram の価値スコア [-1, 1]
214    pub fn trigram_value(&self, a1: &str, a2: &str, a3: &str) -> f64 {
215        let key = (a1.to_string(), a2.to_string(), a3.to_string());
216        match self.trigrams.get(&key) {
217            Some(&(success, failure)) => {
218                let total = success + failure;
219                if total == 0 {
220                    0.0
221                } else {
222                    (success as f64 / total as f64) * 2.0 - 1.0
223                }
224            }
225            None => 0.0,
226        }
227    }
228
229    /// 2つのアクション列の後に推奨されるアクション一覧
230    pub fn recommended_after(&self, a1: &str, a2: &str) -> Vec<(String, f64)> {
231        let mut candidates: Vec<_> = self
232            .trigrams
233            .iter()
234            .filter(|((x1, x2, _), _)| x1 == a1 && x2 == a2)
235            .map(|((_, _, a3), &(success, failure))| {
236                let total = success + failure;
237                let score = if total == 0 {
238                    0.0
239                } else {
240                    (success as f64 / total as f64) * 2.0 - 1.0
241                };
242                (a3.clone(), score)
243            })
244            .collect();
245
246        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
247        candidates
248    }
249
250    /// 3つのアクション列の後に推奨されるアクション一覧
251    pub fn recommended_after_three(&self, a1: &str, a2: &str, a3: &str) -> Vec<(String, f64)> {
252        let mut candidates: Vec<_> = self
253            .quadgrams
254            .iter()
255            .filter(|((x1, x2, x3, _), _)| x1 == a1 && x2 == a2 && x3 == a3)
256            .map(|((_, _, _, a4), &(success, failure))| {
257                let total = success + failure;
258                let score = if total == 0 {
259                    0.0
260                } else {
261                    (success as f64 / total as f64) * 2.0 - 1.0
262                };
263                (a4.clone(), score)
264            })
265            .collect();
266
267        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
268        candidates
269    }
270
271    pub fn trigram_count(&self) -> usize {
272        self.trigrams.len()
273    }
274
275    pub fn quadgram_count(&self) -> usize {
276        self.quadgrams.len()
277    }
278}
279
280// ============================================================================
281// SelectionPerformance
282// ============================================================================
283
284/// Selection 戦略効果測定(Meta-learning)
285#[derive(Debug, Clone, Default, Serialize, Deserialize)]
286pub struct SelectionPerformance {
287    /// 戦略ごとの統計
288    pub strategy_stats: HashMap<String, StrategyStats>,
289    /// 戦略切り替え履歴
290    pub switch_history: Vec<StrategySwitchEvent>,
291    /// 現在の戦略
292    pub current_strategy: Option<String>,
293    /// 現在の戦略開始時の visits
294    pub strategy_start_visits: u32,
295    /// 現在の戦略開始時の success_rate
296    pub strategy_start_success_rate: f64,
297}
298
299/// 戦略ごとの統計
300#[derive(Debug, Clone, Default, Serialize, Deserialize)]
301pub struct StrategyStats {
302    pub visits: u32,
303    pub successes: u32,
304    pub failures: u32,
305    pub usage_count: u32,
306    pub episodes_success: u32,
307    pub episodes_failure: u32,
308}
309
310impl StrategyStats {
311    pub fn success_rate(&self) -> f64 {
312        if self.visits == 0 {
313            0.5
314        } else {
315            self.successes as f64 / self.visits as f64
316        }
317    }
318
319    pub fn episode_success_rate(&self) -> f64 {
320        let total = self.episodes_success + self.episodes_failure;
321        if total == 0 {
322            0.5
323        } else {
324            self.episodes_success as f64 / total as f64
325        }
326    }
327}
328
329/// 戦略切り替えイベント
330#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct StrategySwitchEvent {
332    pub from: String,
333    pub to: String,
334    pub visits_at_switch: u32,
335    pub success_rate_at_switch: f64,
336    pub from_strategy_success_rate: f64,
337}
338
339impl SelectionPerformance {
340    /// 戦略を記録開始
341    pub fn start_strategy(
342        &mut self,
343        strategy: &str,
344        current_visits: u32,
345        current_success_rate: f64,
346    ) {
347        if let Some(ref current) = self.current_strategy {
348            if current != strategy {
349                let from_stats = self
350                    .strategy_stats
351                    .get(current)
352                    .cloned()
353                    .unwrap_or_default();
354                self.switch_history.push(StrategySwitchEvent {
355                    from: current.clone(),
356                    to: strategy.to_string(),
357                    visits_at_switch: current_visits,
358                    success_rate_at_switch: current_success_rate,
359                    from_strategy_success_rate: from_stats.success_rate(),
360                });
361            }
362        }
363
364        self.current_strategy = Some(strategy.to_string());
365        self.strategy_start_visits = current_visits;
366        self.strategy_start_success_rate = current_success_rate;
367
368        self.strategy_stats
369            .entry(strategy.to_string())
370            .or_default()
371            .usage_count += 1;
372    }
373
374    /// アクション結果を記録
375    pub fn record_action(&mut self, success: bool) {
376        if let Some(ref strategy) = self.current_strategy {
377            let stats = self.strategy_stats.entry(strategy.clone()).or_default();
378            stats.visits += 1;
379            if success {
380                stats.successes += 1;
381            } else {
382                stats.failures += 1;
383            }
384        }
385    }
386
387    /// エピソード終了を記録
388    pub fn record_episode_end(&mut self, success: bool) {
389        if let Some(ref strategy) = self.current_strategy {
390            let stats = self.strategy_stats.entry(strategy.clone()).or_default();
391            if success {
392                stats.episodes_success += 1;
393            } else {
394                stats.episodes_failure += 1;
395            }
396        }
397    }
398
399    /// 戦略の効果を取得
400    pub fn strategy_effectiveness(&self, strategy: &str) -> Option<f64> {
401        self.strategy_stats.get(strategy).map(|s| s.success_rate())
402    }
403
404    /// 最も効果的な戦略を取得
405    pub fn best_strategy(&self) -> Option<(&str, f64)> {
406        self.strategy_stats
407            .iter()
408            .filter(|(_, stats)| stats.visits >= 10)
409            .max_by(|(_, a), (_, b)| {
410                a.success_rate()
411                    .partial_cmp(&b.success_rate())
412                    .unwrap_or(std::cmp::Ordering::Equal)
413            })
414            .map(|(name, stats)| (name.as_str(), stats.success_rate()))
415    }
416
417    /// 状況に応じた推奨戦略を取得
418    pub fn recommended_strategy(&self, failure_rate: f64, visits: u32) -> &str {
419        let ucb1_score = self.strategy_score_for_context("UCB1", failure_rate, visits);
420        let greedy_score = self.strategy_score_for_context("Greedy", failure_rate, visits);
421        let thompson_score = self.strategy_score_for_context("Thompson", failure_rate, visits);
422
423        if ucb1_score >= greedy_score && ucb1_score >= thompson_score {
424            "UCB1"
425        } else if greedy_score >= thompson_score {
426            "Greedy"
427        } else {
428            "Thompson"
429        }
430    }
431
432    fn strategy_score_for_context(&self, strategy: &str, failure_rate: f64, _visits: u32) -> f64 {
433        let base_score = self
434            .strategy_stats
435            .get(strategy)
436            .map(|s| s.success_rate())
437            .unwrap_or(0.5);
438
439        match strategy {
440            "UCB1" => base_score + failure_rate * 0.2,
441            "Greedy" => base_score + (1.0 - failure_rate) * 0.2,
442            "Thompson" => {
443                let distance_from_middle = (failure_rate - 0.5).abs();
444                base_score + (0.5 - distance_from_middle) * 0.2
445            }
446            _ => base_score,
447        }
448    }
449}
450
451// ============================================================================
452// Tuple Key HashMap Serialization
453// ============================================================================
454
455fn serialize_tuple2_map<V, S>(
456    map: &HashMap<(String, String), V>,
457    serializer: S,
458) -> Result<S::Ok, S::Error>
459where
460    V: Serialize,
461    S: serde::Serializer,
462{
463    use serde::ser::SerializeMap;
464    let mut ser_map = serializer.serialize_map(Some(map.len()))?;
465    for ((a, b), v) in map {
466        ser_map.serialize_entry(&format!("{}:{}", a, b), v)?;
467    }
468    ser_map.end()
469}
470
471fn deserialize_tuple2_map<'de, V, D>(
472    deserializer: D,
473) -> Result<HashMap<(String, String), V>, D::Error>
474where
475    V: Deserialize<'de>,
476    D: serde::Deserializer<'de>,
477{
478    use serde::de::Error;
479    let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
480    let mut result = HashMap::new();
481    for (k, v) in string_map {
482        let parts: Vec<&str> = k.splitn(2, ':').collect();
483        if parts.len() != 2 {
484            return Err(D::Error::custom(format!("invalid tuple2 key: {}", k)));
485        }
486        result.insert((parts[0].to_string(), parts[1].to_string()), v);
487    }
488    Ok(result)
489}
490
491fn serialize_tuple3_map<V, S>(
492    map: &HashMap<(String, String, String), V>,
493    serializer: S,
494) -> Result<S::Ok, S::Error>
495where
496    V: Serialize,
497    S: serde::Serializer,
498{
499    use serde::ser::SerializeMap;
500    let mut ser_map = serializer.serialize_map(Some(map.len()))?;
501    for ((a, b, c), v) in map {
502        ser_map.serialize_entry(&format!("{}:{}:{}", a, b, c), v)?;
503    }
504    ser_map.end()
505}
506
507fn deserialize_tuple3_map<'de, V, D>(
508    deserializer: D,
509) -> Result<HashMap<(String, String, String), V>, D::Error>
510where
511    V: Deserialize<'de>,
512    D: serde::Deserializer<'de>,
513{
514    use serde::de::Error;
515    let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
516    let mut result = HashMap::new();
517    for (k, v) in string_map {
518        let parts: Vec<&str> = k.splitn(3, ':').collect();
519        if parts.len() != 3 {
520            return Err(D::Error::custom(format!("invalid tuple3 key: {}", k)));
521        }
522        result.insert(
523            (
524                parts[0].to_string(),
525                parts[1].to_string(),
526                parts[2].to_string(),
527            ),
528            v,
529        );
530    }
531    Ok(result)
532}
533
534fn serialize_tuple4_map<V, S>(
535    map: &HashMap<(String, String, String, String), V>,
536    serializer: S,
537) -> Result<S::Ok, S::Error>
538where
539    V: Serialize,
540    S: serde::Serializer,
541{
542    use serde::ser::SerializeMap;
543    let mut ser_map = serializer.serialize_map(Some(map.len()))?;
544    for ((a, b, c, d), v) in map {
545        ser_map.serialize_entry(&format!("{}:{}:{}:{}", a, b, c, d), v)?;
546    }
547    ser_map.end()
548}
549
550fn deserialize_tuple4_map<'de, V, D>(
551    deserializer: D,
552) -> Result<HashMap<(String, String, String, String), V>, D::Error>
553where
554    V: Deserialize<'de>,
555    D: serde::Deserializer<'de>,
556{
557    use serde::de::Error;
558    let string_map: HashMap<String, V> = HashMap::deserialize(deserializer)?;
559    let mut result = HashMap::new();
560    for (k, v) in string_map {
561        let parts: Vec<&str> = k.splitn(4, ':').collect();
562        if parts.len() != 4 {
563            return Err(D::Error::custom(format!("invalid tuple4 key: {}", k)));
564        }
565        result.insert(
566            (
567                parts[0].to_string(),
568                parts[1].to_string(),
569                parts[2].to_string(),
570                parts[3].to_string(),
571            ),
572            v,
573        );
574    }
575    Ok(result)
576}