Skip to main content

trustformers_core/ab_testing/
deployment.rs

1//! Deployment strategies and rollout control
2
3use super::Variant;
4use anyhow::Result;
5use chrono::{DateTime, Duration, Utc};
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// Deployment strategy for rolling out winning variants
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub enum DeploymentStrategy {
14    /// Immediate full deployment
15    Immediate,
16    /// Gradual percentage-based rollout
17    Gradual {
18        /// Initial percentage
19        initial_percentage: f64,
20        /// Increment per step
21        increment: f64,
22        /// Time between increments
23        increment_interval: Duration,
24    },
25    /// Canary deployment
26    Canary {
27        /// Canary percentage
28        canary_percentage: f64,
29        /// Monitoring duration
30        monitoring_duration: Duration,
31    },
32    /// Blue-green deployment
33    BlueGreen {
34        /// Warm-up duration
35        warmup_duration: Duration,
36    },
37    /// Feature flag based
38    FeatureFlag {
39        /// Flag name
40        flag_name: String,
41    },
42}
43
44/// Rollout status
45#[derive(Debug, Clone, PartialEq)]
46pub enum RolloutStatus {
47    /// Not started
48    NotStarted,
49    /// In progress
50    InProgress {
51        /// Current percentage
52        current_percentage: f64,
53        /// Start time
54        started_at: DateTime<Utc>,
55    },
56    /// Paused
57    Paused {
58        /// Percentage when paused
59        paused_at_percentage: f64,
60    },
61    /// Completed
62    Completed {
63        /// Completion time
64        completed_at: DateTime<Utc>,
65    },
66    /// Rolled back
67    RolledBack {
68        /// Rollback time
69        rolled_back_at: DateTime<Utc>,
70        /// Reason for rollback
71        reason: String,
72    },
73}
74
75/// Rollout configuration
76#[derive(Debug, Clone)]
77pub struct RolloutConfig {
78    /// Experiment ID
79    pub experiment_id: String,
80    /// Winning variant
81    pub variant: Variant,
82    /// Deployment strategy
83    pub strategy: DeploymentStrategy,
84    /// Health checks
85    pub health_checks: Vec<HealthCheck>,
86    /// Rollback conditions
87    pub rollback_conditions: Vec<RollbackCondition>,
88}
89
90/// Health check configuration
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct HealthCheck {
93    /// Check name
94    pub name: String,
95    /// Check type
96    pub check_type: HealthCheckType,
97    /// Threshold for failure
98    pub threshold: f64,
99    /// Number of consecutive failures to trigger
100    pub consecutive_failures: u32,
101}
102
103/// Types of health checks
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub enum HealthCheckType {
106    /// Error rate check
107    ErrorRate,
108    /// Latency check (p99)
109    LatencyP99,
110    /// CPU usage
111    CpuUsage,
112    /// Memory usage
113    MemoryUsage,
114    /// Custom metric
115    Custom(String),
116}
117
118/// Conditions that trigger automatic rollback
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct RollbackCondition {
121    /// Condition name
122    pub name: String,
123    /// Metric to monitor
124    pub metric: String,
125    /// Comparison operator
126    pub operator: ComparisonOperator,
127    /// Threshold value
128    pub threshold: f64,
129    /// Duration condition must be true
130    pub duration: Duration,
131}
132
133/// Comparison operators for rollback conditions
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub enum ComparisonOperator {
136    GreaterThan,
137    LessThan,
138    GreaterThanOrEqual,
139    LessThanOrEqual,
140}
141
142/// Rollout controller
143pub struct RolloutController {
144    /// Active rollouts
145    rollouts: Arc<RwLock<HashMap<String, ActiveRollout>>>,
146    /// Health monitor
147    health_monitor: Arc<HealthMonitor>,
148}
149
150/// Active rollout state
151struct ActiveRollout {
152    /// Configuration
153    config: RolloutConfig,
154    /// Current status
155    status: RolloutStatus,
156    /// Health check failures
157    health_failures: HashMap<String, u32>,
158    /// Rollback condition states
159    rollback_states: HashMap<String, RollbackState>,
160}
161
162/// State for tracking rollback conditions
163struct RollbackState {
164    /// When condition was first met
165    first_triggered: Option<DateTime<Utc>>,
166    /// Current value
167    current_value: f64,
168}
169
170/// Health monitoring service
171struct HealthMonitor {
172    /// Metrics storage
173    metrics: Arc<RwLock<HashMap<String, f64>>>,
174}
175
176impl Default for RolloutController {
177    fn default() -> Self {
178        Self::new()
179    }
180}
181
182impl RolloutController {
183    /// Create a new rollout controller
184    pub fn new() -> Self {
185        Self {
186            rollouts: Arc::new(RwLock::new(HashMap::new())),
187            health_monitor: Arc::new(HealthMonitor::new()),
188        }
189    }
190
191    /// Start a new rollout
192    pub fn start_rollout(&self, config: RolloutConfig) -> Result<()> {
193        let experiment_id = config.experiment_id.clone();
194
195        let status = match &config.strategy {
196            DeploymentStrategy::Immediate => RolloutStatus::InProgress {
197                current_percentage: 100.0,
198                started_at: Utc::now(),
199            },
200            DeploymentStrategy::Gradual {
201                initial_percentage, ..
202            } => RolloutStatus::InProgress {
203                current_percentage: *initial_percentage,
204                started_at: Utc::now(),
205            },
206            DeploymentStrategy::Canary {
207                canary_percentage, ..
208            } => RolloutStatus::InProgress {
209                current_percentage: *canary_percentage,
210                started_at: Utc::now(),
211            },
212            DeploymentStrategy::BlueGreen { .. } => RolloutStatus::InProgress {
213                current_percentage: 0.0,
214                started_at: Utc::now(),
215            },
216            DeploymentStrategy::FeatureFlag { .. } => RolloutStatus::InProgress {
217                current_percentage: 0.0,
218                started_at: Utc::now(),
219            },
220        };
221
222        let rollout = ActiveRollout {
223            config,
224            status,
225            health_failures: HashMap::new(),
226            rollback_states: HashMap::new(),
227        };
228
229        self.rollouts.write().insert(experiment_id, rollout);
230        Ok(())
231    }
232
233    /// Update rollout progress
234    pub fn update_rollout(&self, experiment_id: &str) -> Result<()> {
235        let mut rollouts = self.rollouts.write();
236        let rollout = rollouts
237            .get_mut(experiment_id)
238            .ok_or_else(|| anyhow::anyhow!("Rollout not found"))?;
239
240        // Check health and rollback conditions
241        if self.should_rollback(rollout)? {
242            rollout.status = RolloutStatus::RolledBack {
243                rolled_back_at: Utc::now(),
244                reason: "Health check or rollback condition triggered".to_string(),
245            };
246            return Ok(());
247        }
248
249        // Update based on strategy
250        match &rollout.config.strategy {
251            DeploymentStrategy::Gradual {
252                increment,
253                increment_interval,
254                ..
255            } => {
256                if let RolloutStatus::InProgress {
257                    current_percentage,
258                    started_at,
259                } = &rollout.status
260                {
261                    let elapsed = Utc::now() - *started_at;
262                    let steps = (elapsed.num_seconds() / increment_interval.num_seconds()) as f64;
263                    let new_percentage = (current_percentage + steps * increment).min(100.0);
264
265                    if new_percentage >= 100.0 {
266                        rollout.status = RolloutStatus::Completed {
267                            completed_at: Utc::now(),
268                        };
269                    } else {
270                        rollout.status = RolloutStatus::InProgress {
271                            current_percentage: new_percentage,
272                            started_at: *started_at,
273                        };
274                    }
275                }
276            },
277            DeploymentStrategy::Canary {
278                monitoring_duration,
279                ..
280            } => {
281                if let RolloutStatus::InProgress { started_at, .. } = &rollout.status {
282                    if Utc::now() - *started_at > *monitoring_duration {
283                        rollout.status = RolloutStatus::Completed {
284                            completed_at: Utc::now(),
285                        };
286                    }
287                }
288            },
289            DeploymentStrategy::BlueGreen { warmup_duration } => {
290                if let RolloutStatus::InProgress { started_at, .. } = &rollout.status {
291                    if Utc::now() - *started_at > *warmup_duration {
292                        rollout.status = RolloutStatus::InProgress {
293                            current_percentage: 100.0,
294                            started_at: *started_at,
295                        };
296                    }
297                }
298            },
299            _ => {},
300        }
301
302        Ok(())
303    }
304
305    /// Check if rollback is needed
306    fn should_rollback(&self, rollout: &mut ActiveRollout) -> Result<bool> {
307        // Check health checks
308        for health_check in &rollout.config.health_checks {
309            let metric_value = self.health_monitor.get_metric(&health_check.name)?;
310
311            let failed = match health_check.check_type {
312                HealthCheckType::ErrorRate => metric_value > health_check.threshold,
313                HealthCheckType::LatencyP99 => metric_value > health_check.threshold,
314                HealthCheckType::CpuUsage => metric_value > health_check.threshold,
315                HealthCheckType::MemoryUsage => metric_value > health_check.threshold,
316                HealthCheckType::Custom(_) => false, // Would need custom logic
317            };
318
319            if failed {
320                let failures =
321                    rollout.health_failures.entry(health_check.name.clone()).or_insert(0);
322                *failures += 1;
323
324                if *failures >= health_check.consecutive_failures {
325                    return Ok(true);
326                }
327            } else {
328                rollout.health_failures.remove(&health_check.name);
329            }
330        }
331
332        // Check rollback conditions
333        for condition in &rollout.config.rollback_conditions {
334            let metric_value = self.health_monitor.get_metric(&condition.metric)?;
335
336            let triggered = match condition.operator {
337                ComparisonOperator::GreaterThan => metric_value > condition.threshold,
338                ComparisonOperator::LessThan => metric_value < condition.threshold,
339                ComparisonOperator::GreaterThanOrEqual => metric_value >= condition.threshold,
340                ComparisonOperator::LessThanOrEqual => metric_value <= condition.threshold,
341            };
342
343            let state =
344                rollout.rollback_states.entry(condition.name.clone()).or_insert(RollbackState {
345                    first_triggered: None,
346                    current_value: metric_value,
347                });
348
349            state.current_value = metric_value;
350
351            if !triggered {
352                state.first_triggered = None;
353                continue;
354            }
355
356            if state.first_triggered.is_none() {
357                state.first_triggered = Some(Utc::now());
358                continue;
359            }
360
361            if let Some(first_triggered) = state.first_triggered {
362                if Utc::now() - first_triggered >= condition.duration {
363                    return Ok(true);
364                }
365            }
366        }
367
368        Ok(false)
369    }
370
371    /// Promote a variant (start rollout)
372    pub fn promote(&self, experiment_id: &str, variant: &Variant) -> Result<()> {
373        let config = RolloutConfig {
374            experiment_id: experiment_id.to_string(),
375            variant: variant.clone(),
376            strategy: DeploymentStrategy::Gradual {
377                initial_percentage: 10.0,
378                increment: 10.0,
379                increment_interval: Duration::hours(1),
380            },
381            health_checks: vec![
382                HealthCheck {
383                    name: "error_rate".to_string(),
384                    check_type: HealthCheckType::ErrorRate,
385                    threshold: 0.05,
386                    consecutive_failures: 3,
387                },
388                HealthCheck {
389                    name: "latency_p99".to_string(),
390                    check_type: HealthCheckType::LatencyP99,
391                    threshold: 1000.0,
392                    consecutive_failures: 3,
393                },
394            ],
395            rollback_conditions: vec![RollbackCondition {
396                name: "sustained_errors".to_string(),
397                metric: "error_rate".to_string(),
398                operator: ComparisonOperator::GreaterThan,
399                threshold: 0.1,
400                duration: Duration::minutes(5),
401            }],
402        };
403
404        self.start_rollout(config)
405    }
406
407    /// Rollback to control
408    pub fn rollback(&self, experiment_id: &str) -> Result<()> {
409        let mut rollouts = self.rollouts.write();
410        if let Some(rollout) = rollouts.get_mut(experiment_id) {
411            rollout.status = RolloutStatus::RolledBack {
412                rolled_back_at: Utc::now(),
413                reason: "Manual rollback".to_string(),
414            };
415        }
416        Ok(())
417    }
418
419    /// Get rollout status
420    pub fn get_status(&self, experiment_id: &str) -> Result<RolloutStatus> {
421        let rollouts = self.rollouts.read();
422        let rollout = rollouts
423            .get(experiment_id)
424            .ok_or_else(|| anyhow::anyhow!("Rollout not found"))?;
425        Ok(rollout.status.clone())
426    }
427
428    /// Pause rollout
429    pub fn pause(&self, experiment_id: &str) -> Result<()> {
430        let mut rollouts = self.rollouts.write();
431        if let Some(rollout) = rollouts.get_mut(experiment_id) {
432            if let RolloutStatus::InProgress {
433                current_percentage, ..
434            } = rollout.status
435            {
436                rollout.status = RolloutStatus::Paused {
437                    paused_at_percentage: current_percentage,
438                };
439            }
440        }
441        Ok(())
442    }
443
444    /// Resume rollout
445    pub fn resume(&self, experiment_id: &str) -> Result<()> {
446        let mut rollouts = self.rollouts.write();
447        if let Some(rollout) = rollouts.get_mut(experiment_id) {
448            if let RolloutStatus::Paused {
449                paused_at_percentage,
450            } = rollout.status
451            {
452                rollout.status = RolloutStatus::InProgress {
453                    current_percentage: paused_at_percentage,
454                    started_at: Utc::now(),
455                };
456            }
457        }
458        Ok(())
459    }
460}
461
462impl HealthMonitor {
463    fn new() -> Self {
464        Self {
465            metrics: Arc::new(RwLock::new(HashMap::new())),
466        }
467    }
468
469    fn get_metric(&self, name: &str) -> Result<f64> {
470        let metrics = self.metrics.read();
471        Ok(metrics.get(name).copied().unwrap_or(0.0))
472    }
473
474    #[allow(dead_code)]
475    pub fn update_metric(&self, name: &str, value: f64) {
476        self.metrics.write().insert(name.to_string(), value);
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483
484    #[test]
485    fn test_immediate_deployment() {
486        let controller = RolloutController::new();
487        let config = RolloutConfig {
488            experiment_id: "exp1".to_string(),
489            variant: Variant::new("winner", "model-v2"),
490            strategy: DeploymentStrategy::Immediate,
491            health_checks: vec![],
492            rollback_conditions: vec![],
493        };
494
495        controller.start_rollout(config).expect("operation failed in test");
496
497        match controller.get_status("exp1").expect("operation failed in test") {
498            RolloutStatus::InProgress {
499                current_percentage, ..
500            } => {
501                assert_eq!(current_percentage, 100.0);
502            },
503            _ => panic!("Expected InProgress status"),
504        }
505    }
506
507    #[test]
508    fn test_gradual_rollout() {
509        let controller = RolloutController::new();
510        let config = RolloutConfig {
511            experiment_id: "exp2".to_string(),
512            variant: Variant::new("winner", "model-v2"),
513            strategy: DeploymentStrategy::Gradual {
514                initial_percentage: 10.0,
515                increment: 20.0,
516                increment_interval: Duration::seconds(1),
517            },
518            health_checks: vec![],
519            rollback_conditions: vec![],
520        };
521
522        controller.start_rollout(config).expect("operation failed in test");
523
524        // Initial percentage
525        match controller.get_status("exp2").expect("operation failed in test") {
526            RolloutStatus::InProgress {
527                current_percentage, ..
528            } => {
529                assert_eq!(current_percentage, 10.0);
530            },
531            _ => panic!("Expected InProgress status"),
532        }
533
534        // Wait and update
535        std::thread::sleep(std::time::Duration::from_secs(2));
536        controller.update_rollout("exp2").expect("operation failed in test");
537
538        // Should have increased
539        match controller.get_status("exp2").expect("operation failed in test") {
540            RolloutStatus::InProgress {
541                current_percentage, ..
542            } => {
543                assert!(current_percentage > 10.0);
544            },
545            _ => panic!("Expected InProgress status"),
546        }
547    }
548
549    #[test]
550    fn test_health_check_rollback() {
551        let controller = RolloutController::new();
552        let config = RolloutConfig {
553            experiment_id: "exp3".to_string(),
554            variant: Variant::new("winner", "model-v2"),
555            strategy: DeploymentStrategy::Canary {
556                canary_percentage: 5.0,
557                monitoring_duration: Duration::hours(1),
558            },
559            health_checks: vec![HealthCheck {
560                name: "error_rate".to_string(),
561                check_type: HealthCheckType::ErrorRate,
562                threshold: 0.05,
563                consecutive_failures: 1,
564            }],
565            rollback_conditions: vec![],
566        };
567
568        controller.start_rollout(config).expect("operation failed in test");
569
570        // Simulate high error rate
571        controller.health_monitor.update_metric("error_rate", 0.1);
572        controller.update_rollout("exp3").expect("operation failed in test");
573
574        // Should be rolled back
575        match controller.get_status("exp3").expect("operation failed in test") {
576            RolloutStatus::RolledBack { reason, .. } => {
577                assert!(reason.contains("Health check"));
578            },
579            _ => panic!("Expected RolledBack status"),
580        }
581    }
582}