Skip to main content

swarm_engine_core/online_stats/
swarm.rs

1//! SwarmStats - オンライン統計
2//!
3//! Swarm 内でリアルタイムに更新される基本統計。
4//! Selection(UCB1/Thompson)が参照する visits/success_rate を提供。
5//!
6//! # 責務
7//!
8//! - アクション別統計(visits, successes, failures)
9//! - アクション×ターゲット別統計
10//! - グローバル統計
11//!
12//! # 分離された責務
13//!
14//! - 学習統計(EpisodeTransitions, NgramStats 等)→ `learn::LearnStats`
15
16use std::collections::HashMap;
17use std::time::Duration;
18
19use serde::{Deserialize, Serialize};
20
21use crate::events::ActionEvent;
22
23/// Swarm オンライン統計
24///
25/// ActionEvent を受け取り、Selection が参照する基本統計を提供。
26/// UCB1/Thompson が必要とする visits/success_rate をリアルタイムに更新。
27#[derive(Debug, Clone, Default)]
28pub struct SwarmStats {
29    /// アクション別統計
30    action_stats: HashMap<String, ActionStats>,
31    /// アクション × ターゲット別統計
32    action_target_stats: HashMap<(String, String), ActionStats>,
33    /// グローバル統計(Worker アクションのみ)
34    global: GlobalStats,
35}
36
37/// アクション単位の統計
38#[derive(Debug, Clone, Default, Serialize, Deserialize)]
39pub struct ActionStats {
40    /// 訪問回数
41    pub visits: u32,
42    /// 成功回数
43    pub successes: u32,
44    /// 失敗回数
45    pub failures: u32,
46    /// 発見数の合計
47    pub discoveries: u32,
48    /// 総実行時間
49    #[serde(
50        serialize_with = "serialize_duration",
51        deserialize_with = "deserialize_duration"
52    )]
53    pub total_duration: Duration,
54    /// KPI貢献度の累計
55    #[serde(default)]
56    pub kpi_total: f64,
57    /// KPI貢献度が設定されたイベント数
58    #[serde(default)]
59    pub kpi_count: u32,
60}
61
62fn serialize_duration<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
63where
64    S: serde::Serializer,
65{
66    serializer.serialize_u64(duration.as_millis() as u64)
67}
68
69fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
70where
71    D: serde::Deserializer<'de>,
72{
73    let millis = u64::deserialize(deserializer)?;
74    Ok(Duration::from_millis(millis))
75}
76
77impl ActionStats {
78    /// 成功率(訪問なしの場合は 0.5)
79    pub fn success_rate(&self) -> f64 {
80        if self.visits == 0 {
81            0.5
82        } else {
83            self.successes as f64 / self.visits as f64
84        }
85    }
86
87    /// 平均実行時間
88    pub fn avg_duration(&self) -> Duration {
89        if self.visits == 0 {
90            Duration::ZERO
91        } else {
92            self.total_duration / self.visits
93        }
94    }
95
96    /// 平均発見数
97    pub fn avg_discoveries(&self) -> f64 {
98        if self.visits == 0 {
99            0.0
100        } else {
101            self.discoveries as f64 / self.visits as f64
102        }
103    }
104
105    /// 平均KPI貢献度(KPIが設定されたイベントのみ)
106    pub fn avg_kpi(&self) -> f64 {
107        if self.kpi_count == 0 {
108            0.0
109        } else {
110            self.kpi_total / self.kpi_count as f64
111        }
112    }
113}
114
115/// グローバル統計
116#[derive(Debug, Clone, Default)]
117pub struct GlobalStats {
118    /// 総訪問回数
119    pub total_visits: u32,
120    /// 総成功回数
121    pub total_successes: u32,
122    /// 総失敗回数
123    pub total_failures: u32,
124    /// 総発見数
125    pub total_discoveries: u32,
126    /// 総実行時間
127    pub total_duration: Duration,
128    /// KPI貢献度の累計
129    pub total_kpi: f64,
130    /// KPI貢献度が設定されたイベント数
131    pub kpi_count: u32,
132}
133
134impl GlobalStats {
135    /// 成功率(訪問なしの場合は 1.0)
136    pub fn success_rate(&self) -> f64 {
137        if self.total_visits == 0 {
138            1.0
139        } else {
140            self.total_successes as f64 / self.total_visits as f64
141        }
142    }
143
144    /// 失敗率(訪問なしの場合は 0.0)
145    pub fn failure_rate(&self) -> f64 {
146        if self.total_visits == 0 {
147            0.0
148        } else {
149            self.total_failures as f64 / self.total_visits as f64
150        }
151    }
152
153    /// 平均KPI貢献度(KPIが設定されたイベントのみ)
154    pub fn avg_kpi(&self) -> f64 {
155        if self.kpi_count == 0 {
156            0.0
157        } else {
158            self.total_kpi / self.kpi_count as f64
159        }
160    }
161}
162
163impl SwarmStats {
164    pub fn new() -> Self {
165        Self::default()
166    }
167
168    /// イベントを記録
169    ///
170    /// Selection 用の統計のみ更新。
171    /// - LLM イベント、Manager イベントはスキップ
172    /// - Worker の通常アクションのみ統計に反映
173    pub fn record(&mut self, event: &ActionEvent) {
174        let action = &event.action;
175        let success = event.result.success;
176        let duration = event.duration;
177
178        // LLM 呼び出しイベントはスキップ
179        if action == "llm_invoke" {
180            return;
181        }
182
183        // Manager イベントはスキップ
184        if event.worker_id.is_manager() {
185            return;
186        }
187
188        // tick_start/tick_end はスキップ
189        if action == "tick_start" || action == "tick_end" {
190            return;
191        }
192
193        let target = event.target.as_deref();
194        let discoveries = event.result.discoveries;
195        let kpi = event.result.kpi_contribution;
196
197        // グローバル統計を更新
198        self.global.total_visits += 1;
199        self.global.total_duration += duration;
200        if success {
201            self.global.total_successes += 1;
202            self.global.total_discoveries += discoveries;
203        } else {
204            self.global.total_failures += 1;
205        }
206        if let Some(k) = kpi {
207            self.global.total_kpi += k;
208            self.global.kpi_count += 1;
209        }
210
211        // アクション統計を更新
212        let action_stat = self.action_stats.entry(action.clone()).or_default();
213        action_stat.visits += 1;
214        action_stat.total_duration += duration;
215        if success {
216            action_stat.successes += 1;
217            action_stat.discoveries += discoveries;
218        } else {
219            action_stat.failures += 1;
220        }
221        if let Some(k) = kpi {
222            action_stat.kpi_total += k;
223            action_stat.kpi_count += 1;
224        }
225
226        // アクション × ターゲット統計を更新
227        if let Some(t) = target {
228            let at_stat = self
229                .action_target_stats
230                .entry((action.clone(), t.to_string()))
231                .or_default();
232            at_stat.visits += 1;
233            at_stat.total_duration += duration;
234            if success {
235                at_stat.successes += 1;
236                at_stat.discoveries += discoveries;
237            } else {
238                at_stat.failures += 1;
239            }
240            if let Some(k) = kpi {
241                at_stat.kpi_total += k;
242                at_stat.kpi_count += 1;
243            }
244        }
245    }
246
247    /// アクションの統計を取得
248    pub fn get_action_stats(&self, action: &str) -> ActionStats {
249        self.action_stats.get(action).cloned().unwrap_or_default()
250    }
251
252    /// アクション × ターゲットの統計を取得
253    pub fn get_action_target_stats(&self, action: &str, target: &str) -> ActionStats {
254        self.action_target_stats
255            .get(&(action.to_string(), target.to_string()))
256            .cloned()
257            .unwrap_or_default()
258    }
259
260    /// グローバル統計を取得
261    pub fn global(&self) -> &GlobalStats {
262        &self.global
263    }
264
265    /// 総訪問回数
266    pub fn total_visits(&self) -> u32 {
267        self.global.total_visits
268    }
269
270    /// 総成功回数
271    pub fn total_successes(&self) -> u32 {
272        self.global.total_successes
273    }
274
275    /// 総失敗回数
276    pub fn total_failures(&self) -> u32 {
277        self.global.total_failures
278    }
279
280    /// 全体の成功率
281    pub fn success_rate(&self) -> f64 {
282        self.global.success_rate()
283    }
284
285    /// 全体の失敗率
286    pub fn failure_rate(&self) -> f64 {
287        self.global.failure_rate()
288    }
289
290    /// 全アクションの統計一覧
291    pub fn all_action_stats(&self) -> impl Iterator<Item = (&String, &ActionStats)> {
292        self.action_stats.iter()
293    }
294
295    /// 統計をリセット
296    pub fn reset(&mut self) {
297        self.action_stats.clear();
298        self.action_target_stats.clear();
299        self.global = GlobalStats::default();
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use std::time::Duration;
306
307    use super::*;
308    use crate::events::{ActionEventBuilder, ActionEventResult};
309    use crate::types::WorkerId;
310
311    fn make_event(action: &str, target: Option<&str>, success: bool) -> ActionEvent {
312        let mut builder =
313            ActionEventBuilder::new(1, WorkerId(0), action).duration(Duration::from_millis(100));
314
315        if let Some(t) = target {
316            builder = builder.target(t);
317        }
318
319        let result = if success {
320            ActionEventResult::success()
321        } else {
322            ActionEventResult::failure("error")
323        };
324
325        builder.result(result).build()
326    }
327
328    #[test]
329    fn test_swarm_stats_basic() {
330        let mut stats = SwarmStats::new();
331
332        stats.record(&make_event("CheckStatus", Some("svc1"), true));
333        stats.record(&make_event("CheckStatus", Some("svc1"), true));
334        stats.record(&make_event("CheckStatus", Some("svc2"), false));
335
336        let action_stats = stats.get_action_stats("CheckStatus");
337        assert_eq!(action_stats.visits, 3);
338        assert_eq!(action_stats.successes, 2);
339        assert_eq!(action_stats.failures, 1);
340        assert!((action_stats.success_rate() - 0.666).abs() < 0.01);
341
342        let at_stats = stats.get_action_target_stats("CheckStatus", "svc1");
343        assert_eq!(at_stats.visits, 2);
344        assert_eq!(at_stats.successes, 2);
345    }
346
347    #[test]
348    fn test_swarm_stats_global() {
349        let mut stats = SwarmStats::new();
350
351        stats.record(&make_event("A", None, true));
352        stats.record(&make_event("B", None, true));
353        stats.record(&make_event("C", None, false));
354
355        assert_eq!(stats.total_visits(), 3);
356        assert_eq!(stats.total_successes(), 2);
357        assert_eq!(stats.total_failures(), 1);
358        assert!((stats.success_rate() - 0.666).abs() < 0.01);
359    }
360
361    #[test]
362    fn test_llm_invoke_skipped() {
363        let mut stats = SwarmStats::new();
364
365        let e = ActionEventBuilder::new(1, WorkerId(0), "llm_invoke")
366            .result(ActionEventResult::success())
367            .build();
368        stats.record(&e);
369
370        assert_eq!(stats.total_visits(), 0);
371    }
372
373    #[test]
374    fn test_manager_skipped() {
375        let mut stats = SwarmStats::new();
376
377        let e = ActionEventBuilder::new(1, WorkerId::MANAGER, "decide")
378            .result(ActionEventResult::success())
379            .build();
380        stats.record(&e);
381
382        assert_eq!(stats.total_visits(), 0);
383    }
384
385    #[test]
386    fn test_kpi_contribution() {
387        let mut stats = SwarmStats::new();
388
389        // KPIなしのイベント
390        stats.record(&make_event("A", None, true));
391
392        // KPIありのイベント
393        let e1 = ActionEventBuilder::new(1, WorkerId(0), "B")
394            .result(ActionEventResult::success().with_kpi(0.5))
395            .build();
396        stats.record(&e1);
397
398        let e2 = ActionEventBuilder::new(2, WorkerId(0), "B")
399            .result(ActionEventResult::success().with_kpi(0.8))
400            .build();
401        stats.record(&e2);
402
403        // グローバル統計
404        assert_eq!(stats.global().kpi_count, 2);
405        assert!((stats.global().total_kpi - 1.3).abs() < 0.01);
406        assert!((stats.global().avg_kpi() - 0.65).abs() < 0.01);
407
408        // アクション別統計
409        let a_stats = stats.get_action_stats("A");
410        assert_eq!(a_stats.kpi_count, 0);
411        assert_eq!(a_stats.avg_kpi(), 0.0);
412
413        let b_stats = stats.get_action_stats("B");
414        assert_eq!(b_stats.kpi_count, 2);
415        assert!((b_stats.kpi_total - 1.3).abs() < 0.01);
416        assert!((b_stats.avg_kpi() - 0.65).abs() < 0.01);
417    }
418}