trustformers_core/ab_testing/
mod.rs1mod analysis;
7mod deployment;
8mod experiment;
9mod metrics;
10mod routing;
11
12pub use analysis::{ConfidenceLevel, StatisticalAnalyzer, TestRecommendation, TestResult};
13pub use deployment::{
14 DeploymentStrategy, HealthCheck, HealthCheckType, RollbackCondition, RolloutController,
15 RolloutStatus,
16};
17pub use experiment::{Experiment, ExperimentConfig, Variant};
18pub use metrics::{MetricCollector, MetricDataPoint, MetricType, MetricValue};
19pub use routing::{RoutingStrategy, TrafficSplitter, UserSegment};
20
21use anyhow::Result;
22use parking_lot::RwLock;
23use std::sync::Arc;
24
25pub struct ABTestManager {
27 experiments: Arc<RwLock<Vec<Experiment>>>,
28 traffic_splitter: Arc<TrafficSplitter>,
29 metric_collector: Arc<MetricCollector>,
30 analyzer: Arc<StatisticalAnalyzer>,
31 rollout_controller: Arc<RolloutController>,
32}
33
34impl Default for ABTestManager {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl ABTestManager {
41 pub fn new() -> Self {
43 Self {
44 experiments: Arc::new(RwLock::new(Vec::new())),
45 traffic_splitter: Arc::new(TrafficSplitter::new()),
46 metric_collector: Arc::new(MetricCollector::new()),
47 analyzer: Arc::new(StatisticalAnalyzer::new()),
48 rollout_controller: Arc::new(RolloutController::new()),
49 }
50 }
51
52 pub fn create_experiment(&self, config: ExperimentConfig) -> Result<String> {
54 let experiment = Experiment::new(config)?;
55 let experiment_id = experiment.id().to_string();
56
57 self.experiments.write().push(experiment);
58 Ok(experiment_id)
59 }
60
61 pub fn route_request(&self, experiment_id: &str, user_id: &str) -> Result<Variant> {
63 let experiments = self.experiments.read();
64 let experiment_uuid = uuid::Uuid::parse_str(experiment_id)?;
65 let experiment = experiments
66 .iter()
67 .find(|e| *e.id() == experiment_uuid)
68 .ok_or_else(|| anyhow::anyhow!("Experiment not found"))?;
69
70 self.traffic_splitter.route(experiment, user_id)
71 }
72
73 pub fn record_metric(
75 &self,
76 experiment_id: &str,
77 variant: &Variant,
78 metric_type: MetricType,
79 value: MetricValue,
80 ) -> Result<()> {
81 self.metric_collector.record(experiment_id, variant, metric_type, value)
82 }
83
84 pub fn analyze_experiment(&self, experiment_id: &str) -> Result<TestResult> {
86 let metrics = self.metric_collector.get_metrics(experiment_id)?;
87 self.analyzer.analyze(metrics)
88 }
89
90 pub fn get_experiment_status(&self, experiment_id: &str) -> Result<ExperimentStatus> {
92 let experiments = self.experiments.read();
93 let experiment_uuid = uuid::Uuid::parse_str(experiment_id)?;
94 let experiment = experiments
95 .iter()
96 .find(|e| *e.id() == experiment_uuid)
97 .ok_or_else(|| anyhow::anyhow!("Experiment not found"))?;
98
99 let status = match experiment.status() {
101 crate::ab_testing::experiment::ExperimentStatus::Draft => ExperimentStatus::Draft,
102 crate::ab_testing::experiment::ExperimentStatus::Running => ExperimentStatus::Running,
103 crate::ab_testing::experiment::ExperimentStatus::Paused => ExperimentStatus::Paused,
104 crate::ab_testing::experiment::ExperimentStatus::Concluded => {
105 ExperimentStatus::Concluded
106 },
107 crate::ab_testing::experiment::ExperimentStatus::Cancelled => {
108 ExperimentStatus::Cancelled
109 },
110 };
111 Ok(status)
112 }
113
114 pub fn start_experiment(&self, experiment_id: &str) -> Result<()> {
116 let mut experiments = self.experiments.write();
117 let experiment_uuid = uuid::Uuid::parse_str(experiment_id)?;
118 let experiment = experiments
119 .iter_mut()
120 .find(|e| *e.id() == experiment_uuid)
121 .ok_or_else(|| anyhow::anyhow!("Experiment not found"))?;
122
123 experiment.start()
124 }
125
126 pub fn promote_variant(&self, experiment_id: &str, variant: &Variant) -> Result<()> {
128 self.rollout_controller.promote(experiment_id, variant)
129 }
130
131 pub fn rollback(&self, experiment_id: &str) -> Result<()> {
133 self.rollout_controller.rollback(experiment_id)
134 }
135}
136
137#[derive(Debug, Clone, PartialEq)]
139pub enum ExperimentStatus {
140 Draft,
142 Running,
144 Paused,
146 Concluded,
148 Cancelled,
150}
151
152#[derive(Debug, Clone)]
154pub struct ABTestSummary {
155 pub experiment_id: String,
156 pub control_metrics: MetricSummary,
157 pub treatment_metrics: MetricSummary,
158 pub statistical_significance: f64,
159 pub confidence_level: ConfidenceLevel,
160 pub recommendation: Recommendation,
161}
162
163#[derive(Debug, Clone)]
165pub struct MetricSummary {
166 pub variant: Variant,
167 pub sample_size: usize,
168 pub mean: f64,
169 pub std_dev: f64,
170 pub confidence_interval: (f64, f64),
171}
172
173#[derive(Debug, Clone, PartialEq)]
175pub enum Recommendation {
176 KeepControl,
178 AdoptTreatment,
180 ContinueTesting,
182 NoPreference,
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn test_create_experiment() {
192 let manager = ABTestManager::new();
193 let config = ExperimentConfig {
194 name: "Model v2 Test".to_string(),
195 description: "Testing new model architecture".to_string(),
196 control_variant: Variant::new("v1", "model-v1"),
197 treatment_variants: vec![Variant::new("v2", "model-v2")],
198 traffic_percentage: 50.0,
199 min_sample_size: 1000,
200 max_duration_hours: 168,
201 };
202
203 let experiment_id = manager.create_experiment(config).expect("operation failed in test");
204 assert!(!experiment_id.is_empty());
205
206 let status =
207 manager.get_experiment_status(&experiment_id).expect("operation failed in test");
208 assert_eq!(status, ExperimentStatus::Draft);
209 }
210
211 #[test]
212 fn test_route_request() {
213 let manager = ABTestManager::new();
214 let config = ExperimentConfig {
215 name: "Routing Test".to_string(),
216 description: "Test traffic routing".to_string(),
217 control_variant: Variant::new("control", "model-v1"),
218 treatment_variants: vec![Variant::new("treatment", "model-v2")],
219 traffic_percentage: 50.0,
220 min_sample_size: 100,
221 max_duration_hours: 24,
222 };
223
224 let experiment_id = manager.create_experiment(config).expect("operation failed in test");
225
226 manager.start_experiment(&experiment_id).expect("operation failed in test");
228
229 let mut control_count = 0;
231 let mut treatment_count = 0;
232
233 for i in 0..1000 {
234 let user_id = format!("user-{}", i);
235 let variant = manager
236 .route_request(&experiment_id, &user_id)
237 .expect("operation failed in test");
238
239 match variant.name() {
240 "control" => control_count += 1,
241 "treatment" => treatment_count += 1,
242 name => panic!("Unexpected variant name: {}", name),
243 }
244 }
245
246 let ratio = control_count as f64 / (control_count + treatment_count) as f64;
248 assert!((ratio - 0.5).abs() < 0.05); }
250}