Skip to main content

trustformers_core/ab_testing/
analysis.rs

1//! Statistical analysis for A/B test results
2
3#![allow(unused_variables)] // A/B testing analysis
4
5use super::{MetricDataPoint, MetricType, Variant};
6use anyhow::Result;
7use std::collections::HashMap;
8
9/// Statistical test results
10#[derive(Debug, Clone)]
11pub struct TestResult {
12    /// Control variant metrics
13    pub control_stats: VariantStatistics,
14    /// Treatment variant metrics
15    pub treatment_stats: Vec<VariantStatistics>,
16    /// Statistical test results
17    pub test_stats: TestStatistics,
18    /// Overall recommendation
19    pub recommendation: TestRecommendation,
20}
21
22/// Statistics for a single variant
23#[derive(Debug, Clone)]
24pub struct VariantStatistics {
25    /// The variant
26    pub variant: Variant,
27    /// Sample size
28    pub sample_size: usize,
29    /// Mean value
30    pub mean: f64,
31    /// Standard deviation
32    pub std_dev: f64,
33    /// Standard error
34    pub std_error: f64,
35    /// Confidence interval
36    pub confidence_interval: (f64, f64),
37}
38
39/// Results of statistical tests
40#[derive(Debug, Clone)]
41pub struct TestStatistics {
42    /// P-value from hypothesis test
43    pub p_value: f64,
44    /// Test statistic (t-stat or z-stat)
45    pub test_statistic: f64,
46    /// Effect size (Cohen's d)
47    pub effect_size: f64,
48    /// Statistical power
49    pub power: f64,
50    /// Minimum detectable effect
51    pub min_detectable_effect: f64,
52    /// Confidence level used
53    pub confidence_level: ConfidenceLevel,
54}
55
56/// Confidence levels for testing
57#[derive(Debug, Clone, Copy, PartialEq)]
58pub enum ConfidenceLevel {
59    /// 90% confidence
60    Low,
61    /// 95% confidence
62    Medium,
63    /// 99% confidence
64    High,
65}
66
67impl ConfidenceLevel {
68    /// Get alpha value
69    pub fn alpha(&self) -> f64 {
70        match self {
71            ConfidenceLevel::Low => 0.10,
72            ConfidenceLevel::Medium => 0.05,
73            ConfidenceLevel::High => 0.01,
74        }
75    }
76
77    /// Get z-score for confidence interval
78    pub fn z_score(&self) -> f64 {
79        match self {
80            ConfidenceLevel::Low => 1.645,
81            ConfidenceLevel::Medium => 1.96,
82            ConfidenceLevel::High => 2.576,
83        }
84    }
85}
86
87/// Test recommendation
88#[derive(Debug, Clone, PartialEq)]
89pub enum TestRecommendation {
90    /// Treatment is significantly better
91    AdoptTreatment { variant: String, improvement: f64 },
92    /// Control is significantly better
93    KeepControl { degradation: f64 },
94    /// No significant difference
95    NoSignificantDifference,
96    /// Need more data
97    InsufficientData { required_sample_size: usize },
98}
99
100/// Statistical analyzer for A/B tests
101pub struct StatisticalAnalyzer {
102    /// Default confidence level
103    default_confidence: ConfidenceLevel,
104    /// Minimum sample size per variant
105    min_sample_size: usize,
106}
107
108impl Default for StatisticalAnalyzer {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl StatisticalAnalyzer {
115    /// Create a new analyzer
116    pub fn new() -> Self {
117        Self {
118            default_confidence: ConfidenceLevel::Medium,
119            min_sample_size: 30,
120        }
121    }
122
123    /// Create with custom settings
124    pub fn with_settings(confidence: ConfidenceLevel, min_sample_size: usize) -> Self {
125        Self {
126            default_confidence: confidence,
127            min_sample_size,
128        }
129    }
130
131    /// Analyze experiment results
132    pub fn analyze(
133        &self,
134        metrics: HashMap<(Variant, MetricType), Vec<MetricDataPoint>>,
135    ) -> Result<TestResult> {
136        // Separate control and treatment metrics, preserving metric type
137        let mut control_data = None;
138        let mut treatment_data = Vec::new();
139        let mut metric_type = None;
140
141        for ((variant, m_type), data_points) in metrics {
142            let values: Vec<f64> = data_points.iter().map(|dp| dp.value.as_f64()).collect();
143
144            // Store the metric type (assume all entries have the same metric type)
145            if metric_type.is_none() {
146                metric_type = Some(m_type.clone());
147            }
148
149            if variant.name() == "control" {
150                control_data = Some((variant, values));
151            } else {
152                treatment_data.push((variant, values));
153            }
154        }
155
156        let primary_metric_type =
157            metric_type.ok_or_else(|| anyhow::anyhow!("No metric type found"))?;
158
159        let (control_variant, control_values) =
160            control_data.ok_or_else(|| anyhow::anyhow!("No control variant data found"))?;
161
162        if treatment_data.is_empty() {
163            anyhow::bail!("No treatment variant data found");
164        }
165
166        // Calculate statistics for control
167        let control_stats = self.calculate_variant_stats(control_variant, &control_values)?;
168
169        // Calculate statistics for treatments
170        let mut treatment_stats = Vec::new();
171        let mut best_treatment = None;
172        let mut best_p_value = f64::INFINITY; // Start with infinity so any p-value will be better
173
174        for (variant, values) in treatment_data {
175            let stats = self.calculate_variant_stats(variant.clone(), &values)?;
176
177            // Perform hypothesis test
178            let test_result = self.perform_test(&control_values, &values)?;
179
180            if test_result.p_value < best_p_value {
181                best_p_value = test_result.p_value;
182                best_treatment = Some((variant, stats.clone(), test_result));
183            }
184
185            treatment_stats.push(stats);
186        }
187
188        // Generate recommendation
189        let recommendation = if let Some((variant, stats, test_result)) = &best_treatment {
190            self.generate_recommendation(
191                &control_stats,
192                stats,
193                test_result,
194                variant,
195                &primary_metric_type,
196            )
197        } else {
198            TestRecommendation::NoSignificantDifference
199        };
200
201        // Use the best treatment's test statistics
202        let test_stats = if let Some((_, _, test_result)) = best_treatment {
203            test_result
204        } else {
205            TestStatistics {
206                p_value: 1.0,
207                test_statistic: 0.0,
208                effect_size: 0.0,
209                power: 0.0,
210                min_detectable_effect: 0.0,
211                confidence_level: self.default_confidence,
212            }
213        };
214
215        Ok(TestResult {
216            control_stats,
217            treatment_stats,
218            test_stats,
219            recommendation,
220        })
221    }
222
223    /// Calculate statistics for a variant
224    fn calculate_variant_stats(
225        &self,
226        variant: Variant,
227        values: &[f64],
228    ) -> Result<VariantStatistics> {
229        let sample_size = values.len();
230        if sample_size == 0 {
231            anyhow::bail!("No data points for variant");
232        }
233
234        let mean = values.iter().sum::<f64>() / sample_size as f64;
235        let variance =
236            values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (sample_size - 1) as f64;
237        let std_dev = variance.sqrt();
238        let std_error = std_dev / (sample_size as f64).sqrt();
239
240        let z_score = self.default_confidence.z_score();
241        let margin_of_error = z_score * std_error;
242        let confidence_interval = (mean - margin_of_error, mean + margin_of_error);
243
244        Ok(VariantStatistics {
245            variant,
246            sample_size,
247            mean,
248            std_dev,
249            std_error,
250            confidence_interval,
251        })
252    }
253
254    /// Perform two-sample t-test
255    fn perform_test(&self, control: &[f64], treatment: &[f64]) -> Result<TestStatistics> {
256        let n1 = control.len() as f64;
257        let n2 = treatment.len() as f64;
258
259        if n1 < self.min_sample_size as f64 || n2 < self.min_sample_size as f64 {
260            return Ok(TestStatistics {
261                p_value: 1.0,
262                test_statistic: 0.0,
263                effect_size: 0.0,
264                power: 0.0,
265                min_detectable_effect: 0.0,
266                confidence_level: self.default_confidence,
267            });
268        }
269
270        // Calculate means
271        let mean1 = control.iter().sum::<f64>() / n1;
272        let mean2 = treatment.iter().sum::<f64>() / n2;
273
274        // Calculate variances
275        let var1 = control.iter().map(|v| (v - mean1).powi(2)).sum::<f64>() / (n1 - 1.0);
276        let var2 = treatment.iter().map(|v| (v - mean2).powi(2)).sum::<f64>() / (n2 - 1.0);
277
278        // Pooled standard deviation
279        let pooled_std = (((n1 - 1.0) * var1 + (n2 - 1.0) * var2) / (n1 + n2 - 2.0)).sqrt();
280
281        // Test statistic
282        let test_statistic = (mean2 - mean1) / (pooled_std * (1.0 / n1 + 1.0 / n2).sqrt());
283
284        // Degrees of freedom
285        let df = n1 + n2 - 2.0;
286
287        // P-value (simplified - in practice use a proper t-distribution)
288        let p_value = self.calculate_p_value(test_statistic.abs(), df);
289
290        // Effect size (Cohen's d)
291        let effect_size = (mean2 - mean1).abs() / pooled_std;
292
293        // Statistical power (simplified calculation)
294        let power = self.calculate_power(effect_size, n1, n2);
295
296        // Minimum detectable effect
297        let min_detectable_effect = self.calculate_mde(n1, n2, pooled_std);
298
299        Ok(TestStatistics {
300            p_value,
301            test_statistic,
302            effect_size,
303            power,
304            min_detectable_effect,
305            confidence_level: self.default_confidence,
306        })
307    }
308
309    /// Calculate p-value (simplified)
310    fn calculate_p_value(&self, t_stat: f64, _df: f64) -> f64 {
311        // Simplified normal approximation
312        // In practice, use proper t-distribution
313        let z = t_stat;
314        2.0 * (1.0 - self.normal_cdf(z))
315    }
316
317    /// Normal CDF approximation
318    fn normal_cdf(&self, x: f64) -> f64 {
319        0.5 * (1.0 + self.erf(x / std::f64::consts::SQRT_2))
320    }
321
322    /// Error function approximation
323    fn erf(&self, x: f64) -> f64 {
324        // Abramowitz and Stegun approximation
325        let a1 = 0.254829592;
326        let a2 = -0.284496736;
327        let a3 = 1.421413741;
328        let a4 = -1.453152027;
329        let a5 = 1.061405429;
330        let p = 0.3275911;
331
332        let sign = if x < 0.0 { -1.0 } else { 1.0 };
333        let x = x.abs();
334
335        let t = 1.0 / (1.0 + p * x);
336        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
337
338        sign * y
339    }
340
341    /// Calculate statistical power
342    fn calculate_power(&self, effect_size: f64, n1: f64, n2: f64) -> f64 {
343        // Simplified power calculation
344        let n_harmonic = (n1 * n2) / (n1 + n2);
345        let noncentrality = effect_size * (n_harmonic / 2.0).sqrt();
346        let critical_value = self.default_confidence.z_score();
347
348        // Power = P(Z > critical_value - noncentrality)
349        1.0 - self.normal_cdf(critical_value - noncentrality)
350    }
351
352    /// Calculate minimum detectable effect
353    fn calculate_mde(&self, n1: f64, n2: f64, pooled_std: f64) -> f64 {
354        let alpha = self.default_confidence.alpha();
355        let beta = 0.2; // 80% power
356        let z_alpha = self.default_confidence.z_score();
357        let z_beta = 0.84; // z-score for 80% power
358
359        let n_harmonic = (n1 * n2) / (n1 + n2);
360        (z_alpha + z_beta) * pooled_std * (2.0 / n_harmonic).sqrt()
361    }
362
363    /// Generate recommendation
364    fn generate_recommendation(
365        &self,
366        control: &VariantStatistics,
367        treatment: &VariantStatistics,
368        test_stats: &TestStatistics,
369        variant: &Variant,
370        metric_type: &MetricType,
371    ) -> TestRecommendation {
372        // Check sample size
373        if control.sample_size < self.min_sample_size
374            || treatment.sample_size < self.min_sample_size
375        {
376            let required =
377                self.min_sample_size.max(control.sample_size).max(treatment.sample_size) * 2;
378            return TestRecommendation::InsufficientData {
379                required_sample_size: required,
380            };
381        }
382
383        // Check statistical significance
384        if test_stats.p_value >= self.default_confidence.alpha() {
385            return TestRecommendation::NoSignificantDifference;
386        }
387
388        // Calculate improvement based on metric type directionality
389        let improvement = ((treatment.mean - control.mean) / control.mean) * 100.0;
390
391        // Determine if treatment is better based on metric type
392        let treatment_is_better = if metric_type.lower_is_better() {
393            treatment.mean < control.mean
394        } else {
395            treatment.mean > control.mean
396        };
397
398        if treatment_is_better {
399            TestRecommendation::AdoptTreatment {
400                variant: variant.name().to_string(),
401                improvement, // Keep the raw improvement (negative for latency improvement)
402            }
403        } else {
404            TestRecommendation::KeepControl {
405                degradation: improvement.abs(),
406            }
407        }
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use crate::ab_testing::MetricValue;
415
416    fn create_test_data(mean: f64, std_dev: f64, size: usize) -> Vec<MetricDataPoint> {
417        use scirs2_core::random::*;
418        let normal = Normal::new(mean, std_dev).expect("operation failed in test");
419        let mut rng = thread_rng();
420
421        (0..size)
422            .map(|_| MetricDataPoint {
423                timestamp: chrono::Utc::now(),
424                value: MetricValue::Numeric(normal.sample(&mut rng)),
425                metadata: None,
426            })
427            .collect()
428    }
429
430    #[test]
431    fn test_significant_difference() {
432        let analyzer = StatisticalAnalyzer::new();
433
434        let control = Variant::new("control", "v1");
435        let treatment = Variant::new("treatment", "v2");
436
437        let mut metrics = HashMap::new();
438
439        // Control: mean=100, std=10
440        metrics.insert(
441            (control.clone(), MetricType::Latency),
442            create_test_data(100.0, 10.0, 100),
443        );
444
445        // Treatment: mean=90, std=10 (10% improvement)
446        metrics.insert(
447            (treatment.clone(), MetricType::Latency),
448            create_test_data(90.0, 10.0, 100),
449        );
450
451        let result = analyzer.analyze(metrics).expect("operation failed in test");
452
453        // Should detect significant improvement
454        match result.recommendation {
455            TestRecommendation::AdoptTreatment {
456                variant,
457                improvement,
458            } => {
459                assert_eq!(variant, "treatment");
460                assert!(improvement < 0.0); // Negative because lower latency is better
461            },
462            other => panic!("Expected to recommend treatment adoption, got: {:?}", other),
463        }
464    }
465
466    #[test]
467    fn test_no_significant_difference() {
468        let analyzer = StatisticalAnalyzer::new();
469
470        let control = Variant::new("control", "v1");
471        let treatment = Variant::new("treatment", "v2");
472
473        let mut metrics = HashMap::new();
474
475        // Create deterministic data with identical means to ensure no statistical significance
476        let create_identical_data = |mean: f64, size: usize| -> Vec<MetricDataPoint> {
477            (0..size)
478                .map(|_| MetricDataPoint {
479                    timestamp: chrono::Utc::now(),
480                    value: MetricValue::Numeric(mean),
481                    metadata: None,
482                })
483                .collect()
484        };
485
486        // Both variants have identical performance to ensure no significance
487        metrics.insert(
488            (control.clone(), MetricType::Accuracy),
489            create_identical_data(0.95, 100),
490        );
491
492        metrics.insert(
493            (treatment.clone(), MetricType::Accuracy),
494            create_identical_data(0.95, 100),
495        );
496
497        let result = analyzer.analyze(metrics).expect("operation failed in test");
498
499        assert_eq!(
500            result.recommendation,
501            TestRecommendation::NoSignificantDifference
502        );
503    }
504
505    #[test]
506    fn test_insufficient_data() {
507        let analyzer = StatisticalAnalyzer::new();
508
509        let control = Variant::new("control", "v1");
510        let treatment = Variant::new("treatment", "v2");
511
512        let mut metrics = HashMap::new();
513
514        // Too few samples
515        metrics.insert(
516            (control.clone(), MetricType::Throughput),
517            create_test_data(1000.0, 50.0, 10),
518        );
519
520        metrics.insert(
521            (treatment.clone(), MetricType::Throughput),
522            create_test_data(1100.0, 50.0, 10),
523        );
524
525        let result = analyzer.analyze(metrics).expect("operation failed in test");
526
527        match result.recommendation {
528            TestRecommendation::InsufficientData {
529                required_sample_size,
530            } => {
531                assert!(required_sample_size > 20);
532            },
533            ref other => panic!(
534                "Expected insufficient data recommendation, got: {:?}",
535                other
536            ),
537        }
538    }
539}