Skip to main content

trustformers_core/ab_testing/
mod.rs

1//! A/B Testing Framework for TrustformeRS
2//!
3//! This module provides a comprehensive A/B testing framework for comparing
4//! different model versions or configurations in production settings.
5
6mod analysis;
7mod deployment;
8mod experiment;
9mod metrics;
10mod routing;
11
12pub use analysis::{ConfidenceLevel, StatisticalAnalyzer, TestRecommendation, TestResult};
13pub use deployment::{
14    DeploymentStrategy, HealthCheck, HealthCheckType, RollbackCondition, RolloutController,
15    RolloutStatus,
16};
17pub use experiment::{Experiment, ExperimentConfig, Variant};
18pub use metrics::{MetricCollector, MetricDataPoint, MetricType, MetricValue};
19pub use routing::{RoutingStrategy, TrafficSplitter, UserSegment};
20
21use anyhow::Result;
22use parking_lot::RwLock;
23use std::sync::Arc;
24
25/// Main A/B testing manager
26pub struct ABTestManager {
27    experiments: Arc<RwLock<Vec<Experiment>>>,
28    traffic_splitter: Arc<TrafficSplitter>,
29    metric_collector: Arc<MetricCollector>,
30    analyzer: Arc<StatisticalAnalyzer>,
31    rollout_controller: Arc<RolloutController>,
32}
33
34impl Default for ABTestManager {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl ABTestManager {
41    /// Create a new A/B test manager
42    pub fn new() -> Self {
43        Self {
44            experiments: Arc::new(RwLock::new(Vec::new())),
45            traffic_splitter: Arc::new(TrafficSplitter::new()),
46            metric_collector: Arc::new(MetricCollector::new()),
47            analyzer: Arc::new(StatisticalAnalyzer::new()),
48            rollout_controller: Arc::new(RolloutController::new()),
49        }
50    }
51
52    /// Create a new experiment
53    pub fn create_experiment(&self, config: ExperimentConfig) -> Result<String> {
54        let experiment = Experiment::new(config)?;
55        let experiment_id = experiment.id().to_string();
56
57        self.experiments.write().push(experiment);
58        Ok(experiment_id)
59    }
60
61    /// Route a request to appropriate variant
62    pub fn route_request(&self, experiment_id: &str, user_id: &str) -> Result<Variant> {
63        let experiments = self.experiments.read();
64        let experiment_uuid = uuid::Uuid::parse_str(experiment_id)?;
65        let experiment = experiments
66            .iter()
67            .find(|e| *e.id() == experiment_uuid)
68            .ok_or_else(|| anyhow::anyhow!("Experiment not found"))?;
69
70        self.traffic_splitter.route(experiment, user_id)
71    }
72
73    /// Record a metric for an experiment
74    pub fn record_metric(
75        &self,
76        experiment_id: &str,
77        variant: &Variant,
78        metric_type: MetricType,
79        value: MetricValue,
80    ) -> Result<()> {
81        self.metric_collector.record(experiment_id, variant, metric_type, value)
82    }
83
84    /// Analyze experiment results
85    pub fn analyze_experiment(&self, experiment_id: &str) -> Result<TestResult> {
86        let metrics = self.metric_collector.get_metrics(experiment_id)?;
87        self.analyzer.analyze(metrics)
88    }
89
90    /// Get experiment status
91    pub fn get_experiment_status(&self, experiment_id: &str) -> Result<ExperimentStatus> {
92        let experiments = self.experiments.read();
93        let experiment_uuid = uuid::Uuid::parse_str(experiment_id)?;
94        let experiment = experiments
95            .iter()
96            .find(|e| *e.id() == experiment_uuid)
97            .ok_or_else(|| anyhow::anyhow!("Experiment not found"))?;
98
99        // Convert experiment::ExperimentStatus to ab_testing::ExperimentStatus
100        let status = match experiment.status() {
101            crate::ab_testing::experiment::ExperimentStatus::Draft => ExperimentStatus::Draft,
102            crate::ab_testing::experiment::ExperimentStatus::Running => ExperimentStatus::Running,
103            crate::ab_testing::experiment::ExperimentStatus::Paused => ExperimentStatus::Paused,
104            crate::ab_testing::experiment::ExperimentStatus::Concluded => {
105                ExperimentStatus::Concluded
106            },
107            crate::ab_testing::experiment::ExperimentStatus::Cancelled => {
108                ExperimentStatus::Cancelled
109            },
110        };
111        Ok(status)
112    }
113
114    /// Start an experiment
115    pub fn start_experiment(&self, experiment_id: &str) -> Result<()> {
116        let mut experiments = self.experiments.write();
117        let experiment_uuid = uuid::Uuid::parse_str(experiment_id)?;
118        let experiment = experiments
119            .iter_mut()
120            .find(|e| *e.id() == experiment_uuid)
121            .ok_or_else(|| anyhow::anyhow!("Experiment not found"))?;
122
123        experiment.start()
124    }
125
126    /// Promote winning variant
127    pub fn promote_variant(&self, experiment_id: &str, variant: &Variant) -> Result<()> {
128        self.rollout_controller.promote(experiment_id, variant)
129    }
130
131    /// Rollback to control variant
132    pub fn rollback(&self, experiment_id: &str) -> Result<()> {
133        self.rollout_controller.rollback(experiment_id)
134    }
135}
136
137/// Experiment status
138#[derive(Debug, Clone, PartialEq)]
139pub enum ExperimentStatus {
140    /// Experiment is being configured
141    Draft,
142    /// Experiment is running
143    Running,
144    /// Experiment is paused
145    Paused,
146    /// Experiment has concluded
147    Concluded,
148    /// Experiment was cancelled
149    Cancelled,
150}
151
152/// A/B test result summary
153#[derive(Debug, Clone)]
154pub struct ABTestSummary {
155    pub experiment_id: String,
156    pub control_metrics: MetricSummary,
157    pub treatment_metrics: MetricSummary,
158    pub statistical_significance: f64,
159    pub confidence_level: ConfidenceLevel,
160    pub recommendation: Recommendation,
161}
162
163/// Metric summary for a variant
164#[derive(Debug, Clone)]
165pub struct MetricSummary {
166    pub variant: Variant,
167    pub sample_size: usize,
168    pub mean: f64,
169    pub std_dev: f64,
170    pub confidence_interval: (f64, f64),
171}
172
173/// Recommendation based on test results
174#[derive(Debug, Clone, PartialEq)]
175pub enum Recommendation {
176    /// Keep the control variant
177    KeepControl,
178    /// Switch to treatment variant
179    AdoptTreatment,
180    /// Continue testing for more data
181    ContinueTesting,
182    /// No significant difference
183    NoPreference,
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_create_experiment() {
192        let manager = ABTestManager::new();
193        let config = ExperimentConfig {
194            name: "Model v2 Test".to_string(),
195            description: "Testing new model architecture".to_string(),
196            control_variant: Variant::new("v1", "model-v1"),
197            treatment_variants: vec![Variant::new("v2", "model-v2")],
198            traffic_percentage: 50.0,
199            min_sample_size: 1000,
200            max_duration_hours: 168,
201        };
202
203        let experiment_id = manager.create_experiment(config).expect("operation failed in test");
204        assert!(!experiment_id.is_empty());
205
206        let status =
207            manager.get_experiment_status(&experiment_id).expect("operation failed in test");
208        assert_eq!(status, ExperimentStatus::Draft);
209    }
210
211    #[test]
212    fn test_route_request() {
213        let manager = ABTestManager::new();
214        let config = ExperimentConfig {
215            name: "Routing Test".to_string(),
216            description: "Test traffic routing".to_string(),
217            control_variant: Variant::new("control", "model-v1"),
218            treatment_variants: vec![Variant::new("treatment", "model-v2")],
219            traffic_percentage: 50.0,
220            min_sample_size: 100,
221            max_duration_hours: 24,
222        };
223
224        let experiment_id = manager.create_experiment(config).expect("operation failed in test");
225
226        // Start the experiment
227        manager.start_experiment(&experiment_id).expect("operation failed in test");
228
229        // Route multiple requests and verify distribution
230        let mut control_count = 0;
231        let mut treatment_count = 0;
232
233        for i in 0..1000 {
234            let user_id = format!("user-{}", i);
235            let variant = manager
236                .route_request(&experiment_id, &user_id)
237                .expect("operation failed in test");
238
239            match variant.name() {
240                "control" => control_count += 1,
241                "treatment" => treatment_count += 1,
242                name => panic!("Unexpected variant name: {}", name),
243            }
244        }
245
246        // Check that distribution is roughly 50/50
247        let ratio = control_count as f64 / (control_count + treatment_count) as f64;
248        assert!((ratio - 0.5).abs() < 0.05); // Allow 5% deviation
249    }
250}