1use crate::error::{AgentError, Result};
9use crate::runtime::{ContainerId, ContainerState, Runtime};
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::{mpsc, Notify, RwLock};
15use zlayer_spec::{PanicAction, ServiceSpec};
16
17pub type IsolateCallback = Arc<dyn Fn(&ContainerId) + Send + Sync>;
19
20const DEFAULT_MAX_RESTARTS: u32 = 5;
22const DEFAULT_RESTART_WINDOW: Duration = Duration::from_secs(300); const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(5);
26const CRASH_LOOP_BACKOFF_DELAY: Duration = Duration::from_secs(30);
28
29#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum SupervisedState {
32 Running,
34 Restarting,
36 CrashLoopBackOff,
38 Isolated,
40 Shutdown,
42 Completed,
44}
45
46#[derive(Debug, Clone)]
48pub struct SupervisedContainer {
49 pub id: ContainerId,
51 pub service_name: String,
53 pub state: SupervisedState,
55 pub panic_action: PanicAction,
57 pub restart_times: Vec<Instant>,
59 pub total_restarts: u32,
61 pub last_exit_code: Option<i32>,
63 pub supervised_since: Instant,
65}
66
67impl SupervisedContainer {
68 #[must_use]
70 pub fn new(id: ContainerId, service_name: String, panic_action: PanicAction) -> Self {
71 Self {
72 id,
73 service_name,
74 state: SupervisedState::Running,
75 panic_action,
76 restart_times: Vec::new(),
77 total_restarts: 0,
78 last_exit_code: None,
79 supervised_since: Instant::now(),
80 }
81 }
82
83 pub fn record_restart(&mut self, window: Duration, max_restarts: u32) -> bool {
85 let now = Instant::now();
86 self.restart_times.push(now);
87 self.total_restarts += 1;
88
89 self.restart_times
91 .retain(|&t| now.duration_since(t) < window);
92
93 #[allow(clippy::cast_possible_truncation)]
95 let count = self.restart_times.len() as u32;
96 count > max_restarts
97 }
98
99 #[must_use]
101 pub fn should_monitor(&self) -> bool {
102 matches!(
103 self.state,
104 SupervisedState::Running | SupervisedState::CrashLoopBackOff
105 )
106 }
107}
108
109#[derive(Debug, Clone)]
111pub enum SupervisorEvent {
112 ContainerRestarted {
114 id: ContainerId,
115 service_name: String,
116 exit_code: i32,
117 restart_count: u32,
118 },
119 CrashLoopBackOff {
121 id: ContainerId,
122 service_name: String,
123 restart_count: u32,
124 },
125 ContainerIsolated {
127 id: ContainerId,
128 service_name: String,
129 exit_code: i32,
130 },
131 ServiceShutdown {
133 id: ContainerId,
134 service_name: String,
135 exit_code: i32,
136 },
137 ContainerCompleted {
139 id: ContainerId,
140 service_name: String,
141 },
142}
143
144#[derive(Debug, Clone)]
146pub struct SupervisorConfig {
147 pub max_restarts: u32,
149 pub restart_window: Duration,
151 pub poll_interval: Duration,
153}
154
155impl Default for SupervisorConfig {
156 fn default() -> Self {
157 Self {
158 max_restarts: DEFAULT_MAX_RESTARTS,
159 restart_window: DEFAULT_RESTART_WINDOW,
160 poll_interval: DEFAULT_POLL_INTERVAL,
161 }
162 }
163}
164
165pub struct ContainerSupervisor {
167 runtime: Arc<dyn Runtime + Send + Sync>,
169 containers: Arc<RwLock<HashMap<ContainerId, SupervisedContainer>>>,
171 config: SupervisorConfig,
173 event_tx: mpsc::Sender<SupervisorEvent>,
175 event_rx: Arc<RwLock<mpsc::Receiver<SupervisorEvent>>>,
177 running: Arc<AtomicBool>,
179 shutdown: Arc<Notify>,
181 on_isolate: Option<IsolateCallback>,
183}
184
185impl ContainerSupervisor {
186 pub fn new(runtime: Arc<dyn Runtime + Send + Sync>) -> Self {
188 Self::with_config(runtime, SupervisorConfig::default())
189 }
190
191 pub fn with_config(runtime: Arc<dyn Runtime + Send + Sync>, config: SupervisorConfig) -> Self {
193 let (event_tx, event_rx) = mpsc::channel(100);
194
195 Self {
196 runtime,
197 containers: Arc::new(RwLock::new(HashMap::new())),
198 config,
199 event_tx,
200 event_rx: Arc::new(RwLock::new(event_rx)),
201 running: Arc::new(AtomicBool::new(false)),
202 shutdown: Arc::new(Notify::new()),
203 on_isolate: None,
204 }
205 }
206
207 pub fn set_isolate_callback<F>(&mut self, callback: F)
209 where
210 F: Fn(&ContainerId) + Send + Sync + 'static,
211 {
212 self.on_isolate = Some(Arc::new(callback));
213 }
214
215 pub async fn supervise(&self, container_id: &ContainerId, spec: &ServiceSpec) {
221 let supervised = SupervisedContainer::new(
222 container_id.clone(),
223 container_id.service.clone(),
224 spec.errors.on_panic.action,
225 );
226
227 let mut containers = self.containers.write().await;
228 containers.insert(container_id.clone(), supervised);
229
230 tracing::info!(
231 container = %container_id,
232 panic_action = ?spec.errors.on_panic.action,
233 "Container registered for supervision"
234 );
235 }
236
237 pub async fn unsupervise(&self, container_id: &ContainerId) {
239 let mut containers = self.containers.write().await;
240 if containers.remove(container_id).is_some() {
241 tracing::debug!(container = %container_id, "Container removed from supervision");
242 }
243 }
244
245 pub async fn get_state(&self, container_id: &ContainerId) -> Option<SupervisedState> {
247 let containers = self.containers.read().await;
248 containers.get(container_id).map(|c| c.state.clone())
249 }
250
251 pub async fn get_container_info(
253 &self,
254 container_id: &ContainerId,
255 ) -> Option<SupervisedContainer> {
256 let containers = self.containers.read().await;
257 containers.get(container_id).cloned()
258 }
259
260 pub async fn list_supervised(&self) -> Vec<SupervisedContainer> {
262 let containers = self.containers.read().await;
263 containers.values().cloned().collect()
264 }
265
266 pub async fn take_event_receiver(&self) -> Option<mpsc::Receiver<SupervisorEvent>> {
270 let mut rx_guard = self.event_rx.write().await;
271 let (_, dummy_rx) = mpsc::channel(1);
273 let old_rx = std::mem::replace(&mut *rx_guard, dummy_rx);
274 Some(old_rx)
275 }
276
277 pub async fn run_loop(&self) {
282 self.running.store(true, Ordering::SeqCst);
283
284 tracing::info!(
285 poll_interval_ms = self.config.poll_interval.as_millis(),
286 "Container supervisor started"
287 );
288
289 loop {
290 tokio::select! {
291 () = self.shutdown.notified() => {
292 tracing::info!("Container supervisor shutting down");
293 break;
294 }
295 () = tokio::time::sleep(self.config.poll_interval) => {
296 if let Err(e) = self.check_all_containers().await {
297 tracing::error!(error = %e, "Error during container health check");
298 }
299 }
300 }
301 }
302
303 self.running.store(false, Ordering::SeqCst);
304 }
305
306 pub fn shutdown(&self) {
308 self.shutdown.notify_one();
309 }
310
311 #[must_use]
313 pub fn is_running(&self) -> bool {
314 self.running.load(Ordering::SeqCst)
315 }
316
317 async fn check_all_containers(&self) -> Result<()> {
319 let containers_to_check: Vec<_> = {
320 let containers = self.containers.read().await;
321 containers
322 .iter()
323 .filter(|(_, c)| c.should_monitor())
324 .map(|(id, c)| (id.clone(), c.panic_action))
325 .collect()
326 };
327
328 for (container_id, panic_action) in containers_to_check {
329 self.check_container(&container_id, panic_action).await?;
330 }
331
332 Ok(())
333 }
334
335 async fn check_container(
337 &self,
338 container_id: &ContainerId,
339 panic_action: PanicAction,
340 ) -> Result<()> {
341 let state = self.runtime.container_state(container_id).await?;
342
343 match state {
344 ContainerState::Running
345 | ContainerState::Pending
346 | ContainerState::Initializing
347 | ContainerState::Stopping => {
348 }
350 ContainerState::Exited { code } => {
351 self.handle_container_exit(container_id, code, panic_action)
353 .await?;
354 }
355 ContainerState::Failed { reason } => {
356 tracing::warn!(
358 container = %container_id,
359 reason = %reason,
360 "Container reported as failed"
361 );
362 self.handle_container_exit(container_id, -1, panic_action)
363 .await?;
364 }
365 }
366
367 Ok(())
368 }
369
370 async fn handle_container_exit(
372 &self,
373 container_id: &ContainerId,
374 exit_code: i32,
375 panic_action: PanicAction,
376 ) -> Result<()> {
377 let (service_name, _should_restart, in_crash_loop) = {
379 let mut containers = self.containers.write().await;
380 let Some(container) = containers.get_mut(container_id) else {
381 return Ok(()); };
383
384 container.last_exit_code = Some(exit_code);
385
386 if exit_code == 0 {
388 container.state = SupervisedState::Completed;
389 let _ = self
390 .event_tx
391 .send(SupervisorEvent::ContainerCompleted {
392 id: container_id.clone(),
393 service_name: container.service_name.clone(),
394 })
395 .await;
396 return Ok(());
397 }
398
399 let service_name = container.service_name.clone();
400 let in_crash_loop =
401 container.record_restart(self.config.restart_window, self.config.max_restarts);
402
403 let should_restart = match panic_action {
404 PanicAction::Restart => !in_crash_loop,
405 PanicAction::Shutdown | PanicAction::Isolate => false,
406 };
407
408 if in_crash_loop && matches!(panic_action, PanicAction::Restart) {
409 container.state = SupervisedState::CrashLoopBackOff;
410 } else if should_restart {
411 container.state = SupervisedState::Restarting;
412 }
413
414 (service_name, should_restart, in_crash_loop)
415 };
416
417 match panic_action {
419 PanicAction::Restart => {
420 if in_crash_loop {
421 self.handle_crash_loop_backoff(container_id, &service_name)
422 .await?;
423 } else {
424 self.restart_container(container_id, &service_name, exit_code)
425 .await?;
426 }
427 }
428 PanicAction::Shutdown => {
429 self.shutdown_container(container_id, &service_name, exit_code)
430 .await?;
431 }
432 PanicAction::Isolate => {
433 self.isolate_container(container_id, &service_name, exit_code)
434 .await?;
435 }
436 }
437
438 Ok(())
439 }
440
441 async fn restart_container(
443 &self,
444 container_id: &ContainerId,
445 service_name: &str,
446 exit_code: i32,
447 ) -> Result<()> {
448 let restart_count = {
449 let containers = self.containers.read().await;
450 containers.get(container_id).map_or(0, |c| c.total_restarts)
451 };
452
453 tracing::info!(
454 container = %container_id,
455 service = %service_name,
456 exit_code = exit_code,
457 restart_count = restart_count,
458 "Restarting crashed container"
459 );
460
461 self.runtime
463 .start_container(container_id)
464 .await
465 .map_err(|e| AgentError::StartFailed {
466 id: container_id.to_string(),
467 reason: e.to_string(),
468 })?;
469
470 {
472 let mut containers = self.containers.write().await;
473 if let Some(container) = containers.get_mut(container_id) {
474 container.state = SupervisedState::Running;
475 }
476 }
477
478 let _ = self
480 .event_tx
481 .send(SupervisorEvent::ContainerRestarted {
482 id: container_id.clone(),
483 service_name: service_name.to_string(),
484 exit_code,
485 restart_count,
486 })
487 .await;
488
489 Ok(())
490 }
491
492 async fn handle_crash_loop_backoff(
494 &self,
495 container_id: &ContainerId,
496 service_name: &str,
497 ) -> Result<()> {
498 let restart_count = {
499 let containers = self.containers.read().await;
500 containers.get(container_id).map_or(0, |c| c.total_restarts)
501 };
502
503 tracing::warn!(
504 container = %container_id,
505 service = %service_name,
506 restart_count = restart_count,
507 backoff_delay_secs = CRASH_LOOP_BACKOFF_DELAY.as_secs(),
508 "Container in CrashLoopBackOff, delaying restart"
509 );
510
511 let _ = self
513 .event_tx
514 .send(SupervisorEvent::CrashLoopBackOff {
515 id: container_id.clone(),
516 service_name: service_name.to_string(),
517 restart_count,
518 })
519 .await;
520
521 let runtime = Arc::clone(&self.runtime);
523 let container_id = container_id.clone();
524 let containers = Arc::clone(&self.containers);
525
526 tokio::spawn(async move {
527 tokio::time::sleep(CRASH_LOOP_BACKOFF_DELAY).await;
528
529 if let Err(e) = runtime.start_container(&container_id).await {
531 tracing::error!(
532 container = %container_id,
533 error = %e,
534 "Failed to restart container after CrashLoopBackOff delay"
535 );
536 return;
537 }
538
539 let mut containers_guard = containers.write().await;
541 if let Some(container) = containers_guard.get_mut(&container_id) {
542 container.state = SupervisedState::Running;
543 }
544 });
545
546 Ok(())
547 }
548
549 async fn shutdown_container(
551 &self,
552 container_id: &ContainerId,
553 service_name: &str,
554 exit_code: i32,
555 ) -> Result<()> {
556 tracing::warn!(
557 container = %container_id,
558 service = %service_name,
559 exit_code = exit_code,
560 "Shutting down service due to panic policy"
561 );
562
563 {
565 let mut containers = self.containers.write().await;
566 if let Some(container) = containers.get_mut(container_id) {
567 container.state = SupervisedState::Shutdown;
568 }
569 }
570
571 let _ = self
573 .event_tx
574 .send(SupervisorEvent::ServiceShutdown {
575 id: container_id.clone(),
576 service_name: service_name.to_string(),
577 exit_code,
578 })
579 .await;
580
581 Ok(())
582 }
583
584 async fn isolate_container(
586 &self,
587 container_id: &ContainerId,
588 service_name: &str,
589 exit_code: i32,
590 ) -> Result<()> {
591 tracing::info!(
592 container = %container_id,
593 service = %service_name,
594 exit_code = exit_code,
595 "Isolating container (removed from load balancer for debugging)"
596 );
597
598 if let Some(ref callback) = self.on_isolate {
600 callback(container_id);
601 }
602
603 {
605 let mut containers = self.containers.write().await;
606 if let Some(container) = containers.get_mut(container_id) {
607 container.state = SupervisedState::Isolated;
608 }
609 }
610
611 let _ = self
613 .event_tx
614 .send(SupervisorEvent::ContainerIsolated {
615 id: container_id.clone(),
616 service_name: service_name.to_string(),
617 exit_code,
618 })
619 .await;
620
621 Ok(())
622 }
623
624 pub async fn supervised_count(&self) -> usize {
626 self.containers.read().await.len()
627 }
628
629 pub async fn count_by_state(&self, state: SupervisedState) -> usize {
631 self.containers
632 .read()
633 .await
634 .values()
635 .filter(|c| c.state == state)
636 .count()
637 }
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643 use crate::runtime::MockRuntime;
644
645 fn mock_container_id(service: &str, replica: u32) -> ContainerId {
646 ContainerId {
647 service: service.to_string(),
648 replica,
649 }
650 }
651
652 fn mock_service_spec(panic_action: PanicAction) -> ServiceSpec {
653 let mut spec: ServiceSpec = serde_yaml::from_str::<zlayer_spec::DeploymentSpec>(
654 r"
655version: v1
656deployment: test
657services:
658 test:
659 rtype: service
660 image:
661 name: test:latest
662",
663 )
664 .unwrap()
665 .services
666 .remove("test")
667 .unwrap();
668
669 spec.errors.on_panic.action = panic_action;
670 spec
671 }
672
673 #[tokio::test]
674 async fn test_supervisor_creation() {
675 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
676 let supervisor = ContainerSupervisor::new(runtime);
677
678 assert!(!supervisor.is_running());
679 assert_eq!(supervisor.supervised_count().await, 0);
680 }
681
682 #[tokio::test]
683 async fn test_supervisor_with_config() {
684 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
685 let config = SupervisorConfig {
686 max_restarts: 10,
687 restart_window: Duration::from_secs(600),
688 poll_interval: Duration::from_secs(1),
689 };
690
691 let supervisor = ContainerSupervisor::with_config(runtime, config);
692 assert_eq!(supervisor.config.max_restarts, 10);
693 assert_eq!(supervisor.config.restart_window, Duration::from_secs(600));
694 }
695
696 #[tokio::test]
697 async fn test_supervise_container() {
698 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
699 let supervisor = ContainerSupervisor::new(runtime);
700
701 let container_id = mock_container_id("api", 1);
702 let spec = mock_service_spec(PanicAction::Restart);
703
704 supervisor.supervise(&container_id, &spec).await;
705
706 assert_eq!(supervisor.supervised_count().await, 1);
707
708 let state = supervisor.get_state(&container_id).await;
709 assert_eq!(state, Some(SupervisedState::Running));
710 }
711
712 #[tokio::test]
713 async fn test_unsupervise_container() {
714 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
715 let supervisor = ContainerSupervisor::new(runtime);
716
717 let container_id = mock_container_id("api", 1);
718 let spec = mock_service_spec(PanicAction::Restart);
719
720 supervisor.supervise(&container_id, &spec).await;
721 assert_eq!(supervisor.supervised_count().await, 1);
722
723 supervisor.unsupervise(&container_id).await;
724 assert_eq!(supervisor.supervised_count().await, 0);
725 }
726
727 #[tokio::test]
728 async fn test_list_supervised() {
729 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
730 let supervisor = ContainerSupervisor::new(runtime);
731
732 let spec = mock_service_spec(PanicAction::Restart);
733
734 supervisor
735 .supervise(&mock_container_id("api", 1), &spec)
736 .await;
737 supervisor
738 .supervise(&mock_container_id("api", 2), &spec)
739 .await;
740 supervisor
741 .supervise(&mock_container_id("web", 1), &spec)
742 .await;
743
744 let containers = supervisor.list_supervised().await;
745 assert_eq!(containers.len(), 3);
746 }
747
748 #[tokio::test]
749 async fn test_supervised_container_record_restart() {
750 let mut container = SupervisedContainer::new(
751 mock_container_id("api", 1),
752 "api".to_string(),
753 PanicAction::Restart,
754 );
755
756 for _ in 0..5 {
758 let in_loop = container.record_restart(Duration::from_secs(300), 5);
759 assert!(!in_loop);
760 }
761
762 let in_loop = container.record_restart(Duration::from_secs(300), 5);
764 assert!(in_loop);
765 }
766
767 #[tokio::test]
768 async fn test_supervised_container_restart_window() {
769 let mut container = SupervisedContainer::new(
770 mock_container_id("api", 1),
771 "api".to_string(),
772 PanicAction::Restart,
773 );
774
775 for _ in 0..5 {
777 container.record_restart(Duration::from_millis(100), 5);
778 }
779
780 tokio::time::sleep(Duration::from_millis(150)).await;
782
783 let in_loop = container.record_restart(Duration::from_millis(100), 5);
785 assert!(!in_loop);
786 }
787
788 #[tokio::test]
789 async fn test_get_container_info() {
790 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
791 let supervisor = ContainerSupervisor::new(runtime);
792
793 let container_id = mock_container_id("api", 1);
794 let spec = mock_service_spec(PanicAction::Isolate);
795
796 supervisor.supervise(&container_id, &spec).await;
797
798 let info = supervisor.get_container_info(&container_id).await;
799 assert!(info.is_some());
800
801 let info = info.unwrap();
802 assert_eq!(info.id, container_id);
803 assert_eq!(info.service_name, "api");
804 assert_eq!(info.panic_action, PanicAction::Isolate);
805 assert_eq!(info.state, SupervisedState::Running);
806 assert_eq!(info.total_restarts, 0);
807 }
808
809 #[tokio::test]
810 async fn test_count_by_state() {
811 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
812 let supervisor = ContainerSupervisor::new(runtime);
813
814 let spec = mock_service_spec(PanicAction::Restart);
815
816 supervisor
817 .supervise(&mock_container_id("api", 1), &spec)
818 .await;
819 supervisor
820 .supervise(&mock_container_id("api", 2), &spec)
821 .await;
822
823 assert_eq!(supervisor.count_by_state(SupervisedState::Running).await, 2);
824 assert_eq!(
825 supervisor
826 .count_by_state(SupervisedState::CrashLoopBackOff)
827 .await,
828 0
829 );
830 }
831
832 #[test]
833 fn test_supervisor_config_default() {
834 let config = SupervisorConfig::default();
835
836 assert_eq!(config.max_restarts, DEFAULT_MAX_RESTARTS);
837 assert_eq!(config.restart_window, DEFAULT_RESTART_WINDOW);
838 assert_eq!(config.poll_interval, DEFAULT_POLL_INTERVAL);
839 }
840
841 #[test]
842 fn test_supervised_state_should_monitor() {
843 let container = SupervisedContainer {
845 state: SupervisedState::Running,
846 ..SupervisedContainer::new(
847 mock_container_id("api", 1),
848 "api".to_string(),
849 PanicAction::Restart,
850 )
851 };
852 assert!(container.should_monitor());
853
854 let container = SupervisedContainer {
855 state: SupervisedState::CrashLoopBackOff,
856 ..SupervisedContainer::new(
857 mock_container_id("api", 1),
858 "api".to_string(),
859 PanicAction::Restart,
860 )
861 };
862 assert!(container.should_monitor());
863
864 let container = SupervisedContainer {
866 state: SupervisedState::Shutdown,
867 ..SupervisedContainer::new(
868 mock_container_id("api", 1),
869 "api".to_string(),
870 PanicAction::Restart,
871 )
872 };
873 assert!(!container.should_monitor());
874
875 let container = SupervisedContainer {
876 state: SupervisedState::Isolated,
877 ..SupervisedContainer::new(
878 mock_container_id("api", 1),
879 "api".to_string(),
880 PanicAction::Restart,
881 )
882 };
883 assert!(!container.should_monitor());
884
885 let container = SupervisedContainer {
886 state: SupervisedState::Completed,
887 ..SupervisedContainer::new(
888 mock_container_id("api", 1),
889 "api".to_string(),
890 PanicAction::Restart,
891 )
892 };
893 assert!(!container.should_monitor());
894 }
895}