1use std::collections::HashMap;
41use std::sync::Arc;
42use std::sync::atomic::{AtomicU64, Ordering};
43use std::time::Duration;
44
45use tokio::sync::RwLock;
46use tokio_util::sync::CancellationToken;
47use tokio_util::task::TaskTracker;
48
49use crate::backoff::{BackoffAction, BackoffConfig, BackoffState};
50use crate::lifecycle::{
51 ServiceLifecycle, ServiceLifecycleSnapshot, ServicePhase, TerminationReason,
52};
53use crate::service::{RestartPolicy, TradingService};
54
55#[derive(Debug, Clone)]
61pub struct SupervisorConfig {
62 pub default_backoff: BackoffConfig,
64 pub shutdown_timeout: Duration,
66 pub install_signal_handler: bool,
68}
69
70impl Default for SupervisorConfig {
71 fn default() -> Self {
72 Self {
73 default_backoff: BackoffConfig::default(),
74 shutdown_timeout: Duration::from_secs(30),
75 install_signal_handler: true,
76 }
77 }
78}
79
80impl SupervisorConfig {
81 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
83 self.shutdown_timeout = timeout;
84 self
85 }
86
87 pub fn with_default_backoff(mut self, backoff: BackoffConfig) -> Self {
90 self.default_backoff = backoff;
91 self
92 }
93
94 pub fn without_signal_handler(mut self) -> Self {
97 self.install_signal_handler = false;
98 self
99 }
100}
101
102#[derive(Debug, Default)]
112pub struct SupervisorMetrics {
113 pub restarts_total: AtomicU64,
115 pub active_services: AtomicU64,
117 pub spawned_total: AtomicU64,
119 pub terminated_total: AtomicU64,
121 pub circuit_breaker_trips: AtomicU64,
123}
124
125impl SupervisorMetrics {
126 fn new() -> Self {
127 Self::default()
128 }
129
130 fn record_spawn(&self) {
131 self.spawned_total.fetch_add(1, Ordering::Relaxed);
132 let new_active = self.active_services.fetch_add(1, Ordering::Relaxed) + 1;
133
134 #[cfg(feature = "prometheus")]
135 {
136 let p = crate::prometheus::collectors();
137 p.spawned_total.inc();
138 p.active_services.set(new_active as f64);
139 }
140 #[cfg(not(feature = "prometheus"))]
141 let _ = new_active;
142 }
143
144 fn record_restart(&self) {
145 self.restarts_total.fetch_add(1, Ordering::Relaxed);
146 #[cfg(feature = "prometheus")]
147 crate::prometheus::collectors().restarts_total.inc();
148 }
149
150 fn record_termination(&self) {
151 self.terminated_total.fetch_add(1, Ordering::Relaxed);
152 let prev = self
153 .active_services
154 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
155 Some(v.saturating_sub(1))
156 })
157 .unwrap_or(0);
158 let new_active = prev.saturating_sub(1);
159
160 #[cfg(feature = "prometheus")]
161 {
162 let p = crate::prometheus::collectors();
163 p.terminated_total.inc();
164 p.active_services.set(new_active as f64);
165 }
166 #[cfg(not(feature = "prometheus"))]
167 let _ = new_active;
168 }
169
170 fn record_termination_with_uptime(&self, _service_name: &str, _uptime_secs: f64) {
171 self.record_termination();
172 #[cfg(feature = "prometheus")]
173 crate::prometheus::collectors()
174 .uptime_seconds
175 .with_label_values(&[_service_name])
176 .observe(_uptime_secs);
177 }
178
179 fn record_circuit_breaker_trip(&self) {
180 self.circuit_breaker_trips.fetch_add(1, Ordering::Relaxed);
181 #[cfg(feature = "prometheus")]
182 crate::prometheus::collectors().circuit_breaker_trips.inc();
183 }
184
185 pub fn snapshot(&self) -> MetricsSnapshot {
187 MetricsSnapshot {
188 restarts_total: self.restarts_total.load(Ordering::Relaxed),
189 active_services: self.active_services.load(Ordering::Relaxed),
190 spawned_total: self.spawned_total.load(Ordering::Relaxed),
191 terminated_total: self.terminated_total.load(Ordering::Relaxed),
192 circuit_breaker_trips: self.circuit_breaker_trips.load(Ordering::Relaxed),
193 }
194 }
195}
196
197#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
199pub struct MetricsSnapshot {
200 pub restarts_total: u64,
202 pub active_services: u64,
204 pub spawned_total: u64,
206 pub terminated_total: u64,
208 pub circuit_breaker_trips: u64,
210}
211
212#[derive(Debug, Clone, Default)]
218pub struct SpawnOptions {
219 pub backoff: Option<BackoffConfig>,
221}
222
223impl SpawnOptions {
224 pub fn with_backoff(backoff: BackoffConfig) -> Self {
226 Self {
227 backoff: Some(backoff),
228 }
229 }
230}
231
232pub struct Supervisor {
270 config: SupervisorConfig,
271 tracker: TaskTracker,
272 cancel_token: CancellationToken,
273 metrics: Arc<SupervisorMetrics>,
274 lifecycles: Arc<RwLock<HashMap<String, ServiceLifecycle>>>,
275}
276
277impl Supervisor {
278 pub fn new(config: SupervisorConfig) -> Self {
280 Self {
281 config,
282 tracker: TaskTracker::new(),
283 cancel_token: CancellationToken::new(),
284 metrics: Arc::new(SupervisorMetrics::new()),
285 lifecycles: Arc::new(RwLock::new(HashMap::new())),
286 }
287 }
288
289 pub fn with_defaults() -> Self {
291 Self::new(SupervisorConfig::default())
292 }
293
294 pub fn cancel_token(&self) -> &CancellationToken {
298 &self.cancel_token
299 }
300
301 pub fn metrics(&self) -> &Arc<SupervisorMetrics> {
303 &self.metrics
304 }
305
306 pub async fn lifecycle_snapshots(&self) -> Vec<ServiceLifecycleSnapshot> {
308 self.lifecycles
309 .read()
310 .await
311 .values()
312 .map(ServiceLifecycleSnapshot::from)
313 .collect()
314 }
315
316 pub async fn service_lifecycle(&self, name: &str) -> Option<ServiceLifecycleSnapshot> {
318 self.lifecycles
319 .read()
320 .await
321 .get(name)
322 .map(ServiceLifecycleSnapshot::from)
323 }
324
325 pub async fn service_count(&self) -> usize {
327 self.lifecycles.read().await.len()
328 }
329
330 #[tracing::instrument(skip(self))]
332 pub fn trigger_shutdown(&self) {
333 tracing::info!("supervisor: shutdown triggered");
334 self.cancel_token.cancel();
335 }
336
337 pub fn is_shutting_down(&self) -> bool {
339 self.cancel_token.is_cancelled()
340 }
341
342 pub fn spawn_service(&self, service: Box<dyn TradingService>) {
346 self.spawn_service_with_options(service, SpawnOptions::default());
347 }
348
349 #[tracing::instrument(
351 skip(self, service, options),
352 fields(service = %service.name(), policy = %service.restart_policy())
353 )]
354 pub fn spawn_service_with_options(
355 &self,
356 service: Box<dyn TradingService>,
357 options: SpawnOptions,
358 ) {
359 let service_name = service.name().to_string();
360 let restart_policy = service.restart_policy();
361 let backoff_config = options
362 .backoff
363 .unwrap_or_else(|| self.config.default_backoff.clone());
364
365 let cancel = self.cancel_token.child_token();
366 let metrics = self.metrics.clone();
367 let lifecycles = self.lifecycles.clone();
368
369 metrics.record_spawn();
370
371 self.tracker.spawn(Self::service_loop(
372 service,
373 service_name,
374 restart_policy,
375 backoff_config,
376 cancel,
377 metrics,
378 lifecycles,
379 ));
380 }
381
382 #[tracing::instrument(
389 skip_all,
390 fields(service = %service_name, policy = %restart_policy)
391 )]
392 async fn service_loop(
393 service: Box<dyn TradingService>,
394 service_name: String,
395 restart_policy: RestartPolicy,
396 backoff_config: BackoffConfig,
397 cancel: CancellationToken,
398 metrics: Arc<SupervisorMetrics>,
399 lifecycles: Arc<RwLock<HashMap<String, ServiceLifecycle>>>,
400 ) {
401 let mut backoff = BackoffState::new(backoff_config);
402 let mut lifecycle = ServiceLifecycle::new(&service_name);
403
404 {
407 let mut lc_map = lifecycles.write().await;
408 lc_map.insert(service_name.clone(), lifecycle.clone());
409 }
410
411 loop {
412 if cancel.is_cancelled() {
414 tracing::info!(
415 service = %service_name,
416 "cancellation detected, not starting service"
417 );
418 let _ = lifecycle.transition_to_stopping();
419 let _ = lifecycle.transition_to_terminated(TerminationReason::Cancelled);
420 Self::update_lifecycle(&lifecycles, &service_name, &lifecycle).await;
421 let uptime = lifecycle.cumulative_running_time().as_secs_f64();
422 metrics.record_termination_with_uptime(&service_name, uptime);
423 return;
424 }
425
426 if lifecycle.phase() == ServicePhase::Starting {
428 let _ = lifecycle.transition_to_running();
429 } else if lifecycle.phase() == ServicePhase::BackingOff {
430 let _ = lifecycle.transition_to_restarting();
431 let _ = lifecycle.transition_to_running();
432 metrics.record_restart();
433 }
434
435 backoff.record_start();
436 Self::update_lifecycle(&lifecycles, &service_name, &lifecycle).await;
437
438 tracing::info!(
439 service = %service_name,
440 attempt = lifecycle.start_count(),
441 "running service"
442 );
443
444 let result = service.run(cancel.clone()).await;
450
451 if cancel.is_cancelled() {
453 tracing::info!(service = %service_name, "service exited after cancellation");
454 let _ = lifecycle.transition_to_stopping();
455 let _ = lifecycle.transition_to_terminated(TerminationReason::Cancelled);
456 Self::update_lifecycle(&lifecycles, &service_name, &lifecycle).await;
457 let uptime = lifecycle.cumulative_running_time().as_secs_f64();
458 metrics.record_termination_with_uptime(&service_name, uptime);
459 return;
460 }
461
462 match result {
463 Ok(()) => {
464 tracing::info!(service = %service_name, "service exited cleanly");
465 backoff.maybe_reset_on_cooldown();
466
467 match restart_policy {
468 RestartPolicy::Always => {
469 backoff.reset();
474
475 tracing::info!(
476 service = %service_name,
477 "restart_policy=always, will restart after short delay"
478 );
479 let delay = Duration::from_millis(100);
482 let _ = lifecycle
483 .transition_to_backing_off("clean exit, policy=always", delay);
484 Self::update_lifecycle(&lifecycles, &service_name, &lifecycle).await;
485
486 tokio::select! {
487 _ = cancel.cancelled() => {
488 let _ = lifecycle.transition_to_stopping();
489 let _ = lifecycle.transition_to_terminated(
490 TerminationReason::Cancelled,
491 );
492 Self::update_lifecycle(&lifecycles, &service_name, &lifecycle).await;
493 let uptime = lifecycle.cumulative_running_time().as_secs_f64();
494 metrics.record_termination_with_uptime(&service_name, uptime);
495 return;
496 }
497 _ = tokio::time::sleep(delay) => {}
498 }
499 continue;
500 }
501 RestartPolicy::OnFailure | RestartPolicy::Never => {
502 let _ =
503 lifecycle.transition_to_terminated(TerminationReason::Completed);
504 Self::update_lifecycle(&lifecycles, &service_name, &lifecycle).await;
505 let uptime = lifecycle.cumulative_running_time().as_secs_f64();
506 metrics.record_termination_with_uptime(&service_name, uptime);
507 return;
508 }
509 }
510 }
511
512 Err(err) => {
513 let error_msg = format!("{err:#}");
514 tracing::error!(
515 service = %service_name,
516 error = %error_msg,
517 "service failed"
518 );
519
520 backoff.maybe_reset_on_cooldown();
521
522 match restart_policy {
523 RestartPolicy::Never => {
524 tracing::warn!(
525 service = %service_name,
526 "restart_policy=never, service will not be restarted"
527 );
528 let _ = lifecycle.transition_to_terminated(
529 TerminationReason::Unrecoverable(error_msg),
530 );
531 Self::update_lifecycle(&lifecycles, &service_name, &lifecycle).await;
532 let uptime = lifecycle.cumulative_running_time().as_secs_f64();
533 metrics.record_termination_with_uptime(&service_name, uptime);
534 return;
535 }
536
537 RestartPolicy::OnFailure | RestartPolicy::Always => {
538 match backoff.next_backoff() {
539 BackoffAction::Retry(delay) => {
540 tracing::info!(
541 service = %service_name,
542 delay_ms = delay.as_millis() as u64,
543 attempt = backoff.attempt(),
544 "scheduling restart after backoff"
545 );
546
547 let _ = lifecycle.transition_to_backing_off(&error_msg, delay);
548 Self::update_lifecycle(&lifecycles, &service_name, &lifecycle)
549 .await;
550
551 tokio::select! {
552 _ = cancel.cancelled() => {
553 tracing::info!(
554 service = %service_name,
555 "cancellation during backoff"
556 );
557 let _ = lifecycle.transition_to_stopping();
558 let _ = lifecycle.transition_to_terminated(
559 TerminationReason::Cancelled,
560 );
561 Self::update_lifecycle(&lifecycles, &service_name, &lifecycle).await;
562 let uptime = lifecycle.cumulative_running_time().as_secs_f64();
563 metrics.record_termination_with_uptime(&service_name, uptime);
564 return;
565 }
566 _ = tokio::time::sleep(delay) => {}
567 }
568 }
569
570 BackoffAction::CircuitOpen {
571 failures,
572 max_retries,
573 } => {
574 tracing::error!(
575 service = %service_name,
576 failures = failures,
577 max_retries = max_retries,
578 "CIRCUIT BREAKER OPEN — too many failures, giving up"
579 );
580 metrics.record_circuit_breaker_trip();
581
582 let _ = lifecycle.transition_to_terminated(
583 TerminationReason::CircuitBreakerOpen {
584 failures,
585 max_retries,
586 },
587 );
588 Self::update_lifecycle(&lifecycles, &service_name, &lifecycle)
589 .await;
590 let uptime = lifecycle.cumulative_running_time().as_secs_f64();
591 metrics.record_termination_with_uptime(&service_name, uptime);
592 return;
593 }
594 }
595 }
596 }
597 }
598 }
599 }
600 }
601
602 async fn update_lifecycle(
604 lifecycles: &Arc<RwLock<HashMap<String, ServiceLifecycle>>>,
605 name: &str,
606 lifecycle: &ServiceLifecycle,
607 ) {
608 let mut lc_map = lifecycles.write().await;
609 lc_map.insert(name.to_string(), lifecycle.clone());
610 }
611
612 #[tracing::instrument(skip(self), fields(timeout_secs = self.config.shutdown_timeout.as_secs()))]
619 pub async fn wait_for_drain(&self) {
620 self.tracker.close();
621 tracing::info!(
622 timeout_secs = self.config.shutdown_timeout.as_secs(),
623 "waiting for all services to drain"
624 );
625 match tokio::time::timeout(self.config.shutdown_timeout, self.tracker.wait()).await {
626 Ok(()) => tracing::info!("all services drained successfully"),
627 Err(_) => tracing::warn!(
628 timeout_secs = self.config.shutdown_timeout.as_secs(),
629 "shutdown timeout exceeded, some services may not have exited cleanly"
630 ),
631 }
632 }
633
634 #[tracing::instrument(skip(self), fields(signal_handler = self.config.install_signal_handler))]
640 pub async fn run_until_shutdown(&self) -> anyhow::Result<()> {
641 if self.config.install_signal_handler {
642 self.wait_for_signal_and_shutdown().await?;
643 } else {
644 self.cancel_token.cancelled().await;
645 tracing::info!("external shutdown signal received");
646 }
647
648 self.wait_for_drain().await;
649
650 let snap = self.metrics.snapshot();
651 tracing::info!(
652 restarts = snap.restarts_total,
653 spawned = snap.spawned_total,
654 terminated = snap.terminated_total,
655 circuit_trips = snap.circuit_breaker_trips,
656 "supervisor shutdown complete"
657 );
658
659 Ok(())
660 }
661
662 async fn wait_for_signal_and_shutdown(&self) -> anyhow::Result<()> {
663 #[cfg(unix)]
664 {
665 use tokio::signal::unix::{SignalKind, signal};
666 let mut sigterm = signal(SignalKind::terminate())?;
667 let mut sigint = signal(SignalKind::interrupt())?;
668 tokio::select! {
669 _ = sigterm.recv() => tracing::info!("received SIGTERM"),
670 _ = sigint.recv() => tracing::info!("received SIGINT"),
671 _ = self.cancel_token.cancelled() => {
672 tracing::info!("shutdown triggered programmatically");
673 return Ok(());
674 }
675 }
676 }
677
678 #[cfg(not(unix))]
679 {
680 tokio::select! {
681 result = tokio::signal::ctrl_c() => {
682 result?;
683 tracing::info!("received Ctrl-C");
684 }
685 _ = self.cancel_token.cancelled() => {
686 tracing::info!("shutdown triggered programmatically");
687 return Ok(());
688 }
689 }
690 }
691
692 self.cancel_token.cancel();
693 Ok(())
694 }
695}
696
697#[cfg(test)]
702mod tests {
703 use super::*;
704 use async_trait::async_trait;
705 use std::sync::atomic::AtomicU32;
706
707 struct CountingService {
712 name: String,
713 policy: RestartPolicy,
714 run_count: Arc<AtomicU64>,
715 }
716
717 impl CountingService {
718 fn new(name: &str, policy: RestartPolicy) -> (Self, Arc<AtomicU64>) {
719 let count = Arc::new(AtomicU64::new(0));
720 (
721 Self {
722 name: name.to_string(),
723 policy,
724 run_count: count.clone(),
725 },
726 count,
727 )
728 }
729 }
730
731 #[async_trait]
732 impl TradingService for CountingService {
733 fn name(&self) -> &str {
734 &self.name
735 }
736 fn restart_policy(&self) -> RestartPolicy {
737 self.policy
738 }
739 async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
740 self.run_count.fetch_add(1, Ordering::SeqCst);
741 cancel.cancelled().await;
742 Ok(())
743 }
744 }
745
746 struct FailNTimes {
748 name: String,
749 fail_count: u32,
750 current: Arc<AtomicU64>,
751 }
752
753 impl FailNTimes {
754 fn new(name: &str, fail_count: u32) -> (Self, Arc<AtomicU64>) {
755 let current = Arc::new(AtomicU64::new(0));
756 (
757 Self {
758 name: name.to_string(),
759 fail_count,
760 current: current.clone(),
761 },
762 current,
763 )
764 }
765 }
766
767 #[async_trait]
768 impl TradingService for FailNTimes {
769 fn name(&self) -> &str {
770 &self.name
771 }
772 fn restart_policy(&self) -> RestartPolicy {
773 RestartPolicy::OnFailure
774 }
775 async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
776 let attempt = self.current.fetch_add(1, Ordering::SeqCst) as u32;
777 if attempt < self.fail_count {
778 tokio::time::sleep(Duration::from_millis(1)).await;
779 anyhow::bail!("simulated failure #{}", attempt + 1);
780 }
781 cancel.cancelled().await;
782 Ok(())
783 }
784 }
785
786 struct OneShotService {
788 name: String,
789 ran: Arc<AtomicU64>,
790 }
791
792 impl OneShotService {
793 fn new(name: &str) -> (Self, Arc<AtomicU64>) {
794 let ran = Arc::new(AtomicU64::new(0));
795 (
796 Self {
797 name: name.to_string(),
798 ran: ran.clone(),
799 },
800 ran,
801 )
802 }
803 }
804
805 #[async_trait]
806 impl TradingService for OneShotService {
807 fn name(&self) -> &str {
808 &self.name
809 }
810 fn restart_policy(&self) -> RestartPolicy {
811 RestartPolicy::Never
812 }
813 async fn run(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
814 self.ran.fetch_add(1, Ordering::SeqCst);
815 Ok(())
816 }
817 }
818
819 struct AlwaysFailService {
821 name: String,
822 attempts: Arc<AtomicU64>,
823 }
824
825 impl AlwaysFailService {
826 fn new(name: &str) -> (Self, Arc<AtomicU64>) {
827 let attempts = Arc::new(AtomicU64::new(0));
828 (
829 Self {
830 name: name.to_string(),
831 attempts: attempts.clone(),
832 },
833 attempts,
834 )
835 }
836 }
837
838 #[async_trait]
839 impl TradingService for AlwaysFailService {
840 fn name(&self) -> &str {
841 &self.name
842 }
843 fn restart_policy(&self) -> RestartPolicy {
844 RestartPolicy::OnFailure
845 }
846 async fn run(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
847 self.attempts.fetch_add(1, Ordering::SeqCst);
848 tokio::time::sleep(Duration::from_millis(1)).await;
849 anyhow::bail!("permanent failure");
850 }
851 }
852
853 #[tokio::test]
856 async fn test_supervisor_creation() {
857 let sup = Supervisor::with_defaults();
858 assert!(!sup.is_shutting_down());
859 assert_eq!(sup.service_count().await, 0);
860 }
861
862 #[tokio::test]
863 async fn test_spawn_and_cancel_single_service() {
864 let config = SupervisorConfig::default().without_signal_handler();
865 let sup = Supervisor::new(config);
866
867 let (svc, count) = CountingService::new("test-svc", RestartPolicy::OnFailure);
868 sup.spawn_service(Box::new(svc));
869
870 tokio::time::sleep(Duration::from_millis(50)).await;
871
872 assert_eq!(count.load(Ordering::SeqCst), 1);
873 assert_eq!(sup.metrics().active_services.load(Ordering::Relaxed), 1);
874
875 sup.trigger_shutdown();
876 sup.wait_for_drain().await;
877
878 let snap = sup.metrics().snapshot();
879 assert_eq!(snap.spawned_total, 1);
880 assert_eq!(snap.terminated_total, 1);
881 assert_eq!(snap.active_services, 0);
882 }
883
884 #[tokio::test]
885 async fn test_spawn_multiple_services() {
886 let config = SupervisorConfig::default().without_signal_handler();
887 let sup = Supervisor::new(config);
888
889 let (svc1, count1) = CountingService::new("svc-1", RestartPolicy::OnFailure);
890 let (svc2, count2) = CountingService::new("svc-2", RestartPolicy::OnFailure);
891 let (svc3, count3) = CountingService::new("svc-3", RestartPolicy::OnFailure);
892
893 sup.spawn_service(Box::new(svc1));
894 sup.spawn_service(Box::new(svc2));
895 sup.spawn_service(Box::new(svc3));
896
897 tokio::time::sleep(Duration::from_millis(50)).await;
898
899 assert_eq!(count1.load(Ordering::SeqCst), 1);
900 assert_eq!(count2.load(Ordering::SeqCst), 1);
901 assert_eq!(count3.load(Ordering::SeqCst), 1);
902
903 sup.trigger_shutdown();
904 sup.wait_for_drain().await;
905
906 let snap = sup.metrics().snapshot();
907 assert_eq!(snap.spawned_total, 3);
908 assert_eq!(snap.terminated_total, 3);
909 }
910
911 #[tokio::test]
912 async fn test_service_restart_on_failure() {
913 let config = SupervisorConfig::default()
914 .without_signal_handler()
915 .with_default_backoff(
916 BackoffConfig::new(Duration::from_millis(10), Duration::from_millis(50))
917 .without_circuit_breaker(),
918 );
919 let sup = Supervisor::new(config);
920
921 let (svc, attempts) = FailNTimes::new("fail-3", 3);
922 sup.spawn_service(Box::new(svc));
923
924 tokio::time::sleep(Duration::from_millis(500)).await;
925
926 assert!(
927 attempts.load(Ordering::SeqCst) >= 4,
928 "expected >= 4 attempts, got {}",
929 attempts.load(Ordering::SeqCst)
930 );
931
932 sup.trigger_shutdown();
933 sup.wait_for_drain().await;
934
935 let snap = sup.metrics().snapshot();
936 assert!(snap.restarts_total >= 3);
937 }
938
939 #[tokio::test]
940 async fn test_one_shot_service_no_restart() {
941 let config = SupervisorConfig::default().without_signal_handler();
942 let sup = Supervisor::new(config);
943
944 let (svc, ran) = OneShotService::new("one-shot");
945 sup.spawn_service(Box::new(svc));
946
947 tokio::time::sleep(Duration::from_millis(100)).await;
948
949 assert_eq!(ran.load(Ordering::SeqCst), 1);
950
951 let snap = sup.metrics().snapshot();
952 assert_eq!(snap.terminated_total, 1);
953 assert_eq!(snap.restarts_total, 0);
954
955 sup.trigger_shutdown();
956 sup.wait_for_drain().await;
957 }
958
959 #[tokio::test]
960 async fn test_restart_policy_never_on_failure() {
961 let config = SupervisorConfig::default().without_signal_handler();
962 let sup = Supervisor::new(config);
963
964 struct FailOnce {
965 ran: Arc<AtomicU64>,
966 }
967
968 #[async_trait]
969 impl TradingService for FailOnce {
970 fn name(&self) -> &str {
971 "fail-once-never"
972 }
973 fn restart_policy(&self) -> RestartPolicy {
974 RestartPolicy::Never
975 }
976 async fn run(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
977 self.ran.fetch_add(1, Ordering::SeqCst);
978 anyhow::bail!("intentional failure");
979 }
980 }
981
982 let ran = Arc::new(AtomicU64::new(0));
983 let svc = FailOnce { ran: ran.clone() };
984 sup.spawn_service(Box::new(svc));
985
986 tokio::time::sleep(Duration::from_millis(100)).await;
987 assert_eq!(ran.load(Ordering::SeqCst), 1);
988
989 let snap = sup.metrics().snapshot();
990 assert_eq!(snap.terminated_total, 1);
991 assert_eq!(snap.restarts_total, 0);
992
993 sup.trigger_shutdown();
994 sup.wait_for_drain().await;
995 }
996
997 #[tokio::test]
998 async fn test_circuit_breaker_trips() {
999 let config = SupervisorConfig::default()
1000 .without_signal_handler()
1001 .with_default_backoff(
1002 BackoffConfig::new(Duration::from_millis(5), Duration::from_millis(20))
1003 .with_circuit_breaker(3, Duration::from_secs(60)),
1004 );
1005 let sup = Supervisor::new(config);
1006
1007 let (svc, attempts) = AlwaysFailService::new("always-fail");
1008 sup.spawn_service(Box::new(svc));
1009
1010 tokio::time::sleep(Duration::from_millis(500)).await;
1011
1012 let att = attempts.load(Ordering::SeqCst);
1013 assert!(att >= 3, "expected >= 3 attempts, got {att}");
1014
1015 let snap = sup.metrics().snapshot();
1016 assert_eq!(snap.circuit_breaker_trips, 1);
1017 assert_eq!(snap.terminated_total, 1);
1018
1019 sup.trigger_shutdown();
1020 sup.wait_for_drain().await;
1021 }
1022
1023 #[tokio::test]
1024 async fn test_lifecycle_snapshots() {
1025 let config = SupervisorConfig::default().without_signal_handler();
1026 let sup = Supervisor::new(config);
1027
1028 let (svc, _) = CountingService::new("lifecycle-test", RestartPolicy::OnFailure);
1029 sup.spawn_service(Box::new(svc));
1030
1031 tokio::time::sleep(Duration::from_millis(50)).await;
1032
1033 let snapshots = sup.lifecycle_snapshots().await;
1034 assert_eq!(snapshots.len(), 1);
1035 let snap = &snapshots[0];
1036 assert_eq!(snap.service_name, "lifecycle-test");
1037 assert_eq!(snap.phase, ServicePhase::Running);
1038 assert_eq!(snap.start_count, 1);
1039 assert_eq!(snap.total_failures, 0);
1040
1041 sup.trigger_shutdown();
1042 sup.wait_for_drain().await;
1043
1044 let snapshots = sup.lifecycle_snapshots().await;
1045 assert_eq!(snapshots[0].phase, ServicePhase::Terminated);
1046 }
1047
1048 #[tokio::test]
1049 async fn test_service_lifecycle_by_name() {
1050 let config = SupervisorConfig::default().without_signal_handler();
1051 let sup = Supervisor::new(config);
1052
1053 let (svc, _) = CountingService::new("named-svc", RestartPolicy::OnFailure);
1054 sup.spawn_service(Box::new(svc));
1055
1056 tokio::time::sleep(Duration::from_millis(50)).await;
1057
1058 let snap = sup.service_lifecycle("named-svc").await;
1059 assert!(snap.is_some());
1060 assert_eq!(snap.unwrap().service_name, "named-svc");
1061
1062 assert!(sup.service_lifecycle("nonexistent").await.is_none());
1063
1064 sup.trigger_shutdown();
1065 sup.wait_for_drain().await;
1066 }
1067
1068 #[tokio::test]
1069 async fn test_metrics_snapshot() {
1070 let sup = Supervisor::with_defaults();
1071 let snap = sup.metrics().snapshot();
1072 assert_eq!(snap.restarts_total, 0);
1073 assert_eq!(snap.active_services, 0);
1074 assert_eq!(snap.spawned_total, 0);
1075 assert_eq!(snap.terminated_total, 0);
1076 assert_eq!(snap.circuit_breaker_trips, 0);
1077 }
1078
1079 #[tokio::test]
1080 async fn test_shutdown_timeout() {
1081 let config = SupervisorConfig::default()
1082 .without_signal_handler()
1083 .with_shutdown_timeout(Duration::from_millis(100));
1084 let sup = Supervisor::new(config);
1085
1086 struct HangingService;
1087 #[async_trait]
1088 impl TradingService for HangingService {
1089 fn name(&self) -> &str {
1090 "hanger"
1091 }
1092 async fn run(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
1093 tokio::time::sleep(Duration::from_secs(3600)).await;
1094 Ok(())
1095 }
1096 }
1097
1098 sup.spawn_service(Box::new(HangingService));
1099 tokio::time::sleep(Duration::from_millis(20)).await;
1100 sup.trigger_shutdown();
1101
1102 let start = std::time::Instant::now();
1103 sup.wait_for_drain().await;
1104 let elapsed = start.elapsed();
1105
1106 assert!(
1107 elapsed < Duration::from_secs(1),
1108 "drain took too long: {elapsed:?}"
1109 );
1110 }
1111
1112 #[tokio::test]
1113 async fn test_spawn_with_custom_backoff() {
1114 let config = SupervisorConfig::default().without_signal_handler();
1115 let sup = Supervisor::new(config);
1116
1117 let (svc, attempts) = AlwaysFailService::new("custom-backoff");
1118 let custom_backoff =
1119 BackoffConfig::new(Duration::from_millis(5), Duration::from_millis(10))
1120 .with_circuit_breaker(2, Duration::from_secs(60));
1121
1122 sup.spawn_service_with_options(Box::new(svc), SpawnOptions::with_backoff(custom_backoff));
1123
1124 tokio::time::sleep(Duration::from_millis(200)).await;
1125
1126 assert!(attempts.load(Ordering::SeqCst) >= 2);
1127 let snap = sup.metrics().snapshot();
1128 assert_eq!(snap.circuit_breaker_trips, 1);
1129
1130 sup.trigger_shutdown();
1131 sup.wait_for_drain().await;
1132 }
1133
1134 #[tokio::test]
1135 async fn test_restart_policy_always_on_clean_exit() {
1136 let config = SupervisorConfig::default()
1137 .without_signal_handler()
1138 .with_default_backoff(
1139 BackoffConfig::new(Duration::from_millis(10), Duration::from_millis(50))
1140 .without_circuit_breaker(),
1141 );
1142 let sup = Supervisor::new(config);
1143
1144 struct ExitImmediately {
1145 count: Arc<AtomicU64>,
1146 }
1147 #[async_trait]
1148 impl TradingService for ExitImmediately {
1149 fn name(&self) -> &str {
1150 "exit-immediately"
1151 }
1152 fn restart_policy(&self) -> RestartPolicy {
1153 RestartPolicy::Always
1154 }
1155 async fn run(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
1156 self.count.fetch_add(1, Ordering::SeqCst);
1157 tokio::time::sleep(Duration::from_millis(1)).await;
1158 Ok(())
1159 }
1160 }
1161
1162 let count = Arc::new(AtomicU64::new(0));
1163 let svc = ExitImmediately {
1164 count: count.clone(),
1165 };
1166 sup.spawn_service(Box::new(svc));
1167
1168 tokio::time::sleep(Duration::from_millis(500)).await;
1169
1170 let runs = count.load(Ordering::SeqCst);
1171 assert!(
1172 runs >= 2,
1173 "expected service to run multiple times with Always policy, got {runs}"
1174 );
1175
1176 sup.trigger_shutdown();
1177 sup.wait_for_drain().await;
1178 }
1179
1180 #[tokio::test]
1181 async fn test_is_shutting_down() {
1182 let sup = Supervisor::with_defaults();
1183 assert!(!sup.is_shutting_down());
1184 sup.trigger_shutdown();
1185 assert!(sup.is_shutting_down());
1186 }
1187
1188 #[tokio::test]
1189 async fn test_config_builder() {
1190 let config = SupervisorConfig::default()
1191 .with_shutdown_timeout(Duration::from_secs(10))
1192 .with_default_backoff(BackoffConfig::new(
1193 Duration::from_millis(200),
1194 Duration::from_secs(30),
1195 ))
1196 .without_signal_handler();
1197
1198 assert_eq!(config.shutdown_timeout, Duration::from_secs(10));
1199 assert!(!config.install_signal_handler);
1200 assert_eq!(
1201 config.default_backoff.base_delay,
1202 Duration::from_millis(200)
1203 );
1204 assert_eq!(config.default_backoff.max_delay, Duration::from_secs(30));
1205 }
1206
1207 struct ChaosService {
1215 name: String,
1216 fail_times: u32,
1217 current: Arc<AtomicU32>,
1218 attempt_times: Arc<tokio::sync::Mutex<Vec<std::time::Instant>>>,
1219 policy: RestartPolicy,
1220 }
1221
1222 impl ChaosService {
1223 fn new(name: &str, fail_times: u32, policy: RestartPolicy) -> Self {
1224 Self {
1225 name: name.to_string(),
1226 fail_times,
1227 current: Arc::new(AtomicU32::new(0)),
1228 attempt_times: Arc::new(tokio::sync::Mutex::new(Vec::new())),
1229 policy,
1230 }
1231 }
1232 }
1233
1234 #[async_trait]
1235 impl TradingService for ChaosService {
1236 fn name(&self) -> &str {
1237 &self.name
1238 }
1239 fn restart_policy(&self) -> RestartPolicy {
1240 self.policy
1241 }
1242 async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
1243 {
1244 let mut ts = self.attempt_times.lock().await;
1245 ts.push(std::time::Instant::now());
1246 }
1247 let n = self.current.fetch_add(1, Ordering::SeqCst);
1248 if n < self.fail_times {
1249 anyhow::bail!("chaos failure #{}", n + 1);
1250 }
1251 cancel.cancelled().await;
1252 Ok(())
1253 }
1254 }
1255
1256 #[tokio::test]
1258 async fn test_chaos_exponential_backoff() {
1259 let backoff = BackoffConfig::new(Duration::from_millis(20), Duration::from_secs(2))
1260 .without_circuit_breaker();
1261 let config = SupervisorConfig::default()
1262 .with_shutdown_timeout(Duration::from_secs(5))
1263 .with_default_backoff(backoff)
1264 .without_signal_handler();
1265 let sup = Supervisor::new(config);
1266
1267 let chaos = ChaosService::new("chaos-backoff", 3, RestartPolicy::OnFailure);
1268 let attempts_arc = chaos.attempt_times.clone();
1269 let current_arc = chaos.current.clone();
1270
1271 sup.spawn_service(Box::new(chaos));
1272
1273 let deadline = tokio::time::Instant::now() + Duration::from_secs(10);
1274 loop {
1275 let count = current_arc.load(Ordering::SeqCst);
1276 if count >= 4 {
1277 break;
1278 }
1279 if tokio::time::Instant::now() > deadline {
1280 panic!("chaos service did not recover; attempts={count}");
1281 }
1282 tokio::time::sleep(Duration::from_millis(50)).await;
1283 }
1284
1285 tokio::time::sleep(Duration::from_millis(100)).await;
1286
1287 let timestamps = attempts_arc.lock().await;
1288 assert!(
1289 timestamps.len() >= 4,
1290 "expected >= 4 attempts, got {}",
1291 timestamps.len()
1292 );
1293
1294 let delays: Vec<Duration> = timestamps
1295 .windows(2)
1296 .map(|w| w[1].duration_since(w[0]))
1297 .collect();
1298
1299 for (i, d) in delays.iter().enumerate().skip(1) {
1303 assert!(
1304 *d >= Duration::from_millis(1),
1305 "delay[{i}] too short: {d:?} — backoff may not be working"
1306 );
1307 }
1308
1309 let metrics = sup.metrics().snapshot();
1310 assert!(
1311 metrics.restarts_total >= 3,
1312 "expected >= 3 restarts, got {}",
1313 metrics.restarts_total
1314 );
1315
1316 sup.trigger_shutdown();
1317 sup.wait_for_drain().await;
1318
1319 let snap = sup.service_lifecycle("chaos-backoff").await.unwrap();
1320 assert_eq!(snap.phase, ServicePhase::Terminated);
1321 }
1322
1323 #[tokio::test]
1325 async fn test_chaos_circuit_breaker_trips() {
1326 let backoff = BackoffConfig::new(Duration::from_millis(10), Duration::from_millis(50))
1327 .with_circuit_breaker(3, Duration::from_secs(60));
1328 let config = SupervisorConfig::default()
1329 .with_shutdown_timeout(Duration::from_secs(5))
1330 .with_default_backoff(backoff)
1331 .without_signal_handler();
1332 let sup = Supervisor::new(config);
1333
1334 let chaos = ChaosService::new("chaos-cb", 1000, RestartPolicy::OnFailure);
1335 let current_arc = chaos.current.clone();
1336 sup.spawn_service(Box::new(chaos));
1337
1338 let deadline = tokio::time::Instant::now() + Duration::from_secs(10);
1339 loop {
1340 if let Some(snap) = sup.service_lifecycle("chaos-cb").await
1341 && snap.phase == ServicePhase::Terminated
1342 {
1343 break;
1344 }
1345 if tokio::time::Instant::now() > deadline {
1346 panic!(
1347 "circuit breaker did not trip; attempts={}",
1348 current_arc.load(Ordering::SeqCst)
1349 );
1350 }
1351 tokio::time::sleep(Duration::from_millis(50)).await;
1352 }
1353
1354 let snap = sup.service_lifecycle("chaos-cb").await.unwrap();
1355 assert_eq!(snap.phase, ServicePhase::Terminated);
1356 let reason = snap
1357 .termination_reason
1358 .as_deref()
1359 .expect("termination reason");
1360 assert!(
1361 reason.contains("circuit breaker"),
1362 "expected circuit breaker termination, got: {reason}"
1363 );
1364
1365 let metrics = sup.metrics().snapshot();
1366 assert!(metrics.circuit_breaker_trips >= 1);
1367
1368 let attempts_at_trip = current_arc.load(Ordering::SeqCst);
1369 tokio::time::sleep(Duration::from_millis(200)).await;
1370 let attempts_after = current_arc.load(Ordering::SeqCst);
1371 assert_eq!(
1372 attempts_at_trip, attempts_after,
1373 "service should NOT restart after circuit breaker trips"
1374 );
1375
1376 sup.trigger_shutdown();
1377 sup.wait_for_drain().await;
1378 }
1379
1380 struct LifecycleTracer {
1382 name: String,
1383 log: Arc<tokio::sync::Mutex<Vec<String>>>,
1384 policy: RestartPolicy,
1385 }
1386
1387 impl LifecycleTracer {
1388 fn new(
1389 name: &str,
1390 log: Arc<tokio::sync::Mutex<Vec<String>>>,
1391 policy: RestartPolicy,
1392 ) -> Self {
1393 Self {
1394 name: name.to_string(),
1395 log,
1396 policy,
1397 }
1398 }
1399 }
1400
1401 #[async_trait]
1402 impl TradingService for LifecycleTracer {
1403 fn name(&self) -> &str {
1404 &self.name
1405 }
1406 fn restart_policy(&self) -> RestartPolicy {
1407 self.policy
1408 }
1409 async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
1410 {
1411 let mut l = self.log.lock().await;
1412 l.push(format!("{}:started", self.name));
1413 }
1414 cancel.cancelled().await;
1415 {
1416 let mut l = self.log.lock().await;
1417 l.push(format!("{}:stopped", self.name));
1418 }
1419 Ok(())
1420 }
1421 }
1422
1423 #[tokio::test]
1425 async fn test_chaos_mixed_fleet() {
1426 let backoff = BackoffConfig::new(Duration::from_millis(10), Duration::from_millis(100))
1427 .with_circuit_breaker(3, Duration::from_secs(60));
1428 let config = SupervisorConfig::default()
1429 .with_shutdown_timeout(Duration::from_secs(5))
1430 .with_default_backoff(backoff)
1431 .without_signal_handler();
1432 let sup = Supervisor::new(config);
1433
1434 let log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
1435
1436 sup.spawn_service(Box::new(LifecycleTracer::new(
1437 "healthy-api",
1438 log.clone(),
1439 RestartPolicy::OnFailure,
1440 )));
1441
1442 let chaos = ChaosService::new("bad-data", 1000, RestartPolicy::OnFailure);
1443 sup.spawn_service(Box::new(chaos));
1444
1445 let recovering = ChaosService::new("flaky-cns", 2, RestartPolicy::OnFailure);
1446 let recovering_attempts = recovering.current.clone();
1447 sup.spawn_service(Box::new(recovering));
1448
1449 let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
1450 loop {
1451 if sup.service_count().await == 3 {
1452 break;
1453 }
1454 if tokio::time::Instant::now() > deadline {
1455 panic!(
1456 "timed out waiting for 3 services to register; got {}",
1457 sup.service_count().await
1458 );
1459 }
1460 tokio::time::sleep(Duration::from_millis(10)).await;
1461 }
1462
1463 let deadline = tokio::time::Instant::now() + Duration::from_secs(10);
1464 loop {
1465 let bad_done = sup
1466 .service_lifecycle("bad-data")
1467 .await
1468 .is_some_and(|s| s.phase == ServicePhase::Terminated);
1469 let recovered = recovering_attempts.load(Ordering::SeqCst) >= 3;
1470 if bad_done && recovered {
1471 break;
1472 }
1473 if tokio::time::Instant::now() > deadline {
1474 panic!("mixed fleet did not reach expected state");
1475 }
1476 tokio::time::sleep(Duration::from_millis(50)).await;
1477 }
1478
1479 let healthy_snap = sup.service_lifecycle("healthy-api").await.unwrap();
1480 assert!(healthy_snap.phase.is_alive());
1481
1482 let bad_snap = sup.service_lifecycle("bad-data").await.unwrap();
1483 assert_eq!(bad_snap.phase, ServicePhase::Terminated);
1484 assert!(
1485 bad_snap
1486 .termination_reason
1487 .as_deref()
1488 .is_some_and(|r| r.contains("circuit breaker"))
1489 );
1490
1491 let flaky_snap = sup.service_lifecycle("flaky-cns").await.unwrap();
1492 assert!(flaky_snap.phase.is_alive());
1493 assert!(flaky_snap.start_count >= 3);
1494
1495 sup.trigger_shutdown();
1496 sup.wait_for_drain().await;
1497
1498 for name in &["healthy-api", "bad-data", "flaky-cns"] {
1499 let snap = sup.service_lifecycle(name).await.unwrap();
1500 assert_eq!(
1501 snap.phase,
1502 ServicePhase::Terminated,
1503 "service '{name}' should be Terminated after shutdown"
1504 );
1505 }
1506
1507 let metrics = sup.metrics().snapshot();
1508 assert_eq!(metrics.active_services, 0);
1509 assert_eq!(metrics.spawned_total, 3);
1510 assert_eq!(metrics.terminated_total, 3);
1511 }
1512}