1use std::collections::HashMap;
47use std::future::Future;
48use std::pin::Pin;
49use std::sync::Arc;
50use std::time::{Duration, Instant};
51
52use tokio::sync::{mpsc, oneshot};
53use tokio::task::AbortHandle;
54use tokio_util::sync::CancellationToken;
55use tracing::Instrument as _;
56use zeph_common::BlockingSpawner;
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum RestartPolicy {
65 RunOnce,
67 Restart { max: u32, base_delay: Duration },
89}
90
91pub const MAX_RESTART_DELAY: Duration = Duration::from_secs(60);
93
94pub struct TaskDescriptor<F> {
99 pub name: &'static str,
104 pub restart: RestartPolicy,
106 pub factory: F,
108}
109
110#[derive(Debug, Clone)]
114pub struct TaskHandle {
115 name: &'static str,
116 abort: AbortHandle,
117}
118
119impl TaskHandle {
120 pub fn abort(&self) {
122 tracing::debug!(task.name = self.name, "task aborted via handle");
123 self.abort.abort();
124 }
125
126 #[must_use]
128 pub fn name(&self) -> &'static str {
129 self.name
130 }
131}
132
133#[derive(Debug, PartialEq, Eq)]
135pub enum BlockingError {
136 Panicked,
138 SupervisorDropped,
140}
141
142impl std::fmt::Display for BlockingError {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 match self {
145 Self::Panicked => write!(f, "supervised blocking task panicked"),
146 Self::SupervisorDropped => write!(f, "supervisor dropped before task completed"),
147 }
148 }
149}
150
151impl std::error::Error for BlockingError {}
152
153pub struct BlockingHandle<R> {
162 rx: oneshot::Receiver<Result<R, BlockingError>>,
163 abort: AbortHandle,
164}
165
166impl<R> BlockingHandle<R> {
167 pub async fn join(self) -> Result<R, BlockingError> {
175 self.rx
176 .await
177 .unwrap_or(Err(BlockingError::SupervisorDropped))
178 }
179
180 pub fn abort(&self) {
182 self.abort.abort();
183 }
184}
185
186#[derive(Debug, Clone, PartialEq, Eq)]
188pub enum TaskStatus {
189 Running,
191 Restarting { attempt: u32, max: u32 },
193 Completed,
195 Aborted,
197 Failed { reason: String },
199}
200
201#[derive(Debug, Clone)]
203pub struct TaskSnapshot {
216 pub name: Arc<str>,
218 pub status: TaskStatus,
220 pub started_at: Instant,
222 pub restart_count: u32,
224}
225
226type BoxFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
229type BoxFactory = Box<dyn Fn() -> BoxFuture + Send + Sync>;
230
231struct TaskEntry {
232 name: Arc<str>,
233 status: TaskStatus,
234 started_at: Instant,
235 restart_count: u32,
236 restart_policy: RestartPolicy,
237 abort_handle: AbortHandle,
238 factory: Option<BoxFactory>,
240}
241
242#[derive(Debug, Clone, Copy, PartialEq, Eq)]
244enum CompletionKind {
245 Normal,
247 Panicked,
249 Cancelled,
251}
252
253struct Completion {
254 name: Arc<str>,
255 kind: CompletionKind,
256}
257
258struct SupervisorState {
259 tasks: HashMap<Arc<str>, TaskEntry>,
260}
261
262struct Inner {
263 state: parking_lot::Mutex<SupervisorState>,
264 completion_tx: mpsc::UnboundedSender<Completion>,
268 cancel: CancellationToken,
269 blocking_semaphore: Arc<tokio::sync::Semaphore>,
272}
273
274#[derive(Clone)]
310pub struct TaskSupervisor {
311 inner: Arc<Inner>,
312}
313
314impl TaskSupervisor {
315 #[must_use]
322 pub fn new(cancel: CancellationToken) -> Self {
323 let (completion_tx, completion_rx) = mpsc::unbounded_channel();
327 let inner = Arc::new(Inner {
328 state: parking_lot::Mutex::new(SupervisorState {
329 tasks: HashMap::new(),
330 }),
331 completion_tx,
332 cancel: cancel.clone(),
333 blocking_semaphore: Arc::new(tokio::sync::Semaphore::new(8)),
334 });
335
336 Self::start_reap_driver(Arc::clone(&inner), completion_rx, cancel);
337
338 Self { inner }
339 }
340
341 pub fn spawn<F, Fut>(&self, desc: TaskDescriptor<F>) -> TaskHandle
366 where
367 F: Fn() -> Fut + Send + Sync + 'static,
368 Fut: Future<Output = ()> + Send + 'static,
369 {
370 let factory: BoxFactory = Box::new(move || Box::pin((desc.factory)()));
371 let cancel = self.inner.cancel.clone();
372 let completion_tx = self.inner.completion_tx.clone();
373 let name: Arc<str> = Arc::from(desc.name);
374
375 let (abort_handle, jh) = Self::do_spawn(desc.name, &factory, cancel);
376 Self::wire_completion_reporter(Arc::clone(&name), jh, completion_tx);
377
378 let entry = TaskEntry {
379 name: Arc::clone(&name),
380 status: TaskStatus::Running,
381 started_at: Instant::now(),
382 restart_count: 0,
383 restart_policy: desc.restart,
384 abort_handle: abort_handle.clone(),
385 factory: match desc.restart {
386 RestartPolicy::RunOnce => None,
387 RestartPolicy::Restart { .. } => Some(factory),
388 },
389 };
390
391 {
392 let mut state = self.inner.state.lock();
393 if let Some(old) = state.tasks.remove(&name) {
394 old.abort_handle.abort();
395 }
396 state.tasks.insert(Arc::clone(&name), entry);
397 }
398
399 TaskHandle {
400 name: desc.name,
401 abort: abort_handle,
402 }
403 }
404
405 #[allow(clippy::needless_pass_by_value)] pub fn spawn_blocking<F, R>(&self, name: Arc<str>, f: F) -> BlockingHandle<R>
450 where
451 F: FnOnce() -> R + Send + 'static,
452 R: Send + 'static,
453 {
454 let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
455 #[cfg(feature = "task-metrics")]
456 let span = tracing::info_span!(
457 "supervised_blocking_task",
458 task.name = %name,
459 task.wall_time_ms = tracing::field::Empty,
460 task.cpu_time_ms = tracing::field::Empty,
461 );
462 #[cfg(not(feature = "task-metrics"))]
463 let span = tracing::info_span!("supervised_blocking_task", task.name = %name);
464
465 let semaphore = Arc::clone(&self.inner.blocking_semaphore);
466 let inner = Arc::clone(&self.inner);
467 let name_clone = Arc::clone(&name);
468 let completion_tx = self.inner.completion_tx.clone();
469
470 let outer = tokio::spawn(async move {
473 let _permit = semaphore
474 .acquire_owned()
475 .await
476 .expect("blocking semaphore closed");
477
478 let name_for_measure = Arc::clone(&name_clone);
479 let join_handle = tokio::task::spawn_blocking(move || {
480 let _enter = span.enter();
481 measure_blocking(&name_for_measure, f)
482 });
483 let abort = join_handle.abort_handle();
484
485 {
487 let mut state = inner.state.lock();
488 if let Some(entry) = state.tasks.get_mut(&name_clone) {
489 entry.abort_handle = abort;
490 }
491 }
492
493 let kind = match join_handle.await {
494 Ok(val) => {
495 let _ = tx.send(Ok(val));
496 CompletionKind::Normal
497 }
498 Err(e) if e.is_panic() => {
499 let _ = tx.send(Err(BlockingError::Panicked));
500 CompletionKind::Panicked
501 }
502 Err(_) => {
503 CompletionKind::Cancelled
505 }
506 };
507 let _ = completion_tx.send(Completion {
509 name: name_clone,
510 kind,
511 });
512 });
513 let abort = outer.abort_handle();
514
515 {
517 let mut state = self.inner.state.lock();
518 if let Some(old) = state.tasks.remove(&name) {
519 old.abort_handle.abort();
520 }
521 state.tasks.insert(
522 Arc::clone(&name),
523 TaskEntry {
524 name: Arc::clone(&name),
525 status: TaskStatus::Running,
526 started_at: Instant::now(),
527 restart_count: 0,
528 restart_policy: RestartPolicy::RunOnce,
529 abort_handle: abort.clone(),
530 factory: None,
531 },
532 );
533 }
534
535 BlockingHandle { rx, abort }
536 }
537
538 pub fn spawn_oneshot<F, Fut, R>(&self, name: Arc<str>, factory: F) -> BlockingHandle<R>
566 where
567 F: FnOnce() -> Fut + Send + 'static,
568 Fut: Future<Output = R> + Send + 'static,
569 R: Send + 'static,
570 {
571 let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
572 let cancel = self.inner.cancel.clone();
573 let span = tracing::info_span!("supervised_task", task.name = %name);
574 let join_handle: tokio::task::JoinHandle<Option<R>> = tokio::spawn(
575 async move {
576 let fut = factory();
577 tokio::select! {
578 result = fut => Some(result),
579 () = cancel.cancelled() => None,
580 }
581 }
582 .instrument(span),
583 );
584 let abort = join_handle.abort_handle();
585
586 {
587 let mut state = self.inner.state.lock();
588 if let Some(old) = state.tasks.remove(&name) {
589 old.abort_handle.abort();
590 }
591 state.tasks.insert(
592 Arc::clone(&name),
593 TaskEntry {
594 name: Arc::clone(&name),
595 status: TaskStatus::Running,
596 started_at: Instant::now(),
597 restart_count: 0,
598 restart_policy: RestartPolicy::RunOnce,
599 abort_handle: abort.clone(),
600 factory: None,
601 },
602 );
603 }
604
605 let completion_tx = self.inner.completion_tx.clone();
606 tokio::spawn(async move {
607 let kind = match join_handle.await {
608 Ok(Some(val)) => {
609 let _ = tx.send(Ok(val));
610 CompletionKind::Normal
611 }
612 Err(e) if e.is_panic() => {
613 let _ = tx.send(Err(BlockingError::Panicked));
614 CompletionKind::Panicked
615 }
616 _ => CompletionKind::Cancelled,
617 };
618 let _ = completion_tx.send(Completion { name, kind });
619 });
620 BlockingHandle { rx, abort }
621 }
622
623 pub fn abort(&self, name: &'static str) {
625 let state = self.inner.state.lock();
626 let key: Arc<str> = Arc::from(name);
627 if let Some(entry) = state.tasks.get(&key) {
628 entry.abort_handle.abort();
629 tracing::debug!(task.name = name, "task aborted via supervisor");
630 }
631 }
632
633 pub async fn shutdown_all(&self, timeout: Duration) {
646 self.inner.cancel.cancel();
647 let deadline = tokio::time::Instant::now() + timeout;
648 loop {
649 let active = self.active_count();
650 if active == 0 {
651 break;
652 }
653 if tokio::time::Instant::now() >= deadline {
654 tracing::warn!(
655 remaining = active,
656 "shutdown timeout — aborting remaining tasks"
657 );
658 let mut state = self.inner.state.lock();
659 for entry in state.tasks.values_mut() {
660 if matches!(
661 entry.status,
662 TaskStatus::Running | TaskStatus::Restarting { .. }
663 ) {
664 entry.abort_handle.abort();
665 entry.status = TaskStatus::Aborted;
666 }
667 }
668 break;
669 }
670 tokio::time::sleep(Duration::from_millis(50)).await;
671 }
672 }
673
674 #[must_use]
679 pub fn snapshot(&self) -> Vec<TaskSnapshot> {
680 let state = self.inner.state.lock();
681 let mut snaps: Vec<TaskSnapshot> = state
682 .tasks
683 .values()
684 .map(|e| TaskSnapshot {
685 name: Arc::clone(&e.name),
686 status: e.status.clone(),
687 started_at: e.started_at,
688 restart_count: e.restart_count,
689 })
690 .collect();
691 snaps.sort_by_key(|s| s.started_at);
692 snaps
693 }
694
695 #[must_use]
697 pub fn active_count(&self) -> usize {
698 let state = self.inner.state.lock();
699 state
700 .tasks
701 .values()
702 .filter(|e| {
703 matches!(
704 e.status,
705 TaskStatus::Running | TaskStatus::Restarting { .. }
706 )
707 })
708 .count()
709 }
710
711 #[must_use]
715 pub fn cancellation_token(&self) -> CancellationToken {
716 self.inner.cancel.clone()
717 }
718
719 fn do_spawn(
723 name: &'static str,
724 factory: &BoxFactory,
725 cancel: CancellationToken,
726 ) -> (AbortHandle, tokio::task::JoinHandle<()>) {
727 let fut = factory();
728 let span = tracing::info_span!("supervised_task", task.name = name);
729 let jh = tokio::spawn(
730 async move {
731 tokio::select! {
732 () = fut => {},
733 () = cancel.cancelled() => {},
734 }
735 }
736 .instrument(span),
737 );
738 let abort = jh.abort_handle();
739 (abort, jh)
740 }
741
742 fn wire_completion_reporter(
744 name: Arc<str>,
745 jh: tokio::task::JoinHandle<()>,
746 completion_tx: mpsc::UnboundedSender<Completion>,
747 ) {
748 tokio::spawn(async move {
749 let kind = match jh.await {
750 Ok(()) => CompletionKind::Normal,
751 Err(e) if e.is_panic() => CompletionKind::Panicked,
752 Err(_) => CompletionKind::Cancelled,
753 };
754 let _ = completion_tx.send(Completion { name, kind });
755 });
756 }
757
758 fn start_reap_driver(
765 inner: Arc<Inner>,
766 mut completion_rx: mpsc::UnboundedReceiver<Completion>,
767 cancel: CancellationToken,
768 ) {
769 tokio::spawn(async move {
770 loop {
771 tokio::select! {
772 biased;
773 Some(completion) = completion_rx.recv() => {
774 Self::handle_completion(&inner, completion).await;
775 }
776 () = cancel.cancelled() => {
777 while let Ok(completion) = completion_rx.try_recv() {
779 Self::handle_completion(&inner, completion).await;
780 }
781 break;
782 }
783 }
784 }
785 });
786 }
787
788 async fn handle_completion(inner: &Arc<Inner>, completion: Completion) {
794 let Some((attempt, max, delay)) = Self::classify_completion(inner, &completion) else {
795 return;
796 };
797
798 tracing::warn!(
799 task.name = %completion.name,
800 attempt,
801 max,
802 delay_ms = delay.as_millis(),
803 "restarting supervised task"
804 );
805
806 if !delay.is_zero() {
807 tokio::time::sleep(delay).await;
808 }
809
810 Self::do_restart(inner, &completion.name, attempt);
811 }
812
813 fn classify_completion(
817 inner: &Arc<Inner>,
818 completion: &Completion,
819 ) -> Option<(u32, u32, Duration)> {
820 let mut state = inner.state.lock();
821 let entry = state.tasks.get_mut(&completion.name)?;
822
823 match completion.kind {
824 CompletionKind::Panicked => {
825 tracing::warn!(task.name = %completion.name, "supervised task panicked");
826 }
827 CompletionKind::Normal => {
828 tracing::info!(task.name = %completion.name, "supervised task completed");
829 }
830 CompletionKind::Cancelled => {
831 tracing::debug!(task.name = %completion.name, "supervised task cancelled");
832 }
833 }
834
835 match entry.restart_policy {
836 RestartPolicy::RunOnce => {
837 entry.status = TaskStatus::Completed;
838 state.tasks.remove(&completion.name);
839 None
840 }
841 RestartPolicy::Restart { max, base_delay } => {
842 if completion.kind != CompletionKind::Panicked {
844 entry.status = TaskStatus::Completed;
845 state.tasks.remove(&completion.name);
846 return None;
847 }
848 if entry.restart_count >= max {
849 let reason = format!("panicked after {max} restart(s)");
850 tracing::error!(
851 task.name = %completion.name,
852 attempts = max,
853 "task failed permanently"
854 );
855 entry.status = TaskStatus::Failed { reason };
856 None
857 } else {
858 let attempt = entry.restart_count + 1;
859 entry.status = TaskStatus::Restarting { attempt, max };
860 let multiplier = 1_u32
862 .checked_shl(attempt.saturating_sub(1))
863 .unwrap_or(u32::MAX);
864 let delay = base_delay.saturating_mul(multiplier).min(MAX_RESTART_DELAY);
865 Some((attempt, max, delay))
866 }
867 }
868 }
869 }
871
872 fn do_restart(inner: &Arc<Inner>, name: &Arc<str>, attempt: u32) {
874 let spawn_params = {
875 let mut state = inner.state.lock();
876 let Some(entry) = state.tasks.get_mut(name.as_ref()) else {
877 tracing::debug!(
878 task.name = %name,
879 "task removed during restart delay — skipping"
880 );
881 return;
882 };
883 if !matches!(entry.status, TaskStatus::Restarting { .. }) {
884 return;
885 }
886 let Some(factory) = &entry.factory else {
887 return;
888 };
889 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(factory)) {
892 Err(_) => {
893 let reason = format!("factory panicked on restart attempt {attempt}");
894 tracing::error!(task.name = %name, attempt, "factory panicked during restart");
895 entry.status = TaskStatus::Failed { reason };
896 None
897 }
898 Ok(fut) => Some((
899 fut,
900 inner.cancel.clone(),
901 inner.completion_tx.clone(),
902 name.clone(),
903 )),
904 }
905 };
907
908 let Some((fut, cancel, completion_tx, name)) = spawn_params else {
909 return;
910 };
911
912 let span = tracing::info_span!("supervised_task", task.name = %name);
913 let jh = tokio::spawn(
914 async move {
915 tokio::select! {
916 () = fut => {},
917 () = cancel.cancelled() => {},
918 }
919 }
920 .instrument(span),
921 );
922 let new_abort = jh.abort_handle();
923
924 {
925 let mut state = inner.state.lock();
926 if let Some(entry) = state.tasks.get_mut(name.as_ref()) {
927 entry.restart_count = attempt;
928 entry.status = TaskStatus::Running;
929 entry.abort_handle = new_abort;
930 }
931 }
932
933 Self::wire_completion_reporter(name.clone(), jh, completion_tx);
934 }
935}
936
937#[cfg(feature = "task-metrics")]
944#[inline]
945fn measure_blocking<F, R>(name: &str, f: F) -> R
946where
947 F: FnOnce() -> R,
948{
949 use cpu_time::ThreadTime;
950 let wall_start = std::time::Instant::now();
951 let cpu_start = ThreadTime::now();
952 let result = f();
953 let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
954 let cpu_ms = cpu_start.elapsed().as_secs_f64() * 1000.0;
955 metrics::histogram!("zeph.task.wall_time_ms", "task" => name.to_owned()).record(wall_ms);
956 metrics::histogram!("zeph.task.cpu_time_ms", "task" => name.to_owned()).record(cpu_ms);
957 tracing::Span::current().record("task.wall_time_ms", wall_ms);
958 tracing::Span::current().record("task.cpu_time_ms", cpu_ms);
959 result
960}
961
962#[cfg(not(feature = "task-metrics"))]
966#[inline]
967fn measure_blocking<F, R>(_name: &str, f: F) -> R
968where
969 F: FnOnce() -> R,
970{
971 f()
972}
973
974impl BlockingSpawner for TaskSupervisor {
977 fn spawn_blocking_named(
983 &self,
984 name: Arc<str>,
985 f: Box<dyn FnOnce() + Send + 'static>,
986 ) -> tokio::task::JoinHandle<()> {
987 let handle = self.spawn_blocking(Arc::clone(&name), f);
988 tokio::spawn(async move {
989 if let Err(e) = handle.join().await {
990 tracing::error!(task.name = %name, error = %e, "supervised blocking task failed");
991 }
992 })
993 }
994}
995
996#[cfg(test)]
999mod tests {
1000 use std::sync::Arc;
1001 use std::sync::atomic::{AtomicU32, Ordering};
1002 use std::time::Duration;
1003
1004 use tokio_util::sync::CancellationToken;
1005
1006 use super::*;
1007
1008 fn make_supervisor() -> (TaskSupervisor, CancellationToken) {
1009 let cancel = CancellationToken::new();
1010 let sup = TaskSupervisor::new(cancel.clone());
1011 (sup, cancel)
1012 }
1013
1014 #[tokio::test]
1015 async fn test_spawn_and_complete() {
1016 let (sup, _cancel) = make_supervisor();
1017
1018 let done = Arc::new(tokio::sync::Notify::new());
1019 let done2 = Arc::clone(&done);
1020
1021 sup.spawn(TaskDescriptor {
1022 name: "simple",
1023 restart: RestartPolicy::RunOnce,
1024 factory: move || {
1025 let d = Arc::clone(&done2);
1026 async move {
1027 d.notify_one();
1028 }
1029 },
1030 });
1031
1032 tokio::time::timeout(Duration::from_secs(2), done.notified())
1033 .await
1034 .expect("task should complete");
1035
1036 tokio::time::sleep(Duration::from_millis(50)).await;
1037 assert_eq!(
1038 sup.active_count(),
1039 0,
1040 "RunOnce task should be removed after completion"
1041 );
1042 }
1043
1044 #[tokio::test]
1045 async fn test_panic_capture() {
1046 let (sup, _cancel) = make_supervisor();
1047
1048 sup.spawn(TaskDescriptor {
1049 name: "panicking",
1050 restart: RestartPolicy::RunOnce,
1051 factory: || async { panic!("intentional test panic") },
1052 });
1053
1054 tokio::time::sleep(Duration::from_millis(200)).await;
1055
1056 let snaps = sup.snapshot();
1057 assert!(
1058 snaps.iter().all(|s| s.name.as_ref() != "panicking"),
1059 "entry should be reaped"
1060 );
1061 assert_eq!(
1062 sup.active_count(),
1063 0,
1064 "active count must be 0 after RunOnce panic"
1065 );
1066 }
1067
1068 #[tokio::test]
1071 async fn test_restart_only_on_panic() {
1072 let (sup, _cancel) = make_supervisor();
1073
1074 let normal_counter = Arc::new(AtomicU32::new(0));
1076 let nc = Arc::clone(&normal_counter);
1077 sup.spawn(TaskDescriptor {
1078 name: "normal-exit",
1079 restart: RestartPolicy::Restart {
1080 max: 3,
1081 base_delay: Duration::from_millis(10),
1082 },
1083 factory: move || {
1084 let c = Arc::clone(&nc);
1085 async move {
1086 c.fetch_add(1, Ordering::SeqCst);
1087 }
1089 },
1090 });
1091
1092 tokio::time::sleep(Duration::from_millis(300)).await;
1093 assert_eq!(
1094 normal_counter.load(Ordering::SeqCst),
1095 1,
1096 "normal exit must not restart"
1097 );
1098 assert!(
1099 sup.snapshot()
1100 .iter()
1101 .all(|s| s.name.as_ref() != "normal-exit"),
1102 "entry removed after normal exit"
1103 );
1104
1105 let panic_counter = Arc::new(AtomicU32::new(0));
1107 let pc = Arc::clone(&panic_counter);
1108 sup.spawn(TaskDescriptor {
1109 name: "panic-exit",
1110 restart: RestartPolicy::Restart {
1111 max: 2,
1112 base_delay: Duration::from_millis(10),
1113 },
1114 factory: move || {
1115 let c = Arc::clone(&pc);
1116 async move {
1117 c.fetch_add(1, Ordering::SeqCst);
1118 panic!("test panic");
1119 }
1120 },
1121 });
1122
1123 tokio::time::sleep(Duration::from_millis(500)).await;
1125 assert!(
1126 panic_counter.load(Ordering::SeqCst) >= 3,
1127 "panicking task must restart max times"
1128 );
1129 let snap = sup
1130 .snapshot()
1131 .into_iter()
1132 .find(|s| s.name.as_ref() == "panic-exit");
1133 assert!(
1134 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1135 "task must be Failed after exhausting restarts"
1136 );
1137 }
1138
1139 #[tokio::test]
1140 async fn test_restart_policy() {
1141 let (sup, _cancel) = make_supervisor();
1142
1143 let counter = Arc::new(AtomicU32::new(0));
1144 let counter2 = Arc::clone(&counter);
1145
1146 sup.spawn(TaskDescriptor {
1147 name: "restartable",
1148 restart: RestartPolicy::Restart {
1149 max: 2,
1150 base_delay: Duration::from_millis(10),
1151 },
1152 factory: move || {
1153 let c = Arc::clone(&counter2);
1154 async move {
1155 c.fetch_add(1, Ordering::SeqCst);
1156 panic!("always panic");
1157 }
1158 },
1159 });
1160
1161 tokio::time::sleep(Duration::from_millis(500)).await;
1162
1163 let runs = counter.load(Ordering::SeqCst);
1164 assert!(
1165 runs >= 3,
1166 "expected at least 3 invocations (initial + 2 restarts), got {runs}"
1167 );
1168
1169 let snaps = sup.snapshot();
1170 let snap = snaps.iter().find(|s| s.name.as_ref() == "restartable");
1171 assert!(snap.is_some(), "failed task should remain in registry");
1172 assert!(
1173 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1174 "task should be Failed after exhausting retries"
1175 );
1176 }
1177
1178 #[tokio::test]
1180 async fn test_exponential_backoff() {
1181 let (sup, _cancel) = make_supervisor();
1182
1183 let timestamps = Arc::new(parking_lot::Mutex::new(Vec::<std::time::Instant>::new()));
1184 let ts = Arc::clone(×tamps);
1185
1186 sup.spawn(TaskDescriptor {
1187 name: "backoff-task",
1188 restart: RestartPolicy::Restart {
1189 max: 3,
1190 base_delay: Duration::from_millis(50),
1191 },
1192 factory: move || {
1193 let t = Arc::clone(&ts);
1194 async move {
1195 t.lock().push(std::time::Instant::now());
1196 panic!("always panic");
1197 }
1198 },
1199 });
1200
1201 tokio::time::sleep(Duration::from_millis(800)).await;
1203
1204 let ts = timestamps.lock();
1205 assert!(
1206 ts.len() >= 3,
1207 "expected at least 3 invocations, got {}",
1208 ts.len()
1209 );
1210
1211 if ts.len() >= 3 {
1213 let d1 = ts[1].duration_since(ts[0]);
1214 let d2 = ts[2].duration_since(ts[1]);
1215 assert!(
1217 d2 >= d1.mul_f64(1.5),
1218 "expected exponential backoff: d1={d1:?} d2={d2:?}"
1219 );
1220 }
1221 }
1222
1223 #[tokio::test]
1224 async fn test_graceful_shutdown() {
1225 let (sup, _cancel) = make_supervisor();
1226
1227 for name in ["svc-a", "svc-b", "svc-c"] {
1228 sup.spawn(TaskDescriptor {
1229 name,
1230 restart: RestartPolicy::RunOnce,
1231 factory: || async {
1232 tokio::time::sleep(Duration::from_secs(60)).await;
1233 },
1234 });
1235 }
1236
1237 assert_eq!(sup.active_count(), 3);
1238
1239 tokio::time::timeout(
1240 Duration::from_secs(2),
1241 sup.shutdown_all(Duration::from_secs(1)),
1242 )
1243 .await
1244 .expect("shutdown should complete within timeout");
1245 }
1246
1247 #[tokio::test]
1249 async fn test_force_abort_marks_aborted() {
1250 let cancel = CancellationToken::new();
1251 let sup = TaskSupervisor::new(cancel.clone());
1252
1253 sup.spawn(TaskDescriptor {
1254 name: "stubborn-for-abort",
1255 restart: RestartPolicy::RunOnce,
1256 factory: || async {
1257 std::future::pending::<()>().await;
1259 },
1260 });
1261
1262 sup.shutdown_all(Duration::from_millis(1)).await;
1264
1265 let snaps = sup.snapshot();
1267 if let Some(snap) = snaps
1268 .iter()
1269 .find(|s| s.name.as_ref() == "stubborn-for-abort")
1270 {
1271 assert_eq!(
1272 snap.status,
1273 TaskStatus::Aborted,
1274 "force-aborted task must have Aborted status"
1275 );
1276 }
1277 }
1279
1280 #[tokio::test]
1281 async fn test_registry_snapshot() {
1282 let (sup, _cancel) = make_supervisor();
1283
1284 for name in ["alpha", "beta"] {
1285 sup.spawn(TaskDescriptor {
1286 name,
1287 restart: RestartPolicy::RunOnce,
1288 factory: || async {
1289 tokio::time::sleep(Duration::from_secs(10)).await;
1290 },
1291 });
1292 }
1293
1294 let snaps = sup.snapshot();
1295 assert_eq!(snaps.len(), 2);
1296 let names: Vec<&str> = snaps.iter().map(|s| s.name.as_ref()).collect();
1297 assert!(names.contains(&"alpha"));
1298 assert!(names.contains(&"beta"));
1299 assert!(snaps.iter().all(|s| s.status == TaskStatus::Running));
1300 }
1301
1302 #[tokio::test]
1303 async fn test_blocking_returns_value() {
1304 let (sup, cancel) = make_supervisor();
1305
1306 let handle: BlockingHandle<u32> = sup.spawn_blocking(Arc::from("compute"), || 42_u32);
1307 let result = handle.join().await.expect("should return value");
1308 assert_eq!(result, 42);
1309 cancel.cancel();
1310 }
1311
1312 #[tokio::test]
1313 async fn test_blocking_panic() {
1314 let (sup, _cancel) = make_supervisor();
1315
1316 let handle: BlockingHandle<u32> =
1317 sup.spawn_blocking(Arc::from("panicking-compute"), || panic!("intentional"));
1318 let err = handle
1319 .join()
1320 .await
1321 .expect_err("should return error on panic");
1322 assert_eq!(err, BlockingError::Panicked);
1323 }
1324
1325 #[tokio::test]
1327 async fn test_blocking_registered_in_registry() {
1328 let (sup, cancel) = make_supervisor();
1329
1330 let (tx, rx) = std::sync::mpsc::channel::<()>();
1331 let _handle: BlockingHandle<()> =
1332 sup.spawn_blocking(Arc::from("blocking-task"), move || {
1333 let _ = rx.recv();
1335 });
1336
1337 tokio::time::sleep(Duration::from_millis(10)).await;
1338 assert_eq!(
1339 sup.active_count(),
1340 1,
1341 "blocking task must appear in active_count"
1342 );
1343
1344 let _ = tx.send(());
1345 tokio::time::sleep(Duration::from_millis(100)).await;
1346 assert_eq!(
1347 sup.active_count(),
1348 0,
1349 "blocking task must be removed after completion"
1350 );
1351
1352 cancel.cancel();
1353 }
1354
1355 #[tokio::test]
1357 async fn test_oneshot_registered_in_registry() {
1358 let (sup, cancel) = make_supervisor();
1359
1360 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
1361 let _handle: BlockingHandle<()> =
1362 sup.spawn_oneshot(Arc::from("oneshot-task"), move || async move {
1363 let _ = rx.await;
1364 });
1365
1366 tokio::time::sleep(Duration::from_millis(10)).await;
1367 assert_eq!(
1368 sup.active_count(),
1369 1,
1370 "oneshot task must appear in active_count"
1371 );
1372
1373 let _ = tx.send(());
1374 tokio::time::sleep(Duration::from_millis(50)).await;
1375 assert_eq!(
1376 sup.active_count(),
1377 0,
1378 "oneshot task must be removed after completion"
1379 );
1380
1381 cancel.cancel();
1382 }
1383
1384 #[tokio::test]
1385 async fn test_restart_max_zero() {
1386 let (sup, _cancel) = make_supervisor();
1387
1388 let counter = Arc::new(AtomicU32::new(0));
1389 let counter2 = Arc::clone(&counter);
1390
1391 sup.spawn(TaskDescriptor {
1392 name: "zero-max",
1393 restart: RestartPolicy::Restart {
1394 max: 0,
1395 base_delay: Duration::from_millis(10),
1396 },
1397 factory: move || {
1398 let c = Arc::clone(&counter2);
1399 async move {
1400 c.fetch_add(1, Ordering::SeqCst);
1401 panic!("always panic");
1402 }
1403 },
1404 });
1405
1406 tokio::time::sleep(Duration::from_millis(200)).await;
1407
1408 assert_eq!(
1409 counter.load(Ordering::SeqCst),
1410 1,
1411 "max=0 should not restart"
1412 );
1413
1414 let snaps = sup.snapshot();
1415 let snap = snaps.iter().find(|s| s.name.as_ref() == "zero-max");
1416 assert!(snap.is_some(), "entry should remain as Failed");
1417 assert!(
1418 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1419 "status should be Failed"
1420 );
1421 }
1422
1423 #[tokio::test]
1425 async fn test_concurrent_spawns() {
1426 let (sup, cancel) = make_supervisor();
1427
1428 static NAMES: [&str; 50] = [
1430 "t00", "t01", "t02", "t03", "t04", "t05", "t06", "t07", "t08", "t09", "t10", "t11",
1431 "t12", "t13", "t14", "t15", "t16", "t17", "t18", "t19", "t20", "t21", "t22", "t23",
1432 "t24", "t25", "t26", "t27", "t28", "t29", "t30", "t31", "t32", "t33", "t34", "t35",
1433 "t36", "t37", "t38", "t39", "t40", "t41", "t42", "t43", "t44", "t45", "t46", "t47",
1434 "t48", "t49",
1435 ];
1436
1437 let completed = Arc::new(AtomicU32::new(0));
1438 for name in &NAMES {
1439 let c = Arc::clone(&completed);
1440 sup.spawn(TaskDescriptor {
1441 name,
1442 restart: RestartPolicy::RunOnce,
1443 factory: move || {
1444 let c = Arc::clone(&c);
1445 async move {
1446 c.fetch_add(1, Ordering::SeqCst);
1447 }
1448 },
1449 });
1450 }
1451
1452 tokio::time::timeout(Duration::from_secs(5), async {
1454 loop {
1455 if completed.load(Ordering::SeqCst) == 50 {
1456 break;
1457 }
1458 tokio::time::sleep(Duration::from_millis(10)).await;
1459 }
1460 })
1461 .await
1462 .expect("all 50 tasks should complete");
1463
1464 tokio::time::sleep(Duration::from_millis(100)).await;
1466 assert_eq!(sup.active_count(), 0, "all tasks must be reaped");
1467
1468 cancel.cancel();
1469 }
1470
1471 #[tokio::test]
1472 async fn test_shutdown_timeout_expiry() {
1473 let cancel = CancellationToken::new();
1474 let sup = TaskSupervisor::new(cancel.clone());
1475
1476 sup.spawn(TaskDescriptor {
1477 name: "stubborn",
1478 restart: RestartPolicy::RunOnce,
1479 factory: || async {
1480 tokio::time::sleep(Duration::from_secs(60)).await;
1481 },
1482 });
1483
1484 assert_eq!(sup.active_count(), 1);
1485
1486 tokio::time::timeout(
1487 Duration::from_secs(2),
1488 sup.shutdown_all(Duration::from_millis(50)),
1489 )
1490 .await
1491 .expect("shutdown_all should return even on timeout expiry");
1492
1493 assert!(
1494 cancel.is_cancelled(),
1495 "cancel token must be cancelled after shutdown"
1496 );
1497 }
1498
1499 #[tokio::test]
1500 async fn test_cancellation_token() {
1501 let cancel = CancellationToken::new();
1502 let sup = TaskSupervisor::new(cancel.clone());
1503
1504 assert!(!sup.cancellation_token().is_cancelled());
1505
1506 sup.shutdown_all(Duration::from_millis(100)).await;
1507
1508 assert!(
1509 sup.cancellation_token().is_cancelled(),
1510 "token must be cancelled after shutdown"
1511 );
1512 }
1513
1514 #[tokio::test]
1515 async fn test_blocking_spawner_task_appears_in_snapshot() {
1516 use zeph_common::BlockingSpawner;
1518
1519 let cancel = CancellationToken::new();
1520 let sup = TaskSupervisor::new(cancel);
1521
1522 let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<()>();
1523 let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>();
1524
1525 let handle = sup.spawn_blocking_named(
1526 Arc::from("chunk_file"),
1527 Box::new(move || {
1528 let _ = ready_tx.send(());
1530 let _ = release_rx.blocking_recv();
1532 }),
1533 );
1534
1535 ready_rx.await.expect("task should start");
1537
1538 let snapshot = sup.snapshot();
1539 assert!(
1540 snapshot.iter().any(|t| t.name.as_ref() == "chunk_file"),
1541 "chunk_file task must appear in supervisor snapshot"
1542 );
1543
1544 let _ = release_tx.send(());
1546 handle.await.expect("task should complete");
1547 }
1548
1549 #[cfg(feature = "task-metrics")]
1556 #[test]
1557 fn test_measure_blocking_emits_metrics() {
1558 use metrics_util::debugging::DebuggingRecorder;
1559
1560 let recorder = DebuggingRecorder::new();
1561 let snapshotter = recorder.snapshotter();
1562
1563 metrics::with_local_recorder(&recorder, || {
1566 measure_blocking("test_task", || std::hint::black_box(42_u64));
1567 });
1568
1569 let snapshot = snapshotter.snapshot();
1570 let metric_names: Vec<String> = snapshot
1571 .into_vec()
1572 .into_iter()
1573 .map(|(k, _, _, _)| k.key().name().to_owned())
1574 .collect();
1575
1576 assert!(
1577 metric_names.iter().any(|n| n == "zeph.task.wall_time_ms"),
1578 "expected zeph.task.wall_time_ms histogram; got: {metric_names:?}"
1579 );
1580 assert!(
1581 metric_names.iter().any(|n| n == "zeph.task.cpu_time_ms"),
1582 "expected zeph.task.cpu_time_ms histogram; got: {metric_names:?}"
1583 );
1584 }
1585
1586 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1592 async fn test_spawn_blocking_semaphore_cap() {
1593 let (sup, _cancel) = make_supervisor();
1594 let concurrent = Arc::new(AtomicU32::new(0));
1595 let max_concurrent = Arc::new(AtomicU32::new(0));
1596 let barrier = Arc::new(std::sync::Barrier::new(1)); let mut handles = Vec::new();
1599 for i in 0u32..16 {
1600 let c = Arc::clone(&concurrent);
1601 let m = Arc::clone(&max_concurrent);
1602 let name: Arc<str> = Arc::from(format!("blocking-{i}").as_str());
1603 let h = sup.spawn_blocking(name, move || {
1604 let prev = c.fetch_add(1, Ordering::SeqCst);
1605 let mut cur_max = m.load(Ordering::SeqCst);
1607 while prev + 1 > cur_max {
1608 match m.compare_exchange(cur_max, prev + 1, Ordering::SeqCst, Ordering::SeqCst)
1609 {
1610 Ok(_) => break,
1611 Err(x) => cur_max = x,
1612 }
1613 }
1614 std::thread::sleep(std::time::Duration::from_millis(20));
1616 c.fetch_sub(1, Ordering::SeqCst);
1617 });
1618 handles.push(h);
1619 }
1620
1621 for h in handles {
1622 h.join().await.expect("blocking task should succeed");
1623 }
1624 drop(barrier);
1625
1626 let observed = max_concurrent.load(Ordering::SeqCst);
1627 assert!(
1628 observed <= 8,
1629 "observed {observed} concurrent blocking tasks; expected ≤ 8 (semaphore cap)"
1630 );
1631 }
1632}