1use crate::error::{AgentError, Result};
9use crate::runtime::{ContainerId, ContainerState, Runtime};
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::{mpsc, Notify, RwLock};
15use zlayer_spec::{PanicAction, ServiceSpec};
16
17pub type IsolateCallback = Arc<dyn Fn(&ContainerId) + Send + Sync>;
19
20const DEFAULT_MAX_RESTARTS: u32 = 5;
22const DEFAULT_RESTART_WINDOW: Duration = Duration::from_secs(300); const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(5);
26const CRASH_LOOP_BACKOFF_DELAY: Duration = Duration::from_secs(30);
28
29#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum SupervisedState {
32 Running,
34 Restarting,
36 CrashLoopBackOff,
38 Isolated,
40 Shutdown,
42 Completed,
44}
45
46#[derive(Debug, Clone)]
48pub struct SupervisedContainer {
49 pub id: ContainerId,
51 pub service_name: String,
53 pub state: SupervisedState,
55 pub panic_action: PanicAction,
57 pub restart_times: Vec<Instant>,
59 pub total_restarts: u32,
61 pub last_exit_code: Option<i32>,
63 pub supervised_since: Instant,
65}
66
67impl SupervisedContainer {
68 #[must_use]
70 pub fn new(id: ContainerId, service_name: String, panic_action: PanicAction) -> Self {
71 Self {
72 id,
73 service_name,
74 state: SupervisedState::Running,
75 panic_action,
76 restart_times: Vec::new(),
77 total_restarts: 0,
78 last_exit_code: None,
79 supervised_since: Instant::now(),
80 }
81 }
82
83 pub fn record_restart(&mut self, window: Duration, max_restarts: u32) -> bool {
85 let now = Instant::now();
86 self.restart_times.push(now);
87 self.total_restarts += 1;
88
89 self.restart_times
91 .retain(|&t| now.duration_since(t) < window);
92
93 #[allow(clippy::cast_possible_truncation)]
95 let count = self.restart_times.len() as u32;
96 count > max_restarts
97 }
98
99 #[must_use]
101 pub fn should_monitor(&self) -> bool {
102 matches!(
103 self.state,
104 SupervisedState::Running | SupervisedState::CrashLoopBackOff
105 )
106 }
107}
108
109#[derive(Debug, Clone)]
111pub enum SupervisorEvent {
112 ContainerRestarted {
114 id: ContainerId,
115 service_name: String,
116 exit_code: i32,
117 restart_count: u32,
118 },
119 CrashLoopBackOff {
121 id: ContainerId,
122 service_name: String,
123 restart_count: u32,
124 },
125 ContainerIsolated {
127 id: ContainerId,
128 service_name: String,
129 exit_code: i32,
130 },
131 ServiceShutdown {
133 id: ContainerId,
134 service_name: String,
135 exit_code: i32,
136 },
137 ContainerCompleted {
139 id: ContainerId,
140 service_name: String,
141 },
142}
143
144#[derive(Debug, Clone)]
146pub struct SupervisorConfig {
147 pub max_restarts: u32,
149 pub restart_window: Duration,
151 pub poll_interval: Duration,
153}
154
155impl Default for SupervisorConfig {
156 fn default() -> Self {
157 Self {
158 max_restarts: DEFAULT_MAX_RESTARTS,
159 restart_window: DEFAULT_RESTART_WINDOW,
160 poll_interval: DEFAULT_POLL_INTERVAL,
161 }
162 }
163}
164
165pub struct ContainerSupervisor {
167 runtime: Arc<dyn Runtime + Send + Sync>,
169 containers: Arc<RwLock<HashMap<ContainerId, SupervisedContainer>>>,
171 config: SupervisorConfig,
173 event_tx: mpsc::Sender<SupervisorEvent>,
175 event_rx: Arc<RwLock<mpsc::Receiver<SupervisorEvent>>>,
177 running: Arc<AtomicBool>,
179 shutdown: Arc<Notify>,
181 on_isolate: Option<IsolateCallback>,
183}
184
185impl ContainerSupervisor {
186 pub fn new(runtime: Arc<dyn Runtime + Send + Sync>) -> Self {
188 Self::with_config(runtime, SupervisorConfig::default())
189 }
190
191 pub fn with_config(runtime: Arc<dyn Runtime + Send + Sync>, config: SupervisorConfig) -> Self {
193 let (event_tx, event_rx) = mpsc::channel(100);
194
195 Self {
196 runtime,
197 containers: Arc::new(RwLock::new(HashMap::new())),
198 config,
199 event_tx,
200 event_rx: Arc::new(RwLock::new(event_rx)),
201 running: Arc::new(AtomicBool::new(false)),
202 shutdown: Arc::new(Notify::new()),
203 on_isolate: None,
204 }
205 }
206
207 pub fn set_isolate_callback<F>(&mut self, callback: F)
209 where
210 F: Fn(&ContainerId) + Send + Sync + 'static,
211 {
212 self.on_isolate = Some(Arc::new(callback));
213 }
214
215 pub async fn supervise(&self, container_id: &ContainerId, spec: &ServiceSpec) {
221 let supervised = SupervisedContainer::new(
222 container_id.clone(),
223 container_id.service.clone(),
224 spec.errors.on_panic.action,
225 );
226
227 let mut containers = self.containers.write().await;
228 containers.insert(container_id.clone(), supervised);
229
230 tracing::info!(
231 container = %container_id,
232 panic_action = ?spec.errors.on_panic.action,
233 "Container registered for supervision"
234 );
235 }
236
237 pub async fn unsupervise(&self, container_id: &ContainerId) {
239 let mut containers = self.containers.write().await;
240 if containers.remove(container_id).is_some() {
241 tracing::debug!(container = %container_id, "Container removed from supervision");
242 }
243 }
244
245 pub async fn get_state(&self, container_id: &ContainerId) -> Option<SupervisedState> {
247 let containers = self.containers.read().await;
248 containers.get(container_id).map(|c| c.state.clone())
249 }
250
251 pub async fn get_container_info(
253 &self,
254 container_id: &ContainerId,
255 ) -> Option<SupervisedContainer> {
256 let containers = self.containers.read().await;
257 containers.get(container_id).cloned()
258 }
259
260 pub async fn list_supervised(&self) -> Vec<SupervisedContainer> {
262 let containers = self.containers.read().await;
263 containers.values().cloned().collect()
264 }
265
266 pub async fn take_event_receiver(&self) -> Option<mpsc::Receiver<SupervisorEvent>> {
270 let mut rx_guard = self.event_rx.write().await;
271 let (_, dummy_rx) = mpsc::channel(1);
273 let old_rx = std::mem::replace(&mut *rx_guard, dummy_rx);
274 Some(old_rx)
275 }
276
277 pub async fn run_loop(&self) {
282 self.running.store(true, Ordering::SeqCst);
283
284 tracing::info!(
285 poll_interval_ms = self.config.poll_interval.as_millis(),
286 "Container supervisor started"
287 );
288
289 loop {
290 tokio::select! {
291 () = self.shutdown.notified() => {
292 tracing::info!("Container supervisor shutting down");
293 break;
294 }
295 () = tokio::time::sleep(self.config.poll_interval) => {
296 if let Err(e) = self.check_all_containers().await {
297 tracing::error!(error = %e, "Error during container health check");
298 }
299 }
300 }
301 }
302
303 self.running.store(false, Ordering::SeqCst);
304 }
305
306 pub fn shutdown(&self) {
308 self.shutdown.notify_one();
309 }
310
311 #[must_use]
313 pub fn is_running(&self) -> bool {
314 self.running.load(Ordering::SeqCst)
315 }
316
317 async fn check_all_containers(&self) -> Result<()> {
319 let containers_to_check: Vec<_> = {
320 let containers = self.containers.read().await;
321 containers
322 .iter()
323 .filter(|(_, c)| c.should_monitor())
324 .map(|(id, c)| (id.clone(), c.panic_action))
325 .collect()
326 };
327
328 for (container_id, panic_action) in containers_to_check {
329 self.check_container(&container_id, panic_action).await?;
330 }
331
332 Ok(())
333 }
334
335 async fn check_container(
337 &self,
338 container_id: &ContainerId,
339 panic_action: PanicAction,
340 ) -> Result<()> {
341 let state = match self.runtime.container_state(container_id).await {
352 Ok(state) => state,
353 Err(AgentError::NotFound { .. }) => {
354 tracing::debug!(
355 container = %container_id,
356 "supervised container not yet known to runtime (registration race); \
357 skipping this health-check round"
358 );
359 return Ok(());
360 }
361 Err(e) => return Err(e),
362 };
363
364 match state {
365 ContainerState::Running
366 | ContainerState::Pending
367 | ContainerState::Initializing
368 | ContainerState::Stopping => {
369 }
371 ContainerState::Exited { code } => {
372 self.handle_container_exit(container_id, code, panic_action)
374 .await?;
375 }
376 ContainerState::Failed { reason } => {
377 tracing::warn!(
379 container = %container_id,
380 reason = %reason,
381 "Container reported as failed"
382 );
383 self.handle_container_exit(container_id, -1, panic_action)
384 .await?;
385 }
386 }
387
388 Ok(())
389 }
390
391 async fn handle_container_exit(
393 &self,
394 container_id: &ContainerId,
395 exit_code: i32,
396 panic_action: PanicAction,
397 ) -> Result<()> {
398 let (service_name, _should_restart, in_crash_loop) = {
400 let mut containers = self.containers.write().await;
401 let Some(container) = containers.get_mut(container_id) else {
402 return Ok(()); };
404
405 container.last_exit_code = Some(exit_code);
406
407 if exit_code == 0 {
409 container.state = SupervisedState::Completed;
410 let _ = self
411 .event_tx
412 .send(SupervisorEvent::ContainerCompleted {
413 id: container_id.clone(),
414 service_name: container.service_name.clone(),
415 })
416 .await;
417 return Ok(());
418 }
419
420 let service_name = container.service_name.clone();
421 let in_crash_loop =
422 container.record_restart(self.config.restart_window, self.config.max_restarts);
423
424 let should_restart = match panic_action {
425 PanicAction::Restart => !in_crash_loop,
426 PanicAction::Shutdown | PanicAction::Isolate => false,
427 };
428
429 if in_crash_loop && matches!(panic_action, PanicAction::Restart) {
430 container.state = SupervisedState::CrashLoopBackOff;
431 } else if should_restart {
432 container.state = SupervisedState::Restarting;
433 }
434
435 (service_name, should_restart, in_crash_loop)
436 };
437
438 match panic_action {
440 PanicAction::Restart => {
441 if in_crash_loop {
442 self.handle_crash_loop_backoff(container_id, &service_name)
443 .await?;
444 } else {
445 self.restart_container(container_id, &service_name, exit_code)
446 .await?;
447 }
448 }
449 PanicAction::Shutdown => {
450 self.shutdown_container(container_id, &service_name, exit_code)
451 .await?;
452 }
453 PanicAction::Isolate => {
454 self.isolate_container(container_id, &service_name, exit_code)
455 .await?;
456 }
457 }
458
459 Ok(())
460 }
461
462 async fn restart_container(
464 &self,
465 container_id: &ContainerId,
466 service_name: &str,
467 exit_code: i32,
468 ) -> Result<()> {
469 let restart_count = {
470 let containers = self.containers.read().await;
471 containers.get(container_id).map_or(0, |c| c.total_restarts)
472 };
473
474 tracing::info!(
475 container = %container_id,
476 service = %service_name,
477 exit_code = exit_code,
478 restart_count = restart_count,
479 "Restarting crashed container"
480 );
481
482 self.runtime
484 .start_container(container_id)
485 .await
486 .map_err(|e| AgentError::StartFailed {
487 id: container_id.to_string(),
488 reason: e.to_string(),
489 })?;
490
491 {
493 let mut containers = self.containers.write().await;
494 if let Some(container) = containers.get_mut(container_id) {
495 container.state = SupervisedState::Running;
496 }
497 }
498
499 let _ = self
501 .event_tx
502 .send(SupervisorEvent::ContainerRestarted {
503 id: container_id.clone(),
504 service_name: service_name.to_string(),
505 exit_code,
506 restart_count,
507 })
508 .await;
509
510 Ok(())
511 }
512
513 async fn handle_crash_loop_backoff(
515 &self,
516 container_id: &ContainerId,
517 service_name: &str,
518 ) -> Result<()> {
519 let restart_count = {
520 let containers = self.containers.read().await;
521 containers.get(container_id).map_or(0, |c| c.total_restarts)
522 };
523
524 tracing::warn!(
525 container = %container_id,
526 service = %service_name,
527 restart_count = restart_count,
528 backoff_delay_secs = CRASH_LOOP_BACKOFF_DELAY.as_secs(),
529 "Container in CrashLoopBackOff, delaying restart"
530 );
531
532 let _ = self
534 .event_tx
535 .send(SupervisorEvent::CrashLoopBackOff {
536 id: container_id.clone(),
537 service_name: service_name.to_string(),
538 restart_count,
539 })
540 .await;
541
542 let runtime = Arc::clone(&self.runtime);
544 let container_id = container_id.clone();
545 let containers = Arc::clone(&self.containers);
546
547 tokio::spawn(async move {
548 tokio::time::sleep(CRASH_LOOP_BACKOFF_DELAY).await;
549
550 if let Err(e) = runtime.start_container(&container_id).await {
552 tracing::error!(
553 container = %container_id,
554 error = %e,
555 "Failed to restart container after CrashLoopBackOff delay"
556 );
557 return;
558 }
559
560 let mut containers_guard = containers.write().await;
562 if let Some(container) = containers_guard.get_mut(&container_id) {
563 container.state = SupervisedState::Running;
564 }
565 });
566
567 Ok(())
568 }
569
570 async fn shutdown_container(
572 &self,
573 container_id: &ContainerId,
574 service_name: &str,
575 exit_code: i32,
576 ) -> Result<()> {
577 tracing::warn!(
578 container = %container_id,
579 service = %service_name,
580 exit_code = exit_code,
581 "Shutting down service due to panic policy"
582 );
583
584 {
586 let mut containers = self.containers.write().await;
587 if let Some(container) = containers.get_mut(container_id) {
588 container.state = SupervisedState::Shutdown;
589 }
590 }
591
592 let _ = self
594 .event_tx
595 .send(SupervisorEvent::ServiceShutdown {
596 id: container_id.clone(),
597 service_name: service_name.to_string(),
598 exit_code,
599 })
600 .await;
601
602 Ok(())
603 }
604
605 async fn isolate_container(
607 &self,
608 container_id: &ContainerId,
609 service_name: &str,
610 exit_code: i32,
611 ) -> Result<()> {
612 tracing::info!(
613 container = %container_id,
614 service = %service_name,
615 exit_code = exit_code,
616 "Isolating container (removed from load balancer for debugging)"
617 );
618
619 if let Some(ref callback) = self.on_isolate {
621 callback(container_id);
622 }
623
624 {
626 let mut containers = self.containers.write().await;
627 if let Some(container) = containers.get_mut(container_id) {
628 container.state = SupervisedState::Isolated;
629 }
630 }
631
632 let _ = self
634 .event_tx
635 .send(SupervisorEvent::ContainerIsolated {
636 id: container_id.clone(),
637 service_name: service_name.to_string(),
638 exit_code,
639 })
640 .await;
641
642 Ok(())
643 }
644
645 pub async fn supervised_count(&self) -> usize {
647 self.containers.read().await.len()
648 }
649
650 pub async fn count_by_state(&self, state: SupervisedState) -> usize {
652 self.containers
653 .read()
654 .await
655 .values()
656 .filter(|c| c.state == state)
657 .count()
658 }
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664 use crate::runtime::MockRuntime;
665
666 fn mock_container_id(service: &str, replica: u32) -> ContainerId {
667 ContainerId::new(service.to_string(), replica)
668 }
669
670 fn mock_service_spec(panic_action: PanicAction) -> ServiceSpec {
671 let mut spec: ServiceSpec = serde_yaml::from_str::<zlayer_spec::DeploymentSpec>(
672 r"
673version: v1
674deployment: test
675services:
676 test:
677 rtype: service
678 image:
679 name: test:latest
680",
681 )
682 .unwrap()
683 .services
684 .remove("test")
685 .unwrap();
686
687 spec.errors.on_panic.action = panic_action;
688 spec
689 }
690
691 #[tokio::test]
692 async fn test_supervisor_creation() {
693 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
694 let supervisor = ContainerSupervisor::new(runtime);
695
696 assert!(!supervisor.is_running());
697 assert_eq!(supervisor.supervised_count().await, 0);
698 }
699
700 #[tokio::test]
701 async fn test_supervisor_with_config() {
702 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
703 let config = SupervisorConfig {
704 max_restarts: 10,
705 restart_window: Duration::from_secs(600),
706 poll_interval: Duration::from_secs(1),
707 };
708
709 let supervisor = ContainerSupervisor::with_config(runtime, config);
710 assert_eq!(supervisor.config.max_restarts, 10);
711 assert_eq!(supervisor.config.restart_window, Duration::from_secs(600));
712 }
713
714 #[tokio::test]
715 async fn test_supervise_container() {
716 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
717 let supervisor = ContainerSupervisor::new(runtime);
718
719 let container_id = mock_container_id("api", 1);
720 let spec = mock_service_spec(PanicAction::Restart);
721
722 supervisor.supervise(&container_id, &spec).await;
723
724 assert_eq!(supervisor.supervised_count().await, 1);
725
726 let state = supervisor.get_state(&container_id).await;
727 assert_eq!(state, Some(SupervisedState::Running));
728 }
729
730 #[tokio::test]
731 async fn test_unsupervise_container() {
732 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
733 let supervisor = ContainerSupervisor::new(runtime);
734
735 let container_id = mock_container_id("api", 1);
736 let spec = mock_service_spec(PanicAction::Restart);
737
738 supervisor.supervise(&container_id, &spec).await;
739 assert_eq!(supervisor.supervised_count().await, 1);
740
741 supervisor.unsupervise(&container_id).await;
742 assert_eq!(supervisor.supervised_count().await, 0);
743 }
744
745 #[tokio::test]
746 async fn test_list_supervised() {
747 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
748 let supervisor = ContainerSupervisor::new(runtime);
749
750 let spec = mock_service_spec(PanicAction::Restart);
751
752 supervisor
753 .supervise(&mock_container_id("api", 1), &spec)
754 .await;
755 supervisor
756 .supervise(&mock_container_id("api", 2), &spec)
757 .await;
758 supervisor
759 .supervise(&mock_container_id("web", 1), &spec)
760 .await;
761
762 let containers = supervisor.list_supervised().await;
763 assert_eq!(containers.len(), 3);
764 }
765
766 #[tokio::test]
767 async fn test_supervised_container_record_restart() {
768 let mut container = SupervisedContainer::new(
769 mock_container_id("api", 1),
770 "api".to_string(),
771 PanicAction::Restart,
772 );
773
774 for _ in 0..5 {
776 let in_loop = container.record_restart(Duration::from_secs(300), 5);
777 assert!(!in_loop);
778 }
779
780 let in_loop = container.record_restart(Duration::from_secs(300), 5);
782 assert!(in_loop);
783 }
784
785 #[tokio::test]
786 async fn test_supervised_container_restart_window() {
787 let mut container = SupervisedContainer::new(
788 mock_container_id("api", 1),
789 "api".to_string(),
790 PanicAction::Restart,
791 );
792
793 for _ in 0..5 {
795 container.record_restart(Duration::from_millis(100), 5);
796 }
797
798 tokio::time::sleep(Duration::from_millis(150)).await;
800
801 let in_loop = container.record_restart(Duration::from_millis(100), 5);
803 assert!(!in_loop);
804 }
805
806 #[tokio::test]
807 async fn test_get_container_info() {
808 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
809 let supervisor = ContainerSupervisor::new(runtime);
810
811 let container_id = mock_container_id("api", 1);
812 let spec = mock_service_spec(PanicAction::Isolate);
813
814 supervisor.supervise(&container_id, &spec).await;
815
816 let info = supervisor.get_container_info(&container_id).await;
817 assert!(info.is_some());
818
819 let info = info.unwrap();
820 assert_eq!(info.id, container_id);
821 assert_eq!(info.service_name, "api");
822 assert_eq!(info.panic_action, PanicAction::Isolate);
823 assert_eq!(info.state, SupervisedState::Running);
824 assert_eq!(info.total_restarts, 0);
825 }
826
827 #[tokio::test]
828 async fn test_count_by_state() {
829 let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
830 let supervisor = ContainerSupervisor::new(runtime);
831
832 let spec = mock_service_spec(PanicAction::Restart);
833
834 supervisor
835 .supervise(&mock_container_id("api", 1), &spec)
836 .await;
837 supervisor
838 .supervise(&mock_container_id("api", 2), &spec)
839 .await;
840
841 assert_eq!(supervisor.count_by_state(SupervisedState::Running).await, 2);
842 assert_eq!(
843 supervisor
844 .count_by_state(SupervisedState::CrashLoopBackOff)
845 .await,
846 0
847 );
848 }
849
850 #[test]
851 fn test_supervisor_config_default() {
852 let config = SupervisorConfig::default();
853
854 assert_eq!(config.max_restarts, DEFAULT_MAX_RESTARTS);
855 assert_eq!(config.restart_window, DEFAULT_RESTART_WINDOW);
856 assert_eq!(config.poll_interval, DEFAULT_POLL_INTERVAL);
857 }
858
859 #[test]
860 fn test_supervised_state_should_monitor() {
861 let container = SupervisedContainer {
863 state: SupervisedState::Running,
864 ..SupervisedContainer::new(
865 mock_container_id("api", 1),
866 "api".to_string(),
867 PanicAction::Restart,
868 )
869 };
870 assert!(container.should_monitor());
871
872 let container = SupervisedContainer {
873 state: SupervisedState::CrashLoopBackOff,
874 ..SupervisedContainer::new(
875 mock_container_id("api", 1),
876 "api".to_string(),
877 PanicAction::Restart,
878 )
879 };
880 assert!(container.should_monitor());
881
882 let container = SupervisedContainer {
884 state: SupervisedState::Shutdown,
885 ..SupervisedContainer::new(
886 mock_container_id("api", 1),
887 "api".to_string(),
888 PanicAction::Restart,
889 )
890 };
891 assert!(!container.should_monitor());
892
893 let container = SupervisedContainer {
894 state: SupervisedState::Isolated,
895 ..SupervisedContainer::new(
896 mock_container_id("api", 1),
897 "api".to_string(),
898 PanicAction::Restart,
899 )
900 };
901 assert!(!container.should_monitor());
902
903 let container = SupervisedContainer {
904 state: SupervisedState::Completed,
905 ..SupervisedContainer::new(
906 mock_container_id("api", 1),
907 "api".to_string(),
908 PanicAction::Restart,
909 )
910 };
911 assert!(!container.should_monitor());
912 }
913}