Skip to main content

swarm_engine_eval/
aggregator.rs

1//! Aggregator - Statistical calculations
2//!
3//! N回実行の結果から統計量を計算します。
4
5use serde::{Deserialize, Serialize};
6
7use crate::run::EvalRun;
8
9/// Statistical summary
10#[derive(Debug, Clone, Default, Serialize, Deserialize)]
11pub struct Statistics {
12    /// Number of samples
13    pub n: usize,
14
15    /// Mean value
16    pub mean: f64,
17
18    /// Standard deviation (sample)
19    pub std_dev: f64,
20
21    /// 95% confidence interval lower bound
22    pub ci_95_lower: f64,
23
24    /// 95% confidence interval upper bound
25    pub ci_95_upper: f64,
26
27    /// Minimum value
28    pub min: f64,
29
30    /// Maximum value
31    pub max: f64,
32}
33
34impl Statistics {
35    /// Calculate statistics from values
36    pub fn from_values(values: &[f64]) -> Self {
37        let n = values.len();
38        if n == 0 {
39            return Self::default();
40        }
41
42        let sum: f64 = values.iter().sum();
43        let mean = sum / n as f64;
44
45        let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
46        let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
47
48        if n == 1 {
49            return Self {
50                n,
51                mean,
52                std_dev: 0.0,
53                ci_95_lower: mean,
54                ci_95_upper: mean,
55                min,
56                max,
57            };
58        }
59
60        // Sample standard deviation
61        let variance: f64 = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
62        let std_dev = variance.sqrt();
63
64        // 95% confidence interval using t-distribution
65        let t_value = t_value_95(n - 1);
66        let margin = t_value * std_dev / (n as f64).sqrt();
67
68        Self {
69            n,
70            mean,
71            std_dev,
72            ci_95_lower: mean - margin,
73            ci_95_upper: mean + margin,
74            min,
75            max,
76        }
77    }
78}
79
80/// T-value for 95% confidence interval (two-tailed, alpha=0.05)
81///
82/// Uses table lookup with linear interpolation between known values.
83fn t_value_95(df: usize) -> f64 {
84    // T-table: (degrees_of_freedom, t_value)
85    // Values from standard statistical tables
86    const T_TABLE: &[(usize, f64)] = &[
87        (1, 12.706),
88        (2, 4.303),
89        (3, 3.182),
90        (4, 2.776),
91        (5, 2.571),
92        (6, 2.447),
93        (7, 2.365),
94        (8, 2.306),
95        (9, 2.262),
96        (10, 2.228),
97        (11, 2.201),
98        (12, 2.179),
99        (13, 2.160),
100        (14, 2.145),
101        (15, 2.131),
102        (16, 2.120),
103        (17, 2.110),
104        (18, 2.101),
105        (19, 2.093),
106        (20, 2.086),
107        (25, 2.060),
108        (30, 2.042),
109        (40, 2.021),
110        (50, 2.009),
111        (60, 2.000),
112        (80, 1.990),
113        (100, 1.984),
114        (120, 1.980),
115    ];
116
117    // Normal approximation for very large df
118    if df > 120 {
119        return 1.96;
120    }
121
122    // Find surrounding values for interpolation
123    let mut lower = (1_usize, 12.706_f64);
124    let mut upper = (120_usize, 1.980_f64);
125
126    for &(table_df, t_val) in T_TABLE {
127        if table_df == df {
128            return t_val;
129        }
130        if table_df < df {
131            lower = (table_df, t_val);
132        } else {
133            upper = (table_df, t_val);
134            break;
135        }
136    }
137
138    // Linear interpolation between lower and upper
139    let (df_low, t_low) = lower;
140    let (df_high, t_high) = upper;
141    let ratio = (df - df_low) as f64 / (df_high - df_low) as f64;
142    t_low + (t_high - t_low) * ratio
143}
144
145/// Aggregated results
146#[derive(Debug, Clone, Default, Serialize, Deserialize)]
147pub struct AggregatedResults {
148    /// Total number of runs
149    pub total_runs: usize,
150
151    /// Number of successful runs
152    pub successful_runs: usize,
153
154    /// Success rate
155    pub success_rate: f64,
156
157    /// pass@1
158    pub pass_at_1: f64,
159
160    /// pass@5
161    pub pass_at_5: f64,
162
163    /// pass@10
164    pub pass_at_10: Option<f64>,
165
166    /// Statistics for various metrics
167    pub statistics: AggregatedStatistics,
168}
169
170/// Aggregated statistics for various metrics
171#[derive(Debug, Clone, Default, Serialize, Deserialize)]
172pub struct AggregatedStatistics {
173    /// Success rate statistics
174    pub success_rate: Statistics,
175
176    /// Total ticks statistics
177    pub total_ticks: Statistics,
178
179    /// Tick latency p95 statistics
180    pub tick_latency_p95_ms: Statistics,
181
182    /// Tick latency p99 statistics
183    pub tick_latency_p99_ms: Statistics,
184
185    /// Tick miss rate statistics
186    pub tick_miss_rate: Statistics,
187
188    /// Tick jitter statistics
189    pub tick_jitter: Statistics,
190
191    /// Manager intervention rate statistics
192    pub manager_intervention_rate: Statistics,
193
194    /// Raw throughput statistics (all actions per second)
195    pub raw_throughput_per_sec: Statistics,
196
197    /// Effective throughput statistics (successful actions per second)
198    pub effective_throughput_per_sec: Statistics,
199
200    /// Total LLM invocations statistics
201    pub llm_invocations: Statistics,
202
203    /// LLM invocation errors statistics
204    pub llm_invoke_errors: Statistics,
205
206    /// LLM error rate statistics (errors / invocations)
207    pub llm_error_rate: Statistics,
208
209    /// Total LLM invocations across all runs (sum)
210    pub total_llm_invocations: u64,
211
212    /// Total LLM errors across all runs (sum)
213    pub total_llm_errors: u64,
214}
215
216/// Aggregator for evaluation runs
217pub struct Aggregator;
218
219impl Aggregator {
220    /// Aggregate evaluation runs
221    pub fn aggregate(runs: &[EvalRun]) -> AggregatedResults {
222        let total_runs = runs.len();
223        if total_runs == 0 {
224            return AggregatedResults::default();
225        }
226
227        let successful_runs = runs.iter().filter(|r| r.success).count();
228        let success_rate = successful_runs as f64 / total_runs as f64;
229
230        // Collect metric values for statistics
231        let success_rates: Vec<f64> = runs.iter().map(|r| r.metrics.task.success_rate).collect();
232        let total_ticks: Vec<f64> = runs
233            .iter()
234            .map(|r| r.metrics.task.total_ticks as f64)
235            .collect();
236        let tick_latency_p95: Vec<f64> = runs
237            .iter()
238            .map(|r| r.metrics.performance.tick_latency_p95_ms)
239            .collect();
240        let tick_latency_p99: Vec<f64> = runs
241            .iter()
242            .map(|r| r.metrics.performance.tick_latency_p99_ms)
243            .collect();
244        let tick_miss_rates: Vec<f64> = runs
245            .iter()
246            .map(|r| r.metrics.performance.tick_miss_rate)
247            .collect();
248        let tick_jitters: Vec<f64> = runs
249            .iter()
250            .map(|r| r.metrics.performance.tick_jitter)
251            .collect();
252        let manager_rates: Vec<f64> = runs
253            .iter()
254            .map(|r| r.metrics.coordination.manager_intervention_rate)
255            .collect();
256        let raw_throughputs: Vec<f64> = runs
257            .iter()
258            .map(|r| r.metrics.performance.raw_throughput_per_sec)
259            .collect();
260        let effective_throughputs: Vec<f64> = runs
261            .iter()
262            .map(|r| r.metrics.performance.effective_throughput_per_sec)
263            .collect();
264        let llm_invocations: Vec<f64> = runs
265            .iter()
266            .map(|r| r.metrics.performance.llm_invocations as f64)
267            .collect();
268        let llm_errors: Vec<f64> = runs
269            .iter()
270            .map(|r| r.metrics.performance.llm_invoke_errors as f64)
271            .collect();
272        let llm_error_rates: Vec<f64> = runs
273            .iter()
274            .map(|r| r.metrics.performance.llm_error_rate)
275            .collect();
276        let total_llm_invocations: u64 = runs
277            .iter()
278            .map(|r| r.metrics.performance.llm_invocations)
279            .sum();
280        let total_llm_errors: u64 = runs
281            .iter()
282            .map(|r| r.metrics.performance.llm_invoke_errors)
283            .sum();
284
285        // Calculate pass@k
286        let pass_at_1 = success_rate;
287        let pass_at_5 = Self::calculate_pass_at_k(total_runs, successful_runs, 5);
288        let pass_at_10 = if total_runs >= 10 {
289            Some(Self::calculate_pass_at_k(total_runs, successful_runs, 10))
290        } else {
291            None
292        };
293
294        AggregatedResults {
295            total_runs,
296            successful_runs,
297            success_rate,
298            pass_at_1,
299            pass_at_5,
300            pass_at_10,
301            statistics: AggregatedStatistics {
302                success_rate: Statistics::from_values(&success_rates),
303                total_ticks: Statistics::from_values(&total_ticks),
304                tick_latency_p95_ms: Statistics::from_values(&tick_latency_p95),
305                tick_latency_p99_ms: Statistics::from_values(&tick_latency_p99),
306                tick_miss_rate: Statistics::from_values(&tick_miss_rates),
307                tick_jitter: Statistics::from_values(&tick_jitters),
308                manager_intervention_rate: Statistics::from_values(&manager_rates),
309                raw_throughput_per_sec: Statistics::from_values(&raw_throughputs),
310                effective_throughput_per_sec: Statistics::from_values(&effective_throughputs),
311                llm_invocations: Statistics::from_values(&llm_invocations),
312                llm_invoke_errors: Statistics::from_values(&llm_errors),
313                llm_error_rate: Statistics::from_values(&llm_error_rates),
314                total_llm_invocations,
315                total_llm_errors,
316            },
317        }
318    }
319
320    /// Calculate pass@k
321    ///
322    /// pass@k = 1 - C(n-c, k) / C(n, k)
323    /// where n = total runs, c = successful runs, k = sample size
324    fn calculate_pass_at_k(n: usize, c: usize, k: usize) -> f64 {
325        if k > n {
326            return if c > 0 { 1.0 } else { 0.0 };
327        }
328        if c >= n {
329            return 1.0;
330        }
331        if c == 0 {
332            return 0.0;
333        }
334        if n - c < k {
335            return 1.0; // Not enough failures to fill k samples
336        }
337
338        // Calculate in log space to avoid overflow
339        // C(n-c, k) / C(n, k) = product((n-c-i)/(n-i)) for i in 0..k
340        let mut ratio = 1.0;
341        for i in 0..k {
342            ratio *= (n - c - i) as f64 / (n - i) as f64;
343        }
344        1.0 - ratio
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_statistics_empty() {
354        let stats = Statistics::from_values(&[]);
355        assert_eq!(stats.n, 0);
356        assert_eq!(stats.mean, 0.0);
357    }
358
359    #[test]
360    fn test_statistics_single() {
361        let stats = Statistics::from_values(&[0.8]);
362        assert_eq!(stats.n, 1);
363        assert!((stats.mean - 0.8).abs() < 0.001);
364        assert_eq!(stats.std_dev, 0.0);
365    }
366
367    #[test]
368    fn test_statistics_multiple() {
369        let values = vec![0.7, 0.8, 0.9];
370        let stats = Statistics::from_values(&values);
371        assert_eq!(stats.n, 3);
372        assert!((stats.mean - 0.8).abs() < 0.001);
373        assert!(stats.std_dev > 0.0);
374        assert!(stats.ci_95_lower < stats.mean);
375        assert!(stats.ci_95_upper > stats.mean);
376    }
377
378    #[test]
379    fn test_pass_at_k() {
380        // 80% success rate (24/30)
381        let pass_1 = Aggregator::calculate_pass_at_k(30, 24, 1);
382        assert!((pass_1 - 0.8).abs() < 0.01);
383
384        // pass@5 should be higher than pass@1
385        let pass_5 = Aggregator::calculate_pass_at_k(30, 24, 5);
386        assert!(pass_5 > pass_1);
387
388        // 100% success rate
389        let pass_perfect = Aggregator::calculate_pass_at_k(30, 30, 5);
390        assert!((pass_perfect - 1.0).abs() < 0.01);
391
392        // 0% success rate
393        let pass_zero = Aggregator::calculate_pass_at_k(30, 0, 5);
394        assert!((pass_zero - 0.0).abs() < 0.01);
395    }
396
397    #[test]
398    fn test_t_value_exact_matches() {
399        // Test exact table values
400        assert!((t_value_95(1) - 12.706).abs() < 0.001);
401        assert!((t_value_95(10) - 2.228).abs() < 0.001);
402        assert!((t_value_95(15) - 2.131).abs() < 0.001);
403        assert!((t_value_95(20) - 2.086).abs() < 0.001);
404        assert!((t_value_95(30) - 2.042).abs() < 0.001);
405        assert!((t_value_95(100) - 1.984).abs() < 0.001);
406    }
407
408    #[test]
409    fn test_t_value_interpolation() {
410        // Test interpolation between table values
411        // df=21 should be between df=20 (2.086) and df=25 (2.060)
412        let t_21 = t_value_95(21);
413        assert!(t_21 < 2.086);
414        assert!(t_21 > 2.060);
415        // Expected: approximately 2.080 (from statistical tables)
416        assert!((t_21 - 2.080).abs() < 0.01);
417
418        // df=35 should be between df=30 (2.042) and df=40 (2.021)
419        let t_35 = t_value_95(35);
420        assert!(t_35 < 2.042);
421        assert!(t_35 > 2.021);
422
423        // df=22, 23, 24 should be monotonically decreasing
424        let t_22 = t_value_95(22);
425        let t_23 = t_value_95(23);
426        let t_24 = t_value_95(24);
427        assert!(t_22 > t_23);
428        assert!(t_23 > t_24);
429    }
430
431    #[test]
432    fn test_t_value_large_df() {
433        // df > 120 should use normal approximation (1.96)
434        assert!((t_value_95(200) - 1.96).abs() < 0.001);
435        assert!((t_value_95(1000) - 1.96).abs() < 0.001);
436    }
437}