1use super::Variant;
4use anyhow::Result;
5use chrono::{DateTime, Duration, Utc};
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub enum DeploymentStrategy {
14 Immediate,
16 Gradual {
18 initial_percentage: f64,
20 increment: f64,
22 increment_interval: Duration,
24 },
25 Canary {
27 canary_percentage: f64,
29 monitoring_duration: Duration,
31 },
32 BlueGreen {
34 warmup_duration: Duration,
36 },
37 FeatureFlag {
39 flag_name: String,
41 },
42}
43
44#[derive(Debug, Clone, PartialEq)]
46pub enum RolloutStatus {
47 NotStarted,
49 InProgress {
51 current_percentage: f64,
53 started_at: DateTime<Utc>,
55 },
56 Paused {
58 paused_at_percentage: f64,
60 },
61 Completed {
63 completed_at: DateTime<Utc>,
65 },
66 RolledBack {
68 rolled_back_at: DateTime<Utc>,
70 reason: String,
72 },
73}
74
75#[derive(Debug, Clone)]
77pub struct RolloutConfig {
78 pub experiment_id: String,
80 pub variant: Variant,
82 pub strategy: DeploymentStrategy,
84 pub health_checks: Vec<HealthCheck>,
86 pub rollback_conditions: Vec<RollbackCondition>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct HealthCheck {
93 pub name: String,
95 pub check_type: HealthCheckType,
97 pub threshold: f64,
99 pub consecutive_failures: u32,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub enum HealthCheckType {
106 ErrorRate,
108 LatencyP99,
110 CpuUsage,
112 MemoryUsage,
114 Custom(String),
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct RollbackCondition {
121 pub name: String,
123 pub metric: String,
125 pub operator: ComparisonOperator,
127 pub threshold: f64,
129 pub duration: Duration,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub enum ComparisonOperator {
136 GreaterThan,
137 LessThan,
138 GreaterThanOrEqual,
139 LessThanOrEqual,
140}
141
142pub struct RolloutController {
144 rollouts: Arc<RwLock<HashMap<String, ActiveRollout>>>,
146 health_monitor: Arc<HealthMonitor>,
148}
149
150struct ActiveRollout {
152 config: RolloutConfig,
154 status: RolloutStatus,
156 health_failures: HashMap<String, u32>,
158 rollback_states: HashMap<String, RollbackState>,
160}
161
162struct RollbackState {
164 first_triggered: Option<DateTime<Utc>>,
166 current_value: f64,
168}
169
170struct HealthMonitor {
172 metrics: Arc<RwLock<HashMap<String, f64>>>,
174}
175
176impl Default for RolloutController {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182impl RolloutController {
183 pub fn new() -> Self {
185 Self {
186 rollouts: Arc::new(RwLock::new(HashMap::new())),
187 health_monitor: Arc::new(HealthMonitor::new()),
188 }
189 }
190
191 pub fn start_rollout(&self, config: RolloutConfig) -> Result<()> {
193 let experiment_id = config.experiment_id.clone();
194
195 let status = match &config.strategy {
196 DeploymentStrategy::Immediate => RolloutStatus::InProgress {
197 current_percentage: 100.0,
198 started_at: Utc::now(),
199 },
200 DeploymentStrategy::Gradual {
201 initial_percentage, ..
202 } => RolloutStatus::InProgress {
203 current_percentage: *initial_percentage,
204 started_at: Utc::now(),
205 },
206 DeploymentStrategy::Canary {
207 canary_percentage, ..
208 } => RolloutStatus::InProgress {
209 current_percentage: *canary_percentage,
210 started_at: Utc::now(),
211 },
212 DeploymentStrategy::BlueGreen { .. } => RolloutStatus::InProgress {
213 current_percentage: 0.0,
214 started_at: Utc::now(),
215 },
216 DeploymentStrategy::FeatureFlag { .. } => RolloutStatus::InProgress {
217 current_percentage: 0.0,
218 started_at: Utc::now(),
219 },
220 };
221
222 let rollout = ActiveRollout {
223 config,
224 status,
225 health_failures: HashMap::new(),
226 rollback_states: HashMap::new(),
227 };
228
229 self.rollouts.write().insert(experiment_id, rollout);
230 Ok(())
231 }
232
233 pub fn update_rollout(&self, experiment_id: &str) -> Result<()> {
235 let mut rollouts = self.rollouts.write();
236 let rollout = rollouts
237 .get_mut(experiment_id)
238 .ok_or_else(|| anyhow::anyhow!("Rollout not found"))?;
239
240 if self.should_rollback(rollout)? {
242 rollout.status = RolloutStatus::RolledBack {
243 rolled_back_at: Utc::now(),
244 reason: "Health check or rollback condition triggered".to_string(),
245 };
246 return Ok(());
247 }
248
249 match &rollout.config.strategy {
251 DeploymentStrategy::Gradual {
252 increment,
253 increment_interval,
254 ..
255 } => {
256 if let RolloutStatus::InProgress {
257 current_percentage,
258 started_at,
259 } = &rollout.status
260 {
261 let elapsed = Utc::now() - *started_at;
262 let steps = (elapsed.num_seconds() / increment_interval.num_seconds()) as f64;
263 let new_percentage = (current_percentage + steps * increment).min(100.0);
264
265 if new_percentage >= 100.0 {
266 rollout.status = RolloutStatus::Completed {
267 completed_at: Utc::now(),
268 };
269 } else {
270 rollout.status = RolloutStatus::InProgress {
271 current_percentage: new_percentage,
272 started_at: *started_at,
273 };
274 }
275 }
276 },
277 DeploymentStrategy::Canary {
278 monitoring_duration,
279 ..
280 } => {
281 if let RolloutStatus::InProgress { started_at, .. } = &rollout.status {
282 if Utc::now() - *started_at > *monitoring_duration {
283 rollout.status = RolloutStatus::Completed {
284 completed_at: Utc::now(),
285 };
286 }
287 }
288 },
289 DeploymentStrategy::BlueGreen { warmup_duration } => {
290 if let RolloutStatus::InProgress { started_at, .. } = &rollout.status {
291 if Utc::now() - *started_at > *warmup_duration {
292 rollout.status = RolloutStatus::InProgress {
293 current_percentage: 100.0,
294 started_at: *started_at,
295 };
296 }
297 }
298 },
299 _ => {},
300 }
301
302 Ok(())
303 }
304
305 fn should_rollback(&self, rollout: &mut ActiveRollout) -> Result<bool> {
307 for health_check in &rollout.config.health_checks {
309 let metric_value = self.health_monitor.get_metric(&health_check.name)?;
310
311 let failed = match health_check.check_type {
312 HealthCheckType::ErrorRate => metric_value > health_check.threshold,
313 HealthCheckType::LatencyP99 => metric_value > health_check.threshold,
314 HealthCheckType::CpuUsage => metric_value > health_check.threshold,
315 HealthCheckType::MemoryUsage => metric_value > health_check.threshold,
316 HealthCheckType::Custom(_) => false, };
318
319 if failed {
320 let failures =
321 rollout.health_failures.entry(health_check.name.clone()).or_insert(0);
322 *failures += 1;
323
324 if *failures >= health_check.consecutive_failures {
325 return Ok(true);
326 }
327 } else {
328 rollout.health_failures.remove(&health_check.name);
329 }
330 }
331
332 for condition in &rollout.config.rollback_conditions {
334 let metric_value = self.health_monitor.get_metric(&condition.metric)?;
335
336 let triggered = match condition.operator {
337 ComparisonOperator::GreaterThan => metric_value > condition.threshold,
338 ComparisonOperator::LessThan => metric_value < condition.threshold,
339 ComparisonOperator::GreaterThanOrEqual => metric_value >= condition.threshold,
340 ComparisonOperator::LessThanOrEqual => metric_value <= condition.threshold,
341 };
342
343 let state =
344 rollout.rollback_states.entry(condition.name.clone()).or_insert(RollbackState {
345 first_triggered: None,
346 current_value: metric_value,
347 });
348
349 state.current_value = metric_value;
350
351 if !triggered {
352 state.first_triggered = None;
353 continue;
354 }
355
356 if state.first_triggered.is_none() {
357 state.first_triggered = Some(Utc::now());
358 continue;
359 }
360
361 if let Some(first_triggered) = state.first_triggered {
362 if Utc::now() - first_triggered >= condition.duration {
363 return Ok(true);
364 }
365 }
366 }
367
368 Ok(false)
369 }
370
371 pub fn promote(&self, experiment_id: &str, variant: &Variant) -> Result<()> {
373 let config = RolloutConfig {
374 experiment_id: experiment_id.to_string(),
375 variant: variant.clone(),
376 strategy: DeploymentStrategy::Gradual {
377 initial_percentage: 10.0,
378 increment: 10.0,
379 increment_interval: Duration::hours(1),
380 },
381 health_checks: vec![
382 HealthCheck {
383 name: "error_rate".to_string(),
384 check_type: HealthCheckType::ErrorRate,
385 threshold: 0.05,
386 consecutive_failures: 3,
387 },
388 HealthCheck {
389 name: "latency_p99".to_string(),
390 check_type: HealthCheckType::LatencyP99,
391 threshold: 1000.0,
392 consecutive_failures: 3,
393 },
394 ],
395 rollback_conditions: vec![RollbackCondition {
396 name: "sustained_errors".to_string(),
397 metric: "error_rate".to_string(),
398 operator: ComparisonOperator::GreaterThan,
399 threshold: 0.1,
400 duration: Duration::minutes(5),
401 }],
402 };
403
404 self.start_rollout(config)
405 }
406
407 pub fn rollback(&self, experiment_id: &str) -> Result<()> {
409 let mut rollouts = self.rollouts.write();
410 if let Some(rollout) = rollouts.get_mut(experiment_id) {
411 rollout.status = RolloutStatus::RolledBack {
412 rolled_back_at: Utc::now(),
413 reason: "Manual rollback".to_string(),
414 };
415 }
416 Ok(())
417 }
418
419 pub fn get_status(&self, experiment_id: &str) -> Result<RolloutStatus> {
421 let rollouts = self.rollouts.read();
422 let rollout = rollouts
423 .get(experiment_id)
424 .ok_or_else(|| anyhow::anyhow!("Rollout not found"))?;
425 Ok(rollout.status.clone())
426 }
427
428 pub fn pause(&self, experiment_id: &str) -> Result<()> {
430 let mut rollouts = self.rollouts.write();
431 if let Some(rollout) = rollouts.get_mut(experiment_id) {
432 if let RolloutStatus::InProgress {
433 current_percentage, ..
434 } = rollout.status
435 {
436 rollout.status = RolloutStatus::Paused {
437 paused_at_percentage: current_percentage,
438 };
439 }
440 }
441 Ok(())
442 }
443
444 pub fn resume(&self, experiment_id: &str) -> Result<()> {
446 let mut rollouts = self.rollouts.write();
447 if let Some(rollout) = rollouts.get_mut(experiment_id) {
448 if let RolloutStatus::Paused {
449 paused_at_percentage,
450 } = rollout.status
451 {
452 rollout.status = RolloutStatus::InProgress {
453 current_percentage: paused_at_percentage,
454 started_at: Utc::now(),
455 };
456 }
457 }
458 Ok(())
459 }
460}
461
462impl HealthMonitor {
463 fn new() -> Self {
464 Self {
465 metrics: Arc::new(RwLock::new(HashMap::new())),
466 }
467 }
468
469 fn get_metric(&self, name: &str) -> Result<f64> {
470 let metrics = self.metrics.read();
471 Ok(metrics.get(name).copied().unwrap_or(0.0))
472 }
473
474 #[allow(dead_code)]
475 pub fn update_metric(&self, name: &str, value: f64) {
476 self.metrics.write().insert(name.to_string(), value);
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483
484 #[test]
485 fn test_immediate_deployment() {
486 let controller = RolloutController::new();
487 let config = RolloutConfig {
488 experiment_id: "exp1".to_string(),
489 variant: Variant::new("winner", "model-v2"),
490 strategy: DeploymentStrategy::Immediate,
491 health_checks: vec![],
492 rollback_conditions: vec![],
493 };
494
495 controller.start_rollout(config).expect("operation failed in test");
496
497 match controller.get_status("exp1").expect("operation failed in test") {
498 RolloutStatus::InProgress {
499 current_percentage, ..
500 } => {
501 assert_eq!(current_percentage, 100.0);
502 },
503 _ => panic!("Expected InProgress status"),
504 }
505 }
506
507 #[test]
508 fn test_gradual_rollout() {
509 let controller = RolloutController::new();
510 let config = RolloutConfig {
511 experiment_id: "exp2".to_string(),
512 variant: Variant::new("winner", "model-v2"),
513 strategy: DeploymentStrategy::Gradual {
514 initial_percentage: 10.0,
515 increment: 20.0,
516 increment_interval: Duration::seconds(1),
517 },
518 health_checks: vec![],
519 rollback_conditions: vec![],
520 };
521
522 controller.start_rollout(config).expect("operation failed in test");
523
524 match controller.get_status("exp2").expect("operation failed in test") {
526 RolloutStatus::InProgress {
527 current_percentage, ..
528 } => {
529 assert_eq!(current_percentage, 10.0);
530 },
531 _ => panic!("Expected InProgress status"),
532 }
533
534 std::thread::sleep(std::time::Duration::from_secs(2));
536 controller.update_rollout("exp2").expect("operation failed in test");
537
538 match controller.get_status("exp2").expect("operation failed in test") {
540 RolloutStatus::InProgress {
541 current_percentage, ..
542 } => {
543 assert!(current_percentage > 10.0);
544 },
545 _ => panic!("Expected InProgress status"),
546 }
547 }
548
549 #[test]
550 fn test_health_check_rollback() {
551 let controller = RolloutController::new();
552 let config = RolloutConfig {
553 experiment_id: "exp3".to_string(),
554 variant: Variant::new("winner", "model-v2"),
555 strategy: DeploymentStrategy::Canary {
556 canary_percentage: 5.0,
557 monitoring_duration: Duration::hours(1),
558 },
559 health_checks: vec![HealthCheck {
560 name: "error_rate".to_string(),
561 check_type: HealthCheckType::ErrorRate,
562 threshold: 0.05,
563 consecutive_failures: 1,
564 }],
565 rollback_conditions: vec![],
566 };
567
568 controller.start_rollout(config).expect("operation failed in test");
569
570 controller.health_monitor.update_metric("error_rate", 0.1);
572 controller.update_rollout("exp3").expect("operation failed in test");
573
574 match controller.get_status("exp3").expect("operation failed in test") {
576 RolloutStatus::RolledBack { reason, .. } => {
577 assert!(reason.contains("Health check"));
578 },
579 _ => panic!("Expected RolledBack status"),
580 }
581 }
582}