Skip to main content

trustformers_core/versioning/
integration.rs

1//! Integration between versioning and A/B testing systems
2
3use anyhow::Result;
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use uuid::Uuid;
9
10use super::{ModelVersionManager, VersionedModel};
11use crate::ab_testing::{ABTestManager, ExperimentConfig, MetricType, MetricValue, Variant};
12
13/// Enhanced A/B testing manager with versioning integration
14pub struct VersionedABTestManager {
15    version_manager: Arc<ModelVersionManager>,
16    ab_test_manager: Arc<ABTestManager>,
17    active_experiments: tokio::sync::RwLock<HashMap<String, VersionedExperiment>>,
18}
19
20impl VersionedABTestManager {
21    /// Create a new versioned A/B test manager
22    pub fn new(version_manager: Arc<ModelVersionManager>) -> Self {
23        Self {
24            version_manager,
25            ab_test_manager: Arc::new(ABTestManager::new()),
26            active_experiments: tokio::sync::RwLock::new(HashMap::new()),
27        }
28    }
29
30    /// Create an A/B test between model versions
31    pub async fn create_version_experiment(
32        &self,
33        config: VersionExperimentConfig,
34    ) -> Result<String> {
35        // Validate that all versions exist
36        let control_version = self
37            .version_manager
38            .get_version(config.control_version_id)
39            .await?
40            .ok_or_else(|| anyhow::anyhow!("Control version not found"))?;
41
42        let mut treatment_versions = Vec::new();
43        for &version_id in &config.treatment_version_ids {
44            let version = self
45                .version_manager
46                .get_version(version_id)
47                .await?
48                .ok_or_else(|| anyhow::anyhow!("Treatment version {} not found", version_id))?;
49            treatment_versions.push(version);
50        }
51
52        // Create A/B test variants
53        let control_variant = Variant::new("control", &control_version.qualified_name());
54
55        let treatment_variants: Vec<Variant> = treatment_versions
56            .iter()
57            .enumerate()
58            .map(|(i, v)| Variant::new(&format!("treatment_{}", i), &v.qualified_name()))
59            .collect();
60
61        // Create experiment config
62        let experiment_config = ExperimentConfig {
63            name: config.name.clone(),
64            description: config.description.clone(),
65            control_variant,
66            treatment_variants,
67            traffic_percentage: config.traffic_percentage,
68            min_sample_size: config.min_sample_size,
69            max_duration_hours: config.max_duration_hours,
70        };
71
72        // Create the experiment
73        let experiment_id = self.ab_test_manager.create_experiment(experiment_config)?;
74
75        // Track versioned experiment
76        let versioned_experiment = VersionedExperiment {
77            experiment_id: experiment_id.clone(),
78            model_name: control_version.model_name().to_string(),
79            control_version_id: config.control_version_id,
80            treatment_version_ids: config.treatment_version_ids.clone(),
81            config: config.clone(),
82            started_at: Utc::now(),
83            status: VersionedExperimentStatus::Running,
84            metrics_collected: HashMap::new(),
85        };
86
87        {
88            let mut experiments = self.active_experiments.write().await;
89            experiments.insert(experiment_id.clone(), versioned_experiment);
90        }
91
92        tracing::info!(
93            "Created versioned A/B test: {} ({})",
94            config.name,
95            experiment_id
96        );
97        Ok(experiment_id)
98    }
99
100    /// Route a request to the appropriate model version
101    pub async fn route_request(
102        &self,
103        experiment_id: &str,
104        user_id: &str,
105    ) -> Result<ModelRoutingResult> {
106        // Get the variant from A/B test manager
107        let variant = self.ab_test_manager.route_request(experiment_id, user_id)?;
108
109        // Get the versioned experiment
110        let experiments = self.active_experiments.read().await;
111        let versioned_experiment = experiments
112            .get(experiment_id)
113            .ok_or_else(|| anyhow::anyhow!("Versioned experiment not found"))?;
114
115        // Map variant to version ID
116        let version_id = if variant.name() == "control" {
117            versioned_experiment.control_version_id
118        } else {
119            // Parse treatment index from variant name
120            let treatment_index = variant
121                .name()
122                .strip_prefix("treatment_")
123                .and_then(|s| s.parse::<usize>().ok())
124                .ok_or_else(|| anyhow::anyhow!("Invalid treatment variant name"))?;
125
126            *versioned_experiment
127                .treatment_version_ids
128                .get(treatment_index)
129                .ok_or_else(|| anyhow::anyhow!("Treatment index out of bounds"))?
130        };
131
132        // Get the actual model version
133        let model_version = self
134            .version_manager
135            .get_version(version_id)
136            .await?
137            .ok_or_else(|| anyhow::anyhow!("Model version not found"))?;
138
139        Ok(ModelRoutingResult {
140            experiment_id: experiment_id.to_string(),
141            variant: variant.clone(),
142            version_id,
143            model_version,
144            user_id: user_id.to_string(),
145        })
146    }
147
148    /// Record metrics for a versioned experiment
149    pub async fn record_version_metric(
150        &self,
151        experiment_id: &str,
152        user_id: &str,
153        metric_type: VersionMetricType,
154        value: f64,
155        metadata: Option<HashMap<String, String>>,
156    ) -> Result<()> {
157        // Get routing information to determine variant
158        let routing_result = self.route_request(experiment_id, user_id).await?;
159
160        // Convert to A/B test metric
161        let ab_metric_type = match &metric_type {
162            VersionMetricType::Latency => MetricType::Latency,
163            VersionMetricType::Accuracy => MetricType::Accuracy,
164            VersionMetricType::Throughput => MetricType::Throughput,
165            VersionMetricType::ErrorRate => MetricType::ErrorRate,
166            VersionMetricType::MemoryUsage => MetricType::MemoryUsage,
167            VersionMetricType::CustomMetric(name) => MetricType::Custom(name.clone()),
168        };
169
170        let ab_metric_value = match metric_type {
171            VersionMetricType::Latency => MetricValue::Duration(value as u64),
172            _ => MetricValue::Numeric(value),
173        };
174
175        // Record in A/B test manager
176        self.ab_test_manager.record_metric(
177            experiment_id,
178            &routing_result.variant,
179            ab_metric_type,
180            ab_metric_value,
181        )?;
182
183        // Update versioned experiment metrics
184        {
185            let mut experiments = self.active_experiments.write().await;
186            if let Some(experiment) = experiments.get_mut(experiment_id) {
187                let metric_key = format!("{}:{}", routing_result.variant.name(), metric_type);
188                experiment.metrics_collected.entry(metric_key).or_default().push(
189                    VersionMetricRecord {
190                        value,
191                        timestamp: Utc::now(),
192                        user_id: user_id.to_string(),
193                        metadata: metadata.unwrap_or_default(),
194                    },
195                );
196            }
197        }
198
199        Ok(())
200    }
201
202    /// Analyze experiment results with version context
203    pub async fn analyze_version_experiment(
204        &self,
205        experiment_id: &str,
206    ) -> Result<VersionExperimentResult> {
207        // Get A/B test results
208        let ab_result = self.ab_test_manager.analyze_experiment(experiment_id)?;
209
210        // Get versioned experiment
211        let experiments = self.active_experiments.read().await;
212        let versioned_experiment = experiments
213            .get(experiment_id)
214            .ok_or_else(|| anyhow::anyhow!("Versioned experiment not found"))?;
215
216        // Get version details
217        let control_version = self
218            .version_manager
219            .get_version(versioned_experiment.control_version_id)
220            .await?
221            .ok_or_else(|| anyhow::anyhow!("Control version not found"))?;
222
223        let mut treatment_versions = Vec::new();
224        for &version_id in &versioned_experiment.treatment_version_ids {
225            let version = self
226                .version_manager
227                .get_version(version_id)
228                .await?
229                .ok_or_else(|| anyhow::anyhow!("Treatment version not found"))?;
230            treatment_versions.push(version);
231        }
232
233        // Create version-enhanced result
234        Ok(VersionExperimentResult {
235            experiment_id: experiment_id.to_string(),
236            model_name: versioned_experiment.model_name.clone(),
237            control_version,
238            treatment_versions,
239            ab_test_result: ab_result,
240            experiment_duration: Utc::now() - versioned_experiment.started_at,
241            total_requests: versioned_experiment
242                .metrics_collected
243                .values()
244                .map(|records| records.len())
245                .sum(),
246            version_performance_comparison: self
247                .compare_version_performance(versioned_experiment)
248                .await?,
249        })
250    }
251
252    /// Promote winning version based on experiment results
253    pub async fn promote_winning_version(&self, experiment_id: &str) -> Result<PromotionResult> {
254        let result = self.analyze_version_experiment(experiment_id).await?;
255
256        // Determine winner based on A/B test recommendation
257        let winning_version_id = match &result.ab_test_result.recommendation {
258            crate::ab_testing::TestRecommendation::AdoptTreatment { variant, .. } => {
259                // Find the treatment version ID
260                let treatment_index = variant
261                    .strip_prefix("treatment_")
262                    .and_then(|s| s.parse::<usize>().ok())
263                    .ok_or_else(|| anyhow::anyhow!("Invalid treatment variant name"))?;
264
265                result
266                    .treatment_versions
267                    .get(treatment_index)
268                    .map(|v| v.id())
269                    .ok_or_else(|| anyhow::anyhow!("Treatment index out of bounds"))?
270            },
271            crate::ab_testing::TestRecommendation::KeepControl { .. } => {
272                result.control_version.id()
273            },
274            _ => {
275                return Ok(PromotionResult {
276                    promoted: false,
277                    version_id: None,
278                    reason: "No clear winner determined".to_string(),
279                });
280            },
281        };
282
283        // Promote the winning version to production
284        self.version_manager.promote_to_production(winning_version_id).await?;
285
286        // Mark experiment as concluded
287        {
288            let mut experiments = self.active_experiments.write().await;
289            if let Some(experiment) = experiments.get_mut(experiment_id) {
290                experiment.status = VersionedExperimentStatus::Concluded;
291            }
292        }
293
294        Ok(PromotionResult {
295            promoted: true,
296            version_id: Some(winning_version_id),
297            reason: "Version promoted based on A/B test results".to_string(),
298        })
299    }
300
301    /// List active versioned experiments
302    pub async fn list_experiments(&self) -> Result<Vec<VersionedExperiment>> {
303        let experiments = self.active_experiments.read().await;
304        Ok(experiments.values().cloned().collect())
305    }
306
307    /// Stop a versioned experiment
308    pub async fn stop_experiment(&self, experiment_id: &str) -> Result<()> {
309        let mut experiments = self.active_experiments.write().await;
310        if let Some(experiment) = experiments.get_mut(experiment_id) {
311            experiment.status = VersionedExperimentStatus::Stopped;
312        }
313        Ok(())
314    }
315
316    // Helper methods
317
318    async fn compare_version_performance(
319        &self,
320        experiment: &VersionedExperiment,
321    ) -> Result<HashMap<String, VersionPerformanceMetrics>> {
322        let mut comparison = HashMap::new();
323
324        for (metric_key, records) in &experiment.metrics_collected {
325            let parts: Vec<&str> = metric_key.split(':').collect();
326            if parts.len() == 2 {
327                let variant_name = parts[0];
328                let metric_type = parts[1];
329
330                let values: Vec<f64> = records.iter().map(|r| r.value).collect();
331
332                if !values.is_empty() {
333                    let mean = values.iter().sum::<f64>() / values.len() as f64;
334                    let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
335                        / values.len() as f64;
336                    let std_dev = variance.sqrt();
337                    let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
338                    let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
339
340                    let metrics = VersionPerformanceMetrics {
341                        metric_type: metric_type.to_string(),
342                        sample_count: values.len(),
343                        mean,
344                        std_dev,
345                        min,
346                        max,
347                        p95: calculate_percentile(&values, 0.95),
348                        p99: calculate_percentile(&values, 0.99),
349                    };
350
351                    comparison.insert(format!("{}:{}", variant_name, metric_type), metrics);
352                }
353            }
354        }
355
356        Ok(comparison)
357    }
358}
359
360/// Calculate percentile from a list of values
361fn calculate_percentile(values: &[f64], percentile: f64) -> f64 {
362    let mut sorted = values.to_vec();
363    sorted.sort_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"));
364    let index = (percentile * (sorted.len() - 1) as f64) as usize;
365    sorted[index]
366}
367
368/// Configuration for version-based A/B test experiment
369#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct VersionExperimentConfig {
371    pub name: String,
372    pub description: String,
373    pub control_version_id: Uuid,
374    pub treatment_version_ids: Vec<Uuid>,
375    pub traffic_percentage: f64,
376    pub min_sample_size: usize,
377    pub max_duration_hours: u64,
378}
379
380/// Status of a versioned experiment
381#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
382pub enum VersionedExperimentStatus {
383    Running,
384    Stopped,
385    Concluded,
386    Failed,
387}
388
389/// A versioned A/B test experiment
390#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct VersionedExperiment {
392    pub experiment_id: String,
393    pub model_name: String,
394    pub control_version_id: Uuid,
395    pub treatment_version_ids: Vec<Uuid>,
396    pub config: VersionExperimentConfig,
397    pub started_at: DateTime<Utc>,
398    pub status: VersionedExperimentStatus,
399    pub metrics_collected: HashMap<String, Vec<VersionMetricRecord>>,
400}
401
402/// Metric types for version experiments
403#[derive(Debug, Clone, Serialize, Deserialize)]
404pub enum VersionMetricType {
405    Latency,
406    Accuracy,
407    Throughput,
408    ErrorRate,
409    MemoryUsage,
410    CustomMetric(String),
411}
412
413impl std::fmt::Display for VersionMetricType {
414    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415        match self {
416            VersionMetricType::Latency => write!(f, "latency"),
417            VersionMetricType::Accuracy => write!(f, "accuracy"),
418            VersionMetricType::Throughput => write!(f, "throughput"),
419            VersionMetricType::ErrorRate => write!(f, "error_rate"),
420            VersionMetricType::MemoryUsage => write!(f, "memory_usage"),
421            VersionMetricType::CustomMetric(name) => write!(f, "{}", name),
422        }
423    }
424}
425
426/// Metric record for version experiments
427#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct VersionMetricRecord {
429    pub value: f64,
430    pub timestamp: DateTime<Utc>,
431    pub user_id: String,
432    pub metadata: HashMap<String, String>,
433}
434
435/// Result of routing a request to a model version
436#[derive(Debug, Clone)]
437pub struct ModelRoutingResult {
438    pub experiment_id: String,
439    pub variant: Variant,
440    pub version_id: Uuid,
441    pub model_version: VersionedModel,
442    pub user_id: String,
443}
444
445/// Result of analyzing a version experiment
446#[derive(Debug, Clone)]
447pub struct VersionExperimentResult {
448    pub experiment_id: String,
449    pub model_name: String,
450    pub control_version: VersionedModel,
451    pub treatment_versions: Vec<VersionedModel>,
452    pub ab_test_result: crate::ab_testing::TestResult,
453    pub experiment_duration: chrono::Duration,
454    pub total_requests: usize,
455    pub version_performance_comparison: HashMap<String, VersionPerformanceMetrics>,
456}
457
458/// Performance metrics for a version
459#[derive(Debug, Clone, Serialize, Deserialize)]
460pub struct VersionPerformanceMetrics {
461    pub metric_type: String,
462    pub sample_count: usize,
463    pub mean: f64,
464    pub std_dev: f64,
465    pub min: f64,
466    pub max: f64,
467    pub p95: f64,
468    pub p99: f64,
469}
470
471/// Result of promoting a version
472#[derive(Debug, Clone)]
473pub struct PromotionResult {
474    pub promoted: bool,
475    pub version_id: Option<Uuid>,
476    pub reason: String,
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use crate::versioning::{
483        metadata::{ModelMetadata, ModelTag},
484        storage::InMemoryStorage,
485    };
486
487    async fn create_test_version_manager() -> Arc<ModelVersionManager> {
488        let storage = Arc::new(InMemoryStorage::new());
489        Arc::new(ModelVersionManager::new(storage))
490    }
491
492    async fn create_test_version(manager: &ModelVersionManager, name: &str, version: &str) -> Uuid {
493        let metadata = ModelMetadata::builder()
494            .description(format!("Test model {}", name))
495            .created_by("test_user".to_string())
496            .model_type("transformer".to_string())
497            .tag(ModelTag::new("test"))
498            .build();
499
500        manager
501            .register_version(name, version, metadata, vec![])
502            .await
503            .expect("async operation failed")
504    }
505
506    #[tokio::test]
507    async fn test_version_experiment_creation() {
508        let version_manager = create_test_version_manager().await;
509        let ab_manager = VersionedABTestManager::new(version_manager.clone());
510
511        // Create test versions
512        let control_id = create_test_version(&version_manager, "test_model", "1.0.0").await;
513        let treatment_id = create_test_version(&version_manager, "test_model", "1.1.0").await;
514
515        // Create experiment
516        let config = VersionExperimentConfig {
517            name: "Model v1.1 Test".to_string(),
518            description: "Testing improved model".to_string(),
519            control_version_id: control_id,
520            treatment_version_ids: vec![treatment_id],
521            traffic_percentage: 50.0,
522            min_sample_size: 100,
523            max_duration_hours: 24,
524        };
525
526        let experiment_id = ab_manager
527            .create_version_experiment(config)
528            .await
529            .expect("async operation failed");
530        assert!(!experiment_id.is_empty());
531
532        // Check that experiment is tracked
533        let experiments = ab_manager.list_experiments().await.expect("async operation failed");
534        assert_eq!(experiments.len(), 1);
535        assert_eq!(experiments[0].experiment_id, experiment_id);
536    }
537
538    #[tokio::test]
539    async fn test_request_routing() {
540        let version_manager = create_test_version_manager().await;
541        let ab_manager = VersionedABTestManager::new(version_manager.clone());
542
543        let control_id = create_test_version(&version_manager, "test_model", "1.0.0").await;
544        let treatment_id = create_test_version(&version_manager, "test_model", "1.1.0").await;
545
546        let config = VersionExperimentConfig {
547            name: "Routing Test".to_string(),
548            description: "Test routing".to_string(),
549            control_version_id: control_id,
550            treatment_version_ids: vec![treatment_id],
551            traffic_percentage: 100.0,
552            min_sample_size: 10,
553            max_duration_hours: 1,
554        };
555
556        let experiment_id = ab_manager
557            .create_version_experiment(config)
558            .await
559            .expect("async operation failed");
560
561        // Route a request
562        let routing_result = ab_manager
563            .route_request(&experiment_id, "test_user")
564            .await
565            .expect("async operation failed");
566        assert_eq!(routing_result.experiment_id, experiment_id);
567        assert_eq!(routing_result.user_id, "test_user");
568        assert!(
569            routing_result.version_id == control_id || routing_result.version_id == treatment_id
570        );
571    }
572
573    #[tokio::test]
574    async fn test_metric_recording() {
575        let version_manager = create_test_version_manager().await;
576        let ab_manager = VersionedABTestManager::new(version_manager.clone());
577
578        let control_id = create_test_version(&version_manager, "test_model", "1.0.0").await;
579        let treatment_id = create_test_version(&version_manager, "test_model", "1.1.0").await;
580
581        let config = VersionExperimentConfig {
582            name: "Metrics Test".to_string(),
583            description: "Test metrics".to_string(),
584            control_version_id: control_id,
585            treatment_version_ids: vec![treatment_id],
586            traffic_percentage: 100.0,
587            min_sample_size: 10,
588            max_duration_hours: 1,
589        };
590
591        let experiment_id = ab_manager
592            .create_version_experiment(config)
593            .await
594            .expect("async operation failed");
595
596        // Record a metric
597        ab_manager
598            .record_version_metric(
599                &experiment_id,
600                "test_user",
601                VersionMetricType::Latency,
602                120.0,
603                None,
604            )
605            .await
606            .expect("operation failed in test");
607
608        // Check that metric was recorded
609        let experiments = ab_manager.list_experiments().await.expect("async operation failed");
610        let experiment = &experiments[0];
611        assert!(!experiment.metrics_collected.is_empty());
612    }
613}