1use anyhow::Result;
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use uuid::Uuid;
9
10use super::{ModelVersionManager, VersionedModel};
11use crate::ab_testing::{ABTestManager, ExperimentConfig, MetricType, MetricValue, Variant};
12
13pub struct VersionedABTestManager {
15 version_manager: Arc<ModelVersionManager>,
16 ab_test_manager: Arc<ABTestManager>,
17 active_experiments: tokio::sync::RwLock<HashMap<String, VersionedExperiment>>,
18}
19
20impl VersionedABTestManager {
21 pub fn new(version_manager: Arc<ModelVersionManager>) -> Self {
23 Self {
24 version_manager,
25 ab_test_manager: Arc::new(ABTestManager::new()),
26 active_experiments: tokio::sync::RwLock::new(HashMap::new()),
27 }
28 }
29
30 pub async fn create_version_experiment(
32 &self,
33 config: VersionExperimentConfig,
34 ) -> Result<String> {
35 let control_version = self
37 .version_manager
38 .get_version(config.control_version_id)
39 .await?
40 .ok_or_else(|| anyhow::anyhow!("Control version not found"))?;
41
42 let mut treatment_versions = Vec::new();
43 for &version_id in &config.treatment_version_ids {
44 let version = self
45 .version_manager
46 .get_version(version_id)
47 .await?
48 .ok_or_else(|| anyhow::anyhow!("Treatment version {} not found", version_id))?;
49 treatment_versions.push(version);
50 }
51
52 let control_variant = Variant::new("control", &control_version.qualified_name());
54
55 let treatment_variants: Vec<Variant> = treatment_versions
56 .iter()
57 .enumerate()
58 .map(|(i, v)| Variant::new(&format!("treatment_{}", i), &v.qualified_name()))
59 .collect();
60
61 let experiment_config = ExperimentConfig {
63 name: config.name.clone(),
64 description: config.description.clone(),
65 control_variant,
66 treatment_variants,
67 traffic_percentage: config.traffic_percentage,
68 min_sample_size: config.min_sample_size,
69 max_duration_hours: config.max_duration_hours,
70 };
71
72 let experiment_id = self.ab_test_manager.create_experiment(experiment_config)?;
74
75 let versioned_experiment = VersionedExperiment {
77 experiment_id: experiment_id.clone(),
78 model_name: control_version.model_name().to_string(),
79 control_version_id: config.control_version_id,
80 treatment_version_ids: config.treatment_version_ids.clone(),
81 config: config.clone(),
82 started_at: Utc::now(),
83 status: VersionedExperimentStatus::Running,
84 metrics_collected: HashMap::new(),
85 };
86
87 {
88 let mut experiments = self.active_experiments.write().await;
89 experiments.insert(experiment_id.clone(), versioned_experiment);
90 }
91
92 tracing::info!(
93 "Created versioned A/B test: {} ({})",
94 config.name,
95 experiment_id
96 );
97 Ok(experiment_id)
98 }
99
100 pub async fn route_request(
102 &self,
103 experiment_id: &str,
104 user_id: &str,
105 ) -> Result<ModelRoutingResult> {
106 let variant = self.ab_test_manager.route_request(experiment_id, user_id)?;
108
109 let experiments = self.active_experiments.read().await;
111 let versioned_experiment = experiments
112 .get(experiment_id)
113 .ok_or_else(|| anyhow::anyhow!("Versioned experiment not found"))?;
114
115 let version_id = if variant.name() == "control" {
117 versioned_experiment.control_version_id
118 } else {
119 let treatment_index = variant
121 .name()
122 .strip_prefix("treatment_")
123 .and_then(|s| s.parse::<usize>().ok())
124 .ok_or_else(|| anyhow::anyhow!("Invalid treatment variant name"))?;
125
126 *versioned_experiment
127 .treatment_version_ids
128 .get(treatment_index)
129 .ok_or_else(|| anyhow::anyhow!("Treatment index out of bounds"))?
130 };
131
132 let model_version = self
134 .version_manager
135 .get_version(version_id)
136 .await?
137 .ok_or_else(|| anyhow::anyhow!("Model version not found"))?;
138
139 Ok(ModelRoutingResult {
140 experiment_id: experiment_id.to_string(),
141 variant: variant.clone(),
142 version_id,
143 model_version,
144 user_id: user_id.to_string(),
145 })
146 }
147
148 pub async fn record_version_metric(
150 &self,
151 experiment_id: &str,
152 user_id: &str,
153 metric_type: VersionMetricType,
154 value: f64,
155 metadata: Option<HashMap<String, String>>,
156 ) -> Result<()> {
157 let routing_result = self.route_request(experiment_id, user_id).await?;
159
160 let ab_metric_type = match &metric_type {
162 VersionMetricType::Latency => MetricType::Latency,
163 VersionMetricType::Accuracy => MetricType::Accuracy,
164 VersionMetricType::Throughput => MetricType::Throughput,
165 VersionMetricType::ErrorRate => MetricType::ErrorRate,
166 VersionMetricType::MemoryUsage => MetricType::MemoryUsage,
167 VersionMetricType::CustomMetric(name) => MetricType::Custom(name.clone()),
168 };
169
170 let ab_metric_value = match metric_type {
171 VersionMetricType::Latency => MetricValue::Duration(value as u64),
172 _ => MetricValue::Numeric(value),
173 };
174
175 self.ab_test_manager.record_metric(
177 experiment_id,
178 &routing_result.variant,
179 ab_metric_type,
180 ab_metric_value,
181 )?;
182
183 {
185 let mut experiments = self.active_experiments.write().await;
186 if let Some(experiment) = experiments.get_mut(experiment_id) {
187 let metric_key = format!("{}:{}", routing_result.variant.name(), metric_type);
188 experiment.metrics_collected.entry(metric_key).or_default().push(
189 VersionMetricRecord {
190 value,
191 timestamp: Utc::now(),
192 user_id: user_id.to_string(),
193 metadata: metadata.unwrap_or_default(),
194 },
195 );
196 }
197 }
198
199 Ok(())
200 }
201
202 pub async fn analyze_version_experiment(
204 &self,
205 experiment_id: &str,
206 ) -> Result<VersionExperimentResult> {
207 let ab_result = self.ab_test_manager.analyze_experiment(experiment_id)?;
209
210 let experiments = self.active_experiments.read().await;
212 let versioned_experiment = experiments
213 .get(experiment_id)
214 .ok_or_else(|| anyhow::anyhow!("Versioned experiment not found"))?;
215
216 let control_version = self
218 .version_manager
219 .get_version(versioned_experiment.control_version_id)
220 .await?
221 .ok_or_else(|| anyhow::anyhow!("Control version not found"))?;
222
223 let mut treatment_versions = Vec::new();
224 for &version_id in &versioned_experiment.treatment_version_ids {
225 let version = self
226 .version_manager
227 .get_version(version_id)
228 .await?
229 .ok_or_else(|| anyhow::anyhow!("Treatment version not found"))?;
230 treatment_versions.push(version);
231 }
232
233 Ok(VersionExperimentResult {
235 experiment_id: experiment_id.to_string(),
236 model_name: versioned_experiment.model_name.clone(),
237 control_version,
238 treatment_versions,
239 ab_test_result: ab_result,
240 experiment_duration: Utc::now() - versioned_experiment.started_at,
241 total_requests: versioned_experiment
242 .metrics_collected
243 .values()
244 .map(|records| records.len())
245 .sum(),
246 version_performance_comparison: self
247 .compare_version_performance(versioned_experiment)
248 .await?,
249 })
250 }
251
252 pub async fn promote_winning_version(&self, experiment_id: &str) -> Result<PromotionResult> {
254 let result = self.analyze_version_experiment(experiment_id).await?;
255
256 let winning_version_id = match &result.ab_test_result.recommendation {
258 crate::ab_testing::TestRecommendation::AdoptTreatment { variant, .. } => {
259 let treatment_index = variant
261 .strip_prefix("treatment_")
262 .and_then(|s| s.parse::<usize>().ok())
263 .ok_or_else(|| anyhow::anyhow!("Invalid treatment variant name"))?;
264
265 result
266 .treatment_versions
267 .get(treatment_index)
268 .map(|v| v.id())
269 .ok_or_else(|| anyhow::anyhow!("Treatment index out of bounds"))?
270 },
271 crate::ab_testing::TestRecommendation::KeepControl { .. } => {
272 result.control_version.id()
273 },
274 _ => {
275 return Ok(PromotionResult {
276 promoted: false,
277 version_id: None,
278 reason: "No clear winner determined".to_string(),
279 });
280 },
281 };
282
283 self.version_manager.promote_to_production(winning_version_id).await?;
285
286 {
288 let mut experiments = self.active_experiments.write().await;
289 if let Some(experiment) = experiments.get_mut(experiment_id) {
290 experiment.status = VersionedExperimentStatus::Concluded;
291 }
292 }
293
294 Ok(PromotionResult {
295 promoted: true,
296 version_id: Some(winning_version_id),
297 reason: "Version promoted based on A/B test results".to_string(),
298 })
299 }
300
301 pub async fn list_experiments(&self) -> Result<Vec<VersionedExperiment>> {
303 let experiments = self.active_experiments.read().await;
304 Ok(experiments.values().cloned().collect())
305 }
306
307 pub async fn stop_experiment(&self, experiment_id: &str) -> Result<()> {
309 let mut experiments = self.active_experiments.write().await;
310 if let Some(experiment) = experiments.get_mut(experiment_id) {
311 experiment.status = VersionedExperimentStatus::Stopped;
312 }
313 Ok(())
314 }
315
316 async fn compare_version_performance(
319 &self,
320 experiment: &VersionedExperiment,
321 ) -> Result<HashMap<String, VersionPerformanceMetrics>> {
322 let mut comparison = HashMap::new();
323
324 for (metric_key, records) in &experiment.metrics_collected {
325 let parts: Vec<&str> = metric_key.split(':').collect();
326 if parts.len() == 2 {
327 let variant_name = parts[0];
328 let metric_type = parts[1];
329
330 let values: Vec<f64> = records.iter().map(|r| r.value).collect();
331
332 if !values.is_empty() {
333 let mean = values.iter().sum::<f64>() / values.len() as f64;
334 let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
335 / values.len() as f64;
336 let std_dev = variance.sqrt();
337 let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
338 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
339
340 let metrics = VersionPerformanceMetrics {
341 metric_type: metric_type.to_string(),
342 sample_count: values.len(),
343 mean,
344 std_dev,
345 min,
346 max,
347 p95: calculate_percentile(&values, 0.95),
348 p99: calculate_percentile(&values, 0.99),
349 };
350
351 comparison.insert(format!("{}:{}", variant_name, metric_type), metrics);
352 }
353 }
354 }
355
356 Ok(comparison)
357 }
358}
359
360fn calculate_percentile(values: &[f64], percentile: f64) -> f64 {
362 let mut sorted = values.to_vec();
363 sorted.sort_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"));
364 let index = (percentile * (sorted.len() - 1) as f64) as usize;
365 sorted[index]
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct VersionExperimentConfig {
371 pub name: String,
372 pub description: String,
373 pub control_version_id: Uuid,
374 pub treatment_version_ids: Vec<Uuid>,
375 pub traffic_percentage: f64,
376 pub min_sample_size: usize,
377 pub max_duration_hours: u64,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
382pub enum VersionedExperimentStatus {
383 Running,
384 Stopped,
385 Concluded,
386 Failed,
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct VersionedExperiment {
392 pub experiment_id: String,
393 pub model_name: String,
394 pub control_version_id: Uuid,
395 pub treatment_version_ids: Vec<Uuid>,
396 pub config: VersionExperimentConfig,
397 pub started_at: DateTime<Utc>,
398 pub status: VersionedExperimentStatus,
399 pub metrics_collected: HashMap<String, Vec<VersionMetricRecord>>,
400}
401
402#[derive(Debug, Clone, Serialize, Deserialize)]
404pub enum VersionMetricType {
405 Latency,
406 Accuracy,
407 Throughput,
408 ErrorRate,
409 MemoryUsage,
410 CustomMetric(String),
411}
412
413impl std::fmt::Display for VersionMetricType {
414 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415 match self {
416 VersionMetricType::Latency => write!(f, "latency"),
417 VersionMetricType::Accuracy => write!(f, "accuracy"),
418 VersionMetricType::Throughput => write!(f, "throughput"),
419 VersionMetricType::ErrorRate => write!(f, "error_rate"),
420 VersionMetricType::MemoryUsage => write!(f, "memory_usage"),
421 VersionMetricType::CustomMetric(name) => write!(f, "{}", name),
422 }
423 }
424}
425
426#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct VersionMetricRecord {
429 pub value: f64,
430 pub timestamp: DateTime<Utc>,
431 pub user_id: String,
432 pub metadata: HashMap<String, String>,
433}
434
435#[derive(Debug, Clone)]
437pub struct ModelRoutingResult {
438 pub experiment_id: String,
439 pub variant: Variant,
440 pub version_id: Uuid,
441 pub model_version: VersionedModel,
442 pub user_id: String,
443}
444
445#[derive(Debug, Clone)]
447pub struct VersionExperimentResult {
448 pub experiment_id: String,
449 pub model_name: String,
450 pub control_version: VersionedModel,
451 pub treatment_versions: Vec<VersionedModel>,
452 pub ab_test_result: crate::ab_testing::TestResult,
453 pub experiment_duration: chrono::Duration,
454 pub total_requests: usize,
455 pub version_performance_comparison: HashMap<String, VersionPerformanceMetrics>,
456}
457
458#[derive(Debug, Clone, Serialize, Deserialize)]
460pub struct VersionPerformanceMetrics {
461 pub metric_type: String,
462 pub sample_count: usize,
463 pub mean: f64,
464 pub std_dev: f64,
465 pub min: f64,
466 pub max: f64,
467 pub p95: f64,
468 pub p99: f64,
469}
470
471#[derive(Debug, Clone)]
473pub struct PromotionResult {
474 pub promoted: bool,
475 pub version_id: Option<Uuid>,
476 pub reason: String,
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482 use crate::versioning::{
483 metadata::{ModelMetadata, ModelTag},
484 storage::InMemoryStorage,
485 };
486
487 async fn create_test_version_manager() -> Arc<ModelVersionManager> {
488 let storage = Arc::new(InMemoryStorage::new());
489 Arc::new(ModelVersionManager::new(storage))
490 }
491
492 async fn create_test_version(manager: &ModelVersionManager, name: &str, version: &str) -> Uuid {
493 let metadata = ModelMetadata::builder()
494 .description(format!("Test model {}", name))
495 .created_by("test_user".to_string())
496 .model_type("transformer".to_string())
497 .tag(ModelTag::new("test"))
498 .build();
499
500 manager
501 .register_version(name, version, metadata, vec![])
502 .await
503 .expect("async operation failed")
504 }
505
506 #[tokio::test]
507 async fn test_version_experiment_creation() {
508 let version_manager = create_test_version_manager().await;
509 let ab_manager = VersionedABTestManager::new(version_manager.clone());
510
511 let control_id = create_test_version(&version_manager, "test_model", "1.0.0").await;
513 let treatment_id = create_test_version(&version_manager, "test_model", "1.1.0").await;
514
515 let config = VersionExperimentConfig {
517 name: "Model v1.1 Test".to_string(),
518 description: "Testing improved model".to_string(),
519 control_version_id: control_id,
520 treatment_version_ids: vec![treatment_id],
521 traffic_percentage: 50.0,
522 min_sample_size: 100,
523 max_duration_hours: 24,
524 };
525
526 let experiment_id = ab_manager
527 .create_version_experiment(config)
528 .await
529 .expect("async operation failed");
530 assert!(!experiment_id.is_empty());
531
532 let experiments = ab_manager.list_experiments().await.expect("async operation failed");
534 assert_eq!(experiments.len(), 1);
535 assert_eq!(experiments[0].experiment_id, experiment_id);
536 }
537
538 #[tokio::test]
539 async fn test_request_routing() {
540 let version_manager = create_test_version_manager().await;
541 let ab_manager = VersionedABTestManager::new(version_manager.clone());
542
543 let control_id = create_test_version(&version_manager, "test_model", "1.0.0").await;
544 let treatment_id = create_test_version(&version_manager, "test_model", "1.1.0").await;
545
546 let config = VersionExperimentConfig {
547 name: "Routing Test".to_string(),
548 description: "Test routing".to_string(),
549 control_version_id: control_id,
550 treatment_version_ids: vec![treatment_id],
551 traffic_percentage: 100.0,
552 min_sample_size: 10,
553 max_duration_hours: 1,
554 };
555
556 let experiment_id = ab_manager
557 .create_version_experiment(config)
558 .await
559 .expect("async operation failed");
560
561 let routing_result = ab_manager
563 .route_request(&experiment_id, "test_user")
564 .await
565 .expect("async operation failed");
566 assert_eq!(routing_result.experiment_id, experiment_id);
567 assert_eq!(routing_result.user_id, "test_user");
568 assert!(
569 routing_result.version_id == control_id || routing_result.version_id == treatment_id
570 );
571 }
572
573 #[tokio::test]
574 async fn test_metric_recording() {
575 let version_manager = create_test_version_manager().await;
576 let ab_manager = VersionedABTestManager::new(version_manager.clone());
577
578 let control_id = create_test_version(&version_manager, "test_model", "1.0.0").await;
579 let treatment_id = create_test_version(&version_manager, "test_model", "1.1.0").await;
580
581 let config = VersionExperimentConfig {
582 name: "Metrics Test".to_string(),
583 description: "Test metrics".to_string(),
584 control_version_id: control_id,
585 treatment_version_ids: vec![treatment_id],
586 traffic_percentage: 100.0,
587 min_sample_size: 10,
588 max_duration_hours: 1,
589 };
590
591 let experiment_id = ab_manager
592 .create_version_experiment(config)
593 .await
594 .expect("async operation failed");
595
596 ab_manager
598 .record_version_metric(
599 &experiment_id,
600 "test_user",
601 VersionMetricType::Latency,
602 120.0,
603 None,
604 )
605 .await
606 .expect("operation failed in test");
607
608 let experiments = ab_manager.list_experiments().await.expect("async operation failed");
610 let experiment = &experiments[0];
611 assert!(!experiment.metrics_collected.is_empty());
612 }
613}