Skip to main content

swarm_engine_core/online_stats/
swarm.rs

1//! SwarmStats - オンライン統計
2//!
3//! Swarm 内でリアルタイムに更新される基本統計。
4//! Selection(UCB1/Thompson)が参照する visits/success_rate を提供。
5//!
6//! # 責務
7//!
8//! - アクション別統計(visits, successes, failures)
9//! - アクション×ターゲット別統計
10//! - グローバル統計
11//!
12//! # 分離された責務
13//!
14//! - 学習統計(EpisodeTransitions, NgramStats 等)→ `learn::LearnStats`
15
16use std::collections::HashMap;
17use std::time::Duration;
18
19use serde::{Deserialize, Serialize};
20
21use crate::events::ActionEvent;
22
23/// Swarm オンライン統計
24///
25/// ActionEvent を受け取り、Selection が参照する基本統計を提供。
26/// UCB1/Thompson が必要とする visits/success_rate をリアルタイムに更新。
27#[derive(Debug, Clone, Default)]
28pub struct SwarmStats {
29    /// アクション別統計
30    action_stats: HashMap<String, ActionStats>,
31    /// アクション × ターゲット別統計
32    action_target_stats: HashMap<(String, String), ActionStats>,
33    /// グローバル統計(Worker アクションのみ)
34    global: GlobalStats,
35}
36
37/// アクション単位の統計
38#[derive(Debug, Clone, Default, Serialize, Deserialize)]
39pub struct ActionStats {
40    /// 訪問回数
41    pub visits: u32,
42    /// 成功回数
43    pub successes: u32,
44    /// 失敗回数
45    pub failures: u32,
46    /// 発見数の合計
47    pub discoveries: u32,
48    /// 総実行時間
49    #[serde(
50        serialize_with = "serialize_duration",
51        deserialize_with = "deserialize_duration"
52    )]
53    pub total_duration: Duration,
54}
55
56fn serialize_duration<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
57where
58    S: serde::Serializer,
59{
60    serializer.serialize_u64(duration.as_millis() as u64)
61}
62
63fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
64where
65    D: serde::Deserializer<'de>,
66{
67    let millis = u64::deserialize(deserializer)?;
68    Ok(Duration::from_millis(millis))
69}
70
71impl ActionStats {
72    /// 成功率(訪問なしの場合は 0.5)
73    pub fn success_rate(&self) -> f64 {
74        if self.visits == 0 {
75            0.5
76        } else {
77            self.successes as f64 / self.visits as f64
78        }
79    }
80
81    /// 平均実行時間
82    pub fn avg_duration(&self) -> Duration {
83        if self.visits == 0 {
84            Duration::ZERO
85        } else {
86            self.total_duration / self.visits
87        }
88    }
89
90    /// 平均発見数
91    pub fn avg_discoveries(&self) -> f64 {
92        if self.visits == 0 {
93            0.0
94        } else {
95            self.discoveries as f64 / self.visits as f64
96        }
97    }
98}
99
100/// グローバル統計
101#[derive(Debug, Clone, Default)]
102pub struct GlobalStats {
103    /// 総訪問回数
104    pub total_visits: u32,
105    /// 総成功回数
106    pub total_successes: u32,
107    /// 総失敗回数
108    pub total_failures: u32,
109    /// 総発見数
110    pub total_discoveries: u32,
111    /// 総実行時間
112    pub total_duration: Duration,
113}
114
115impl GlobalStats {
116    /// 成功率(訪問なしの場合は 1.0)
117    pub fn success_rate(&self) -> f64 {
118        if self.total_visits == 0 {
119            1.0
120        } else {
121            self.total_successes as f64 / self.total_visits as f64
122        }
123    }
124
125    /// 失敗率(訪問なしの場合は 0.0)
126    pub fn failure_rate(&self) -> f64 {
127        if self.total_visits == 0 {
128            0.0
129        } else {
130            self.total_failures as f64 / self.total_visits as f64
131        }
132    }
133}
134
135impl SwarmStats {
136    pub fn new() -> Self {
137        Self::default()
138    }
139
140    /// イベントを記録
141    ///
142    /// Selection 用の統計のみ更新。
143    /// - LLM イベント、Manager イベントはスキップ
144    /// - Worker の通常アクションのみ統計に反映
145    pub fn record(&mut self, event: &ActionEvent) {
146        let action = &event.action;
147        let success = event.result.success;
148        let duration = event.duration;
149
150        // LLM 呼び出しイベントはスキップ
151        if action == "llm_invoke" {
152            return;
153        }
154
155        // Manager イベントはスキップ
156        if event.worker_id.is_manager() {
157            return;
158        }
159
160        // tick_start/tick_end はスキップ
161        if action == "tick_start" || action == "tick_end" {
162            return;
163        }
164
165        let target = event.target.as_deref();
166        let discoveries = event.result.discoveries;
167
168        // グローバル統計を更新
169        self.global.total_visits += 1;
170        self.global.total_duration += duration;
171        if success {
172            self.global.total_successes += 1;
173            self.global.total_discoveries += discoveries;
174        } else {
175            self.global.total_failures += 1;
176        }
177
178        // アクション統計を更新
179        let action_stat = self.action_stats.entry(action.clone()).or_default();
180        action_stat.visits += 1;
181        action_stat.total_duration += duration;
182        if success {
183            action_stat.successes += 1;
184            action_stat.discoveries += discoveries;
185        } else {
186            action_stat.failures += 1;
187        }
188
189        // アクション × ターゲット統計を更新
190        if let Some(t) = target {
191            let at_stat = self
192                .action_target_stats
193                .entry((action.clone(), t.to_string()))
194                .or_default();
195            at_stat.visits += 1;
196            at_stat.total_duration += duration;
197            if success {
198                at_stat.successes += 1;
199                at_stat.discoveries += discoveries;
200            } else {
201                at_stat.failures += 1;
202            }
203        }
204    }
205
206    /// アクションの統計を取得
207    pub fn get_action_stats(&self, action: &str) -> ActionStats {
208        self.action_stats.get(action).cloned().unwrap_or_default()
209    }
210
211    /// アクション × ターゲットの統計を取得
212    pub fn get_action_target_stats(&self, action: &str, target: &str) -> ActionStats {
213        self.action_target_stats
214            .get(&(action.to_string(), target.to_string()))
215            .cloned()
216            .unwrap_or_default()
217    }
218
219    /// グローバル統計を取得
220    pub fn global(&self) -> &GlobalStats {
221        &self.global
222    }
223
224    /// 総訪問回数
225    pub fn total_visits(&self) -> u32 {
226        self.global.total_visits
227    }
228
229    /// 総成功回数
230    pub fn total_successes(&self) -> u32 {
231        self.global.total_successes
232    }
233
234    /// 総失敗回数
235    pub fn total_failures(&self) -> u32 {
236        self.global.total_failures
237    }
238
239    /// 全体の成功率
240    pub fn success_rate(&self) -> f64 {
241        self.global.success_rate()
242    }
243
244    /// 全体の失敗率
245    pub fn failure_rate(&self) -> f64 {
246        self.global.failure_rate()
247    }
248
249    /// 全アクションの統計一覧
250    pub fn all_action_stats(&self) -> impl Iterator<Item = (&String, &ActionStats)> {
251        self.action_stats.iter()
252    }
253
254    /// 統計をリセット
255    pub fn reset(&mut self) {
256        self.action_stats.clear();
257        self.action_target_stats.clear();
258        self.global = GlobalStats::default();
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use std::time::Duration;
265
266    use super::*;
267    use crate::events::{ActionEventBuilder, ActionEventResult};
268    use crate::types::WorkerId;
269
270    fn make_event(action: &str, target: Option<&str>, success: bool) -> ActionEvent {
271        let mut builder =
272            ActionEventBuilder::new(1, WorkerId(0), action).duration(Duration::from_millis(100));
273
274        if let Some(t) = target {
275            builder = builder.target(t);
276        }
277
278        let result = if success {
279            ActionEventResult::success()
280        } else {
281            ActionEventResult::failure("error")
282        };
283
284        builder.result(result).build()
285    }
286
287    #[test]
288    fn test_swarm_stats_basic() {
289        let mut stats = SwarmStats::new();
290
291        stats.record(&make_event("CheckStatus", Some("svc1"), true));
292        stats.record(&make_event("CheckStatus", Some("svc1"), true));
293        stats.record(&make_event("CheckStatus", Some("svc2"), false));
294
295        let action_stats = stats.get_action_stats("CheckStatus");
296        assert_eq!(action_stats.visits, 3);
297        assert_eq!(action_stats.successes, 2);
298        assert_eq!(action_stats.failures, 1);
299        assert!((action_stats.success_rate() - 0.666).abs() < 0.01);
300
301        let at_stats = stats.get_action_target_stats("CheckStatus", "svc1");
302        assert_eq!(at_stats.visits, 2);
303        assert_eq!(at_stats.successes, 2);
304    }
305
306    #[test]
307    fn test_swarm_stats_global() {
308        let mut stats = SwarmStats::new();
309
310        stats.record(&make_event("A", None, true));
311        stats.record(&make_event("B", None, true));
312        stats.record(&make_event("C", None, false));
313
314        assert_eq!(stats.total_visits(), 3);
315        assert_eq!(stats.total_successes(), 2);
316        assert_eq!(stats.total_failures(), 1);
317        assert!((stats.success_rate() - 0.666).abs() < 0.01);
318    }
319
320    #[test]
321    fn test_llm_invoke_skipped() {
322        let mut stats = SwarmStats::new();
323
324        let e = ActionEventBuilder::new(1, WorkerId(0), "llm_invoke")
325            .result(ActionEventResult::success())
326            .build();
327        stats.record(&e);
328
329        assert_eq!(stats.total_visits(), 0);
330    }
331
332    #[test]
333    fn test_manager_skipped() {
334        let mut stats = SwarmStats::new();
335
336        let e = ActionEventBuilder::new(1, WorkerId::MANAGER, "decide")
337            .result(ActionEventResult::success())
338            .build();
339        stats.record(&e);
340
341        assert_eq!(stats.total_visits(), 0);
342    }
343}