1use std::collections::HashMap;
47use std::future::Future;
48use std::pin::Pin;
49use std::sync::Arc;
50use std::time::{Duration, Instant};
51
52use tokio::sync::{mpsc, oneshot};
53use tokio::task::AbortHandle;
54use tokio_util::sync::CancellationToken;
55use tracing::Instrument as _;
56
57use crate::BlockingSpawner;
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum RestartPolicy {
66 RunOnce,
68 Restart { max: u32, base_delay: Duration },
90}
91
92pub const MAX_RESTART_DELAY: Duration = Duration::from_mins(1);
94
95const SHUTDOWN_DRAIN_TIMEOUT: Duration = Duration::from_secs(5);
100
101pub struct TaskDescriptor<F> {
106 pub name: &'static str,
111 pub restart: RestartPolicy,
113 pub factory: F,
115}
116
117#[derive(Debug, Clone)]
121pub struct TaskHandle {
122 name: &'static str,
123 abort: AbortHandle,
124}
125
126impl TaskHandle {
127 pub fn abort(&self) {
129 tracing::debug!(task.name = self.name, "task aborted via handle");
130 self.abort.abort();
131 }
132
133 #[must_use]
135 pub fn name(&self) -> &'static str {
136 self.name
137 }
138}
139
140#[derive(Debug, PartialEq, Eq)]
142pub enum BlockingError {
143 Panicked,
145 SupervisorDropped,
147}
148
149impl std::fmt::Display for BlockingError {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 match self {
152 Self::Panicked => write!(f, "supervised blocking task panicked"),
153 Self::SupervisorDropped => write!(f, "supervisor dropped before task completed"),
154 }
155 }
156}
157
158impl std::error::Error for BlockingError {}
159
160pub struct BlockingHandle<R> {
169 rx: oneshot::Receiver<Result<R, BlockingError>>,
170 abort: AbortHandle,
171}
172
173impl<R> BlockingHandle<R> {
174 pub async fn join(self) -> Result<R, BlockingError> {
182 self.rx
183 .await
184 .unwrap_or(Err(BlockingError::SupervisorDropped))
185 }
186
187 pub fn try_join(mut self) -> Result<Result<R, BlockingError>, Self> {
219 match self.rx.try_recv() {
220 Ok(result) => Ok(result),
221 Err(tokio::sync::oneshot::error::TryRecvError::Empty) => Err(self),
222 Err(tokio::sync::oneshot::error::TryRecvError::Closed) => {
223 Ok(Err(BlockingError::SupervisorDropped))
224 }
225 }
226 }
227
228 pub fn abort(&self) {
230 self.abort.abort();
231 }
232}
233
234#[derive(Debug, Clone, PartialEq, Eq)]
236pub enum TaskStatus {
237 Running,
239 Restarting { attempt: u32, max: u32 },
241 Completed,
243 Aborted,
245 Failed { reason: String },
247}
248
249#[derive(Debug, Clone)]
251pub struct TaskSnapshot {
261 pub name: Arc<str>,
263 pub status: TaskStatus,
265 pub started_at: Instant,
267 pub restart_count: u32,
269}
270
271type BoxFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
274type BoxFactory = Box<dyn Fn() -> BoxFuture + Send + Sync>;
275
276struct TaskEntry {
277 name: Arc<str>,
278 status: TaskStatus,
279 started_at: Instant,
280 restart_count: u32,
281 restart_policy: RestartPolicy,
282 abort_handle: AbortHandle,
283 factory: Option<BoxFactory>,
285}
286
287#[derive(Debug, Clone, Copy, PartialEq, Eq)]
289enum CompletionKind {
290 Normal,
292 Panicked,
294 Cancelled,
296}
297
298struct Completion {
299 name: Arc<str>,
300 kind: CompletionKind,
301}
302
303struct SupervisorState {
304 tasks: HashMap<Arc<str>, TaskEntry>,
305}
306
307struct Inner {
308 state: parking_lot::Mutex<SupervisorState>,
309 completion_tx: mpsc::UnboundedSender<Completion>,
313 cancel: CancellationToken,
314 blocking_semaphore: Arc<tokio::sync::Semaphore>,
317}
318
319#[derive(Clone)]
355pub struct TaskSupervisor {
356 inner: Arc<Inner>,
357}
358
359impl TaskSupervisor {
360 #[must_use]
372 pub fn new(cancel: CancellationToken) -> Self {
373 let (completion_tx, completion_rx) = mpsc::unbounded_channel();
377 let inner = Arc::new(Inner {
378 state: parking_lot::Mutex::new(SupervisorState {
379 tasks: HashMap::new(),
380 }),
381 completion_tx,
382 cancel: cancel.clone(),
383 blocking_semaphore: Arc::new(tokio::sync::Semaphore::new(8)),
384 });
385
386 if tokio::runtime::Handle::try_current().is_ok() {
391 Self::start_reap_driver(Arc::clone(&inner), completion_rx, cancel);
392 }
393
394 Self { inner }
395 }
396
397 pub fn spawn<F, Fut>(&self, desc: TaskDescriptor<F>) -> TaskHandle
422 where
423 F: Fn() -> Fut + Send + Sync + 'static,
424 Fut: Future<Output = ()> + Send + 'static,
425 {
426 let factory: BoxFactory = Box::new(move || Box::pin((desc.factory)()));
427 let cancel = self.inner.cancel.clone();
428 let completion_tx = self.inner.completion_tx.clone();
429 let name: Arc<str> = Arc::from(desc.name);
430
431 let (abort_handle, jh) = Self::do_spawn(desc.name, &factory, cancel);
432 Self::wire_completion_reporter(Arc::clone(&name), jh, completion_tx);
433
434 let entry = TaskEntry {
435 name: Arc::clone(&name),
436 status: TaskStatus::Running,
437 started_at: Instant::now(),
438 restart_count: 0,
439 restart_policy: desc.restart,
440 abort_handle: abort_handle.clone(),
441 factory: match desc.restart {
442 RestartPolicy::RunOnce => None,
443 RestartPolicy::Restart { .. } => Some(factory),
444 },
445 };
446
447 {
448 let mut state = self.inner.state.lock();
449 if let Some(old) = state.tasks.remove(&name) {
450 old.abort_handle.abort();
451 }
452 state.tasks.insert(Arc::clone(&name), entry);
453 }
454
455 TaskHandle {
456 name: desc.name,
457 abort: abort_handle,
458 }
459 }
460
461 #[allow(clippy::needless_pass_by_value)] pub fn spawn_blocking<F, R>(&self, name: Arc<str>, f: F) -> BlockingHandle<R>
506 where
507 F: FnOnce() -> R + Send + 'static,
508 R: Send + 'static,
509 {
510 let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
511 let span = tracing::info_span!(
512 "supervised_blocking_task",
513 task.name = %name,
514 task.wall_time_ms = tracing::field::Empty,
515 task.cpu_time_ms = tracing::field::Empty,
516 );
517
518 let semaphore = Arc::clone(&self.inner.blocking_semaphore);
519 let inner = Arc::clone(&self.inner);
520 let name_clone = Arc::clone(&name);
521 let completion_tx = self.inner.completion_tx.clone();
522
523 let outer = tokio::spawn(async move {
526 let _permit = semaphore
527 .acquire_owned()
528 .await
529 .expect("blocking semaphore closed");
530
531 let name_for_measure = Arc::clone(&name_clone);
532 let join_handle = tokio::task::spawn_blocking(move || {
533 let _enter = span.enter();
534 measure_blocking(&name_for_measure, f)
535 });
536 let abort = join_handle.abort_handle();
537
538 {
540 let mut state = inner.state.lock();
541 if let Some(entry) = state.tasks.get_mut(&name_clone) {
542 entry.abort_handle = abort;
543 }
544 }
545
546 let kind = match join_handle.await {
547 Ok(val) => {
548 let _ = tx.send(Ok(val));
549 CompletionKind::Normal
550 }
551 Err(e) if e.is_panic() => {
552 let _ = tx.send(Err(BlockingError::Panicked));
553 CompletionKind::Panicked
554 }
555 Err(_) => {
556 CompletionKind::Cancelled
558 }
559 };
560 let _ = completion_tx.send(Completion {
562 name: name_clone,
563 kind,
564 });
565 });
566 let abort = outer.abort_handle();
567
568 {
570 let mut state = self.inner.state.lock();
571 if let Some(old) = state.tasks.remove(&name) {
572 old.abort_handle.abort();
573 }
574 state.tasks.insert(
575 Arc::clone(&name),
576 TaskEntry {
577 name: Arc::clone(&name),
578 status: TaskStatus::Running,
579 started_at: Instant::now(),
580 restart_count: 0,
581 restart_policy: RestartPolicy::RunOnce,
582 abort_handle: abort.clone(),
583 factory: None,
584 },
585 );
586 }
587
588 BlockingHandle { rx, abort }
589 }
590
591 pub fn spawn_oneshot<F, Fut, R>(&self, name: Arc<str>, factory: F) -> BlockingHandle<R>
619 where
620 F: FnOnce() -> Fut + Send + 'static,
621 Fut: Future<Output = R> + Send + 'static,
622 R: Send + 'static,
623 {
624 let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
625 let cancel = self.inner.cancel.clone();
626 let span = tracing::info_span!("supervised_task", task.name = %name);
627 let join_handle: tokio::task::JoinHandle<Option<R>> = tokio::spawn(
628 async move {
629 let fut = factory();
630 tokio::select! {
631 result = fut => Some(result),
632 () = cancel.cancelled() => None,
633 }
634 }
635 .instrument(span),
636 );
637 let abort = join_handle.abort_handle();
638
639 {
640 let mut state = self.inner.state.lock();
641 if let Some(old) = state.tasks.remove(&name) {
642 old.abort_handle.abort();
643 }
644 state.tasks.insert(
645 Arc::clone(&name),
646 TaskEntry {
647 name: Arc::clone(&name),
648 status: TaskStatus::Running,
649 started_at: Instant::now(),
650 restart_count: 0,
651 restart_policy: RestartPolicy::RunOnce,
652 abort_handle: abort.clone(),
653 factory: None,
654 },
655 );
656 }
657
658 let completion_tx = self.inner.completion_tx.clone();
659 tokio::spawn(async move {
660 let kind = match join_handle.await {
661 Ok(Some(val)) => {
662 let _ = tx.send(Ok(val));
663 CompletionKind::Normal
664 }
665 Err(e) if e.is_panic() => {
666 let _ = tx.send(Err(BlockingError::Panicked));
667 CompletionKind::Panicked
668 }
669 _ => CompletionKind::Cancelled,
670 };
671 let _ = completion_tx.send(Completion { name, kind });
672 });
673 BlockingHandle { rx, abort }
674 }
675
676 pub fn abort(&self, name: &'static str) {
678 let state = self.inner.state.lock();
679 let key: Arc<str> = Arc::from(name);
680 if let Some(entry) = state.tasks.get(&key) {
681 entry.abort_handle.abort();
682 tracing::debug!(task.name = name, "task aborted via supervisor");
683 }
684 }
685
686 pub async fn shutdown_all(&self, timeout: Duration) {
699 self.inner.cancel.cancel();
700 let deadline = tokio::time::Instant::now() + timeout;
701 loop {
702 let active = self.active_count();
703 if active == 0 {
704 break;
705 }
706 if tokio::time::Instant::now() >= deadline {
707 let mut remaining_names: Vec<Arc<str>> = Vec::new();
708 {
709 let mut state = self.inner.state.lock();
710 for entry in state.tasks.values_mut() {
711 if matches!(
712 entry.status,
713 TaskStatus::Running | TaskStatus::Restarting { .. }
714 ) {
715 remaining_names.push(Arc::clone(&entry.name));
716 entry.abort_handle.abort();
717 entry.status = TaskStatus::Aborted;
718 }
719 }
720 }
721 tracing::warn!(
722 remaining = active,
723 tasks = ?remaining_names,
724 "shutdown timeout — aborting remaining tasks"
725 );
726 break;
727 }
728 tokio::time::sleep(Duration::from_millis(50)).await;
729 }
730 }
731
732 #[must_use]
737 pub fn snapshot(&self) -> Vec<TaskSnapshot> {
738 let state = self.inner.state.lock();
739 let mut snaps: Vec<TaskSnapshot> = state
740 .tasks
741 .values()
742 .map(|e| TaskSnapshot {
743 name: Arc::clone(&e.name),
744 status: e.status.clone(),
745 started_at: e.started_at,
746 restart_count: e.restart_count,
747 })
748 .collect();
749 snaps.sort_by_key(|s| s.started_at);
750 snaps
751 }
752
753 #[must_use]
755 pub fn active_count(&self) -> usize {
756 let state = self.inner.state.lock();
757 state
758 .tasks
759 .values()
760 .filter(|e| {
761 matches!(
762 e.status,
763 TaskStatus::Running | TaskStatus::Restarting { .. }
764 )
765 })
766 .count()
767 }
768
769 #[must_use]
773 pub fn cancellation_token(&self) -> CancellationToken {
774 self.inner.cancel.clone()
775 }
776
777 fn do_spawn(
781 name: &'static str,
782 factory: &BoxFactory,
783 cancel: CancellationToken,
784 ) -> (AbortHandle, tokio::task::JoinHandle<()>) {
785 let fut = factory();
786 let span = tracing::info_span!("supervised_task", task.name = name);
787 let jh = tokio::spawn(
788 async move {
789 tokio::select! {
790 () = fut => {},
791 () = cancel.cancelled() => {},
792 }
793 }
794 .instrument(span),
795 );
796 let abort = jh.abort_handle();
797 (abort, jh)
798 }
799
800 fn wire_completion_reporter(
802 name: Arc<str>,
803 jh: tokio::task::JoinHandle<()>,
804 completion_tx: mpsc::UnboundedSender<Completion>,
805 ) {
806 tokio::spawn(async move {
807 let kind = match jh.await {
808 Ok(()) => CompletionKind::Normal,
809 Err(e) if e.is_panic() => CompletionKind::Panicked,
810 Err(_) => CompletionKind::Cancelled,
811 };
812 let _ = completion_tx.send(Completion { name, kind });
813 });
814 }
815
816 fn start_reap_driver(
823 inner: Arc<Inner>,
824 mut completion_rx: mpsc::UnboundedReceiver<Completion>,
825 cancel: CancellationToken,
826 ) {
827 tokio::spawn(async move {
828 loop {
830 tokio::select! {
831 biased;
832 Some(completion) = completion_rx.recv() => {
833 Self::handle_completion(&inner, completion).await;
834 }
835 () = cancel.cancelled() => break,
836 }
837 }
838
839 let drain_deadline = tokio::time::Instant::now() + SHUTDOWN_DRAIN_TIMEOUT;
844 let active = Self::has_active_tasks(&inner);
845 tracing::debug!(active, "reap driver entered post-cancel drain phase");
846 loop {
847 if !Self::has_active_tasks(&inner) {
848 break;
849 }
850 let remaining =
851 drain_deadline.saturating_duration_since(tokio::time::Instant::now());
852 if remaining.is_zero() {
853 break;
854 }
855 match tokio::time::timeout(remaining, completion_rx.recv()).await {
856 Ok(Some(completion)) => Self::handle_completion(&inner, completion).await,
857 Ok(None) | Err(_) => break,
859 }
860 }
861 tracing::debug!(
862 active = Self::has_active_tasks(&inner),
863 "reap driver drain phase complete"
864 );
865 });
866 }
867
868 fn has_active_tasks(inner: &Arc<Inner>) -> bool {
870 let state = inner.state.lock();
871 state.tasks.values().any(|e| {
872 matches!(
873 e.status,
874 TaskStatus::Running | TaskStatus::Restarting { .. }
875 )
876 })
877 }
878
879 async fn handle_completion(inner: &Arc<Inner>, completion: Completion) {
885 if inner.cancel.is_cancelled() {
889 let mut state = inner.state.lock();
890 state.tasks.remove(&completion.name);
891 return;
892 }
893
894 let Some((attempt, max, delay)) = Self::classify_completion(inner, &completion) else {
895 return;
896 };
897
898 tracing::warn!(
899 task.name = %completion.name,
900 attempt,
901 max,
902 delay_ms = delay.as_millis(),
903 "restarting supervised task"
904 );
905
906 if !delay.is_zero() {
907 tokio::time::sleep(delay).await;
908 }
909
910 Self::do_restart(inner, &completion.name, attempt);
911 }
912
913 fn classify_completion(
917 inner: &Arc<Inner>,
918 completion: &Completion,
919 ) -> Option<(u32, u32, Duration)> {
920 let mut state = inner.state.lock();
921 let entry = state.tasks.get_mut(&completion.name)?;
922
923 match completion.kind {
924 CompletionKind::Panicked => {
925 tracing::warn!(task.name = %completion.name, "supervised task panicked");
926 }
927 CompletionKind::Normal => {
928 tracing::info!(task.name = %completion.name, "supervised task completed");
929 }
930 CompletionKind::Cancelled => {
931 tracing::debug!(task.name = %completion.name, "supervised task cancelled");
932 }
933 }
934
935 match entry.restart_policy {
936 RestartPolicy::RunOnce => {
937 entry.status = TaskStatus::Completed;
938 state.tasks.remove(&completion.name);
939 None
940 }
941 RestartPolicy::Restart { max, base_delay } => {
942 if completion.kind != CompletionKind::Panicked {
944 entry.status = TaskStatus::Completed;
945 state.tasks.remove(&completion.name);
946 return None;
947 }
948 if entry.restart_count >= max {
949 let reason = format!("panicked after {max} restart(s)");
950 tracing::error!(
951 task.name = %completion.name,
952 attempts = max,
953 "task failed permanently"
954 );
955 entry.status = TaskStatus::Failed { reason };
956 None
957 } else {
958 let attempt = entry.restart_count + 1;
959 entry.status = TaskStatus::Restarting { attempt, max };
960 let multiplier = 1_u32
962 .checked_shl(attempt.saturating_sub(1))
963 .unwrap_or(u32::MAX);
964 let delay = base_delay.saturating_mul(multiplier).min(MAX_RESTART_DELAY);
965 Some((attempt, max, delay))
966 }
967 }
968 }
969 }
971
972 fn do_restart(inner: &Arc<Inner>, name: &Arc<str>, attempt: u32) {
974 let spawn_params = {
975 let mut state = inner.state.lock();
976 let Some(entry) = state.tasks.get_mut(name.as_ref()) else {
977 tracing::debug!(
978 task.name = %name,
979 "task removed during restart delay — skipping"
980 );
981 return;
982 };
983 if !matches!(entry.status, TaskStatus::Restarting { .. }) {
984 return;
985 }
986 let Some(factory) = &entry.factory else {
987 return;
988 };
989 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(factory)) {
992 Err(_) => {
993 let reason = format!("factory panicked on restart attempt {attempt}");
994 tracing::error!(task.name = %name, attempt, "factory panicked during restart");
995 entry.status = TaskStatus::Failed { reason };
996 None
997 }
998 Ok(fut) => Some((
999 fut,
1000 inner.cancel.clone(),
1001 inner.completion_tx.clone(),
1002 name.clone(),
1003 )),
1004 }
1005 };
1007
1008 let Some((fut, cancel, completion_tx, name)) = spawn_params else {
1009 return;
1010 };
1011
1012 let span = tracing::info_span!("supervised_task", task.name = %name);
1013 let jh = tokio::spawn(
1014 async move {
1015 tokio::select! {
1016 () = fut => {},
1017 () = cancel.cancelled() => {},
1018 }
1019 }
1020 .instrument(span),
1021 );
1022 let new_abort = jh.abort_handle();
1023
1024 {
1025 let mut state = inner.state.lock();
1026 if let Some(entry) = state.tasks.get_mut(name.as_ref()) {
1027 entry.restart_count = attempt;
1028 entry.status = TaskStatus::Running;
1029 entry.abort_handle = new_abort;
1030 }
1031 }
1032
1033 Self::wire_completion_reporter(name.clone(), jh, completion_tx);
1034 }
1035}
1036
1037#[inline]
1041fn measure_blocking<F, R>(name: &str, f: F) -> R
1042where
1043 F: FnOnce() -> R,
1044{
1045 use cpu_time::ThreadTime;
1046 let wall_start = std::time::Instant::now();
1047 let cpu_start = ThreadTime::now();
1048 let result = f();
1049 let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
1050 let cpu_ms = cpu_start.elapsed().as_secs_f64() * 1000.0;
1051 metrics::histogram!("zeph.task.wall_time_ms", "task" => name.to_owned()).record(wall_ms);
1052 metrics::histogram!("zeph.task.cpu_time_ms", "task" => name.to_owned()).record(cpu_ms);
1053 tracing::Span::current().record("task.wall_time_ms", wall_ms);
1054 tracing::Span::current().record("task.cpu_time_ms", cpu_ms);
1055 result
1056}
1057
1058impl BlockingSpawner for TaskSupervisor {
1061 fn spawn_blocking_named(
1067 &self,
1068 name: Arc<str>,
1069 f: Box<dyn FnOnce() + Send + 'static>,
1070 ) -> tokio::task::JoinHandle<()> {
1071 let handle = self.spawn_blocking(Arc::clone(&name), f);
1072 tokio::spawn(async move {
1073 if let Err(e) = handle.join().await {
1074 tracing::error!(task.name = %name, error = %e, "supervised blocking task failed");
1075 }
1076 })
1077 }
1078}
1079
1080#[cfg(test)]
1083mod tests {
1084 use std::sync::Arc;
1085 use std::sync::atomic::{AtomicU32, Ordering};
1086 use std::time::Duration;
1087
1088 use tokio_util::sync::CancellationToken;
1089
1090 use super::*;
1091
1092 fn make_supervisor() -> (TaskSupervisor, CancellationToken) {
1093 let cancel = CancellationToken::new();
1094 let sup = TaskSupervisor::new(cancel.clone());
1095 (sup, cancel)
1096 }
1097
1098 #[tokio::test]
1099 async fn test_spawn_and_complete() {
1100 let (sup, _cancel) = make_supervisor();
1101
1102 let done = Arc::new(tokio::sync::Notify::new());
1103 let done2 = Arc::clone(&done);
1104
1105 sup.spawn(TaskDescriptor {
1106 name: "simple",
1107 restart: RestartPolicy::RunOnce,
1108 factory: move || {
1109 let d = Arc::clone(&done2);
1110 async move {
1111 d.notify_one();
1112 }
1113 },
1114 });
1115
1116 tokio::time::timeout(Duration::from_secs(2), done.notified())
1117 .await
1118 .expect("task should complete");
1119
1120 tokio::time::sleep(Duration::from_millis(50)).await;
1121 assert_eq!(
1122 sup.active_count(),
1123 0,
1124 "RunOnce task should be removed after completion"
1125 );
1126 }
1127
1128 #[tokio::test]
1129 async fn test_panic_capture() {
1130 let (sup, _cancel) = make_supervisor();
1131
1132 sup.spawn(TaskDescriptor {
1133 name: "panicking",
1134 restart: RestartPolicy::RunOnce,
1135 factory: || async { panic!("intentional test panic") },
1136 });
1137
1138 tokio::time::sleep(Duration::from_millis(200)).await;
1139
1140 let snaps = sup.snapshot();
1141 assert!(
1142 snaps.iter().all(|s| s.name.as_ref() != "panicking"),
1143 "entry should be reaped"
1144 );
1145 assert_eq!(
1146 sup.active_count(),
1147 0,
1148 "active count must be 0 after RunOnce panic"
1149 );
1150 }
1151
1152 #[tokio::test]
1155 async fn test_restart_only_on_panic() {
1156 let (sup, _cancel) = make_supervisor();
1157
1158 let normal_counter = Arc::new(AtomicU32::new(0));
1160 let nc = Arc::clone(&normal_counter);
1161 sup.spawn(TaskDescriptor {
1162 name: "normal-exit",
1163 restart: RestartPolicy::Restart {
1164 max: 3,
1165 base_delay: Duration::from_millis(10),
1166 },
1167 factory: move || {
1168 let c = Arc::clone(&nc);
1169 async move {
1170 c.fetch_add(1, Ordering::SeqCst);
1171 }
1173 },
1174 });
1175
1176 tokio::time::sleep(Duration::from_millis(300)).await;
1177 assert_eq!(
1178 normal_counter.load(Ordering::SeqCst),
1179 1,
1180 "normal exit must not restart"
1181 );
1182 assert!(
1183 sup.snapshot()
1184 .iter()
1185 .all(|s| s.name.as_ref() != "normal-exit"),
1186 "entry removed after normal exit"
1187 );
1188
1189 let panic_counter = Arc::new(AtomicU32::new(0));
1191 let pc = Arc::clone(&panic_counter);
1192 sup.spawn(TaskDescriptor {
1193 name: "panic-exit",
1194 restart: RestartPolicy::Restart {
1195 max: 2,
1196 base_delay: Duration::from_millis(10),
1197 },
1198 factory: move || {
1199 let c = Arc::clone(&pc);
1200 async move {
1201 c.fetch_add(1, Ordering::SeqCst);
1202 panic!("test panic");
1203 }
1204 },
1205 });
1206
1207 tokio::time::sleep(Duration::from_millis(500)).await;
1209 assert!(
1210 panic_counter.load(Ordering::SeqCst) >= 3,
1211 "panicking task must restart max times"
1212 );
1213 let snap = sup
1214 .snapshot()
1215 .into_iter()
1216 .find(|s| s.name.as_ref() == "panic-exit");
1217 assert!(
1218 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1219 "task must be Failed after exhausting restarts"
1220 );
1221 }
1222
1223 #[tokio::test]
1224 async fn test_restart_policy() {
1225 let (sup, _cancel) = make_supervisor();
1226
1227 let counter = Arc::new(AtomicU32::new(0));
1228 let counter2 = Arc::clone(&counter);
1229
1230 sup.spawn(TaskDescriptor {
1231 name: "restartable",
1232 restart: RestartPolicy::Restart {
1233 max: 2,
1234 base_delay: Duration::from_millis(10),
1235 },
1236 factory: move || {
1237 let c = Arc::clone(&counter2);
1238 async move {
1239 c.fetch_add(1, Ordering::SeqCst);
1240 panic!("always panic");
1241 }
1242 },
1243 });
1244
1245 tokio::time::sleep(Duration::from_millis(500)).await;
1246
1247 let runs = counter.load(Ordering::SeqCst);
1248 assert!(
1249 runs >= 3,
1250 "expected at least 3 invocations (initial + 2 restarts), got {runs}"
1251 );
1252
1253 let snaps = sup.snapshot();
1254 let snap = snaps.iter().find(|s| s.name.as_ref() == "restartable");
1255 assert!(snap.is_some(), "failed task should remain in registry");
1256 assert!(
1257 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1258 "task should be Failed after exhausting retries"
1259 );
1260 }
1261
1262 #[tokio::test]
1264 async fn test_exponential_backoff() {
1265 let (sup, _cancel) = make_supervisor();
1266
1267 let timestamps = Arc::new(parking_lot::Mutex::new(Vec::<std::time::Instant>::new()));
1268 let ts = Arc::clone(×tamps);
1269
1270 sup.spawn(TaskDescriptor {
1271 name: "backoff-task",
1272 restart: RestartPolicy::Restart {
1273 max: 3,
1274 base_delay: Duration::from_millis(50),
1275 },
1276 factory: move || {
1277 let t = Arc::clone(&ts);
1278 async move {
1279 t.lock().push(std::time::Instant::now());
1280 panic!("always panic");
1281 }
1282 },
1283 });
1284
1285 tokio::time::sleep(Duration::from_millis(800)).await;
1287
1288 let ts = timestamps.lock();
1289 assert!(
1290 ts.len() >= 3,
1291 "expected at least 3 invocations, got {}",
1292 ts.len()
1293 );
1294
1295 if ts.len() >= 3 {
1297 let d1 = ts[1].duration_since(ts[0]);
1298 let d2 = ts[2].duration_since(ts[1]);
1299 assert!(
1301 d2 >= d1.mul_f64(1.5),
1302 "expected exponential backoff: d1={d1:?} d2={d2:?}"
1303 );
1304 }
1305 }
1306
1307 #[tokio::test]
1308 async fn test_graceful_shutdown() {
1309 let (sup, _cancel) = make_supervisor();
1310
1311 for name in ["svc-a", "svc-b", "svc-c"] {
1312 sup.spawn(TaskDescriptor {
1313 name,
1314 restart: RestartPolicy::RunOnce,
1315 factory: || async {
1316 tokio::time::sleep(Duration::from_mins(1)).await;
1317 },
1318 });
1319 }
1320
1321 assert_eq!(sup.active_count(), 3);
1322
1323 tokio::time::timeout(
1324 Duration::from_secs(2),
1325 sup.shutdown_all(Duration::from_secs(1)),
1326 )
1327 .await
1328 .expect("shutdown should complete within timeout");
1329 }
1330
1331 #[tokio::test]
1333 async fn test_force_abort_marks_aborted() {
1334 let cancel = CancellationToken::new();
1335 let sup = TaskSupervisor::new(cancel.clone());
1336
1337 sup.spawn(TaskDescriptor {
1338 name: "stubborn-for-abort",
1339 restart: RestartPolicy::RunOnce,
1340 factory: || async {
1341 std::future::pending::<()>().await;
1343 },
1344 });
1345
1346 sup.shutdown_all(Duration::from_millis(1)).await;
1348
1349 let snaps = sup.snapshot();
1351 if let Some(snap) = snaps
1352 .iter()
1353 .find(|s| s.name.as_ref() == "stubborn-for-abort")
1354 {
1355 assert_eq!(
1356 snap.status,
1357 TaskStatus::Aborted,
1358 "force-aborted task must have Aborted status"
1359 );
1360 }
1361 }
1363
1364 #[tokio::test]
1365 async fn test_registry_snapshot() {
1366 let (sup, _cancel) = make_supervisor();
1367
1368 for name in ["alpha", "beta"] {
1369 sup.spawn(TaskDescriptor {
1370 name,
1371 restart: RestartPolicy::RunOnce,
1372 factory: || async {
1373 tokio::time::sleep(Duration::from_secs(10)).await;
1374 },
1375 });
1376 }
1377
1378 let snaps = sup.snapshot();
1379 assert_eq!(snaps.len(), 2);
1380 let names: Vec<&str> = snaps.iter().map(|s| s.name.as_ref()).collect();
1381 assert!(names.contains(&"alpha"));
1382 assert!(names.contains(&"beta"));
1383 assert!(snaps.iter().all(|s| s.status == TaskStatus::Running));
1384 }
1385
1386 #[tokio::test]
1387 async fn test_blocking_returns_value() {
1388 let (sup, cancel) = make_supervisor();
1389
1390 let handle: BlockingHandle<u32> = sup.spawn_blocking(Arc::from("compute"), || 42_u32);
1391 let result = handle.join().await.expect("should return value");
1392 assert_eq!(result, 42);
1393 cancel.cancel();
1394 }
1395
1396 #[tokio::test]
1397 async fn test_blocking_panic() {
1398 let (sup, _cancel) = make_supervisor();
1399
1400 let handle: BlockingHandle<u32> =
1401 sup.spawn_blocking(Arc::from("panicking-compute"), || panic!("intentional"));
1402 let err = handle
1403 .join()
1404 .await
1405 .expect_err("should return error on panic");
1406 assert_eq!(err, BlockingError::Panicked);
1407 }
1408
1409 #[tokio::test]
1411 async fn test_blocking_registered_in_registry() {
1412 let (sup, cancel) = make_supervisor();
1413
1414 let (tx, rx) = std::sync::mpsc::channel::<()>();
1415 let _handle: BlockingHandle<()> =
1416 sup.spawn_blocking(Arc::from("blocking-task"), move || {
1417 let _ = rx.recv();
1419 });
1420
1421 tokio::time::sleep(Duration::from_millis(10)).await;
1422 assert_eq!(
1423 sup.active_count(),
1424 1,
1425 "blocking task must appear in active_count"
1426 );
1427
1428 let _ = tx.send(());
1429 tokio::time::sleep(Duration::from_millis(100)).await;
1430 assert_eq!(
1431 sup.active_count(),
1432 0,
1433 "blocking task must be removed after completion"
1434 );
1435
1436 cancel.cancel();
1437 }
1438
1439 #[tokio::test]
1441 async fn test_oneshot_registered_in_registry() {
1442 let (sup, cancel) = make_supervisor();
1443
1444 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
1445 let _handle: BlockingHandle<()> =
1446 sup.spawn_oneshot(Arc::from("oneshot-task"), move || async move {
1447 let _ = rx.await;
1448 });
1449
1450 tokio::time::sleep(Duration::from_millis(10)).await;
1451 assert_eq!(
1452 sup.active_count(),
1453 1,
1454 "oneshot task must appear in active_count"
1455 );
1456
1457 let _ = tx.send(());
1458 tokio::time::sleep(Duration::from_millis(50)).await;
1459 assert_eq!(
1460 sup.active_count(),
1461 0,
1462 "oneshot task must be removed after completion"
1463 );
1464
1465 cancel.cancel();
1466 }
1467
1468 #[tokio::test]
1469 async fn test_restart_max_zero() {
1470 let (sup, _cancel) = make_supervisor();
1471
1472 let counter = Arc::new(AtomicU32::new(0));
1473 let counter2 = Arc::clone(&counter);
1474
1475 sup.spawn(TaskDescriptor {
1476 name: "zero-max",
1477 restart: RestartPolicy::Restart {
1478 max: 0,
1479 base_delay: Duration::from_millis(10),
1480 },
1481 factory: move || {
1482 let c = Arc::clone(&counter2);
1483 async move {
1484 c.fetch_add(1, Ordering::SeqCst);
1485 panic!("always panic");
1486 }
1487 },
1488 });
1489
1490 tokio::time::sleep(Duration::from_millis(200)).await;
1491
1492 assert_eq!(
1493 counter.load(Ordering::SeqCst),
1494 1,
1495 "max=0 should not restart"
1496 );
1497
1498 let snaps = sup.snapshot();
1499 let snap = snaps.iter().find(|s| s.name.as_ref() == "zero-max");
1500 assert!(snap.is_some(), "entry should remain as Failed");
1501 assert!(
1502 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1503 "status should be Failed"
1504 );
1505 }
1506
1507 #[tokio::test]
1509 async fn test_concurrent_spawns() {
1510 static NAMES: [&str; 50] = [
1512 "t00", "t01", "t02", "t03", "t04", "t05", "t06", "t07", "t08", "t09", "t10", "t11",
1513 "t12", "t13", "t14", "t15", "t16", "t17", "t18", "t19", "t20", "t21", "t22", "t23",
1514 "t24", "t25", "t26", "t27", "t28", "t29", "t30", "t31", "t32", "t33", "t34", "t35",
1515 "t36", "t37", "t38", "t39", "t40", "t41", "t42", "t43", "t44", "t45", "t46", "t47",
1516 "t48", "t49",
1517 ];
1518 let (sup, cancel) = make_supervisor();
1519
1520 let completed = Arc::new(AtomicU32::new(0));
1521 for name in &NAMES {
1522 let c = Arc::clone(&completed);
1523 sup.spawn(TaskDescriptor {
1524 name,
1525 restart: RestartPolicy::RunOnce,
1526 factory: move || {
1527 let c = Arc::clone(&c);
1528 async move {
1529 c.fetch_add(1, Ordering::SeqCst);
1530 }
1531 },
1532 });
1533 }
1534
1535 tokio::time::timeout(Duration::from_secs(5), async {
1537 loop {
1538 if completed.load(Ordering::SeqCst) == 50 {
1539 break;
1540 }
1541 tokio::time::sleep(Duration::from_millis(10)).await;
1542 }
1543 })
1544 .await
1545 .expect("all 50 tasks should complete");
1546
1547 tokio::time::sleep(Duration::from_millis(100)).await;
1549 assert_eq!(sup.active_count(), 0, "all tasks must be reaped");
1550
1551 cancel.cancel();
1552 }
1553
1554 #[tokio::test]
1555 async fn test_shutdown_timeout_expiry() {
1556 let cancel = CancellationToken::new();
1557 let sup = TaskSupervisor::new(cancel.clone());
1558
1559 sup.spawn(TaskDescriptor {
1560 name: "stubborn",
1561 restart: RestartPolicy::RunOnce,
1562 factory: || async {
1563 tokio::time::sleep(Duration::from_mins(1)).await;
1564 },
1565 });
1566
1567 assert_eq!(sup.active_count(), 1);
1568
1569 tokio::time::timeout(
1570 Duration::from_secs(2),
1571 sup.shutdown_all(Duration::from_millis(50)),
1572 )
1573 .await
1574 .expect("shutdown_all should return even on timeout expiry");
1575
1576 assert!(
1577 cancel.is_cancelled(),
1578 "cancel token must be cancelled after shutdown"
1579 );
1580 }
1581
1582 #[tokio::test]
1583 async fn test_cancellation_token() {
1584 let cancel = CancellationToken::new();
1585 let sup = TaskSupervisor::new(cancel.clone());
1586
1587 assert!(!sup.cancellation_token().is_cancelled());
1588
1589 sup.shutdown_all(Duration::from_millis(100)).await;
1590
1591 assert!(
1592 sup.cancellation_token().is_cancelled(),
1593 "token must be cancelled after shutdown"
1594 );
1595 }
1596
1597 #[tokio::test]
1603 async fn test_shutdown_drains_post_cancel_completions() {
1604 let cancel = CancellationToken::new();
1605 let sup = TaskSupervisor::new(cancel.clone());
1606
1607 for name in [
1608 "loop-1", "loop-2", "loop-3", "loop-4", "loop-5", "loop-6", "loop-7",
1609 ] {
1610 let cancel_inner = cancel.clone();
1611 sup.spawn(TaskDescriptor {
1612 name,
1613 restart: RestartPolicy::RunOnce,
1614 factory: move || {
1615 let c = cancel_inner.clone();
1616 async move {
1617 c.cancelled().await;
1618 for _ in 0..64 {
1620 tokio::task::yield_now().await;
1621 }
1622 }
1623 },
1624 });
1625 }
1626 assert_eq!(sup.active_count(), 7);
1627
1628 sup.shutdown_all(Duration::from_secs(2)).await;
1629
1630 assert_eq!(
1631 sup.active_count(),
1632 0,
1633 "all tasks must be reaped after shutdown (#3161)"
1634 );
1635 }
1636
1637 #[tokio::test]
1638 async fn test_blocking_spawner_task_appears_in_snapshot() {
1639 use crate::BlockingSpawner;
1641
1642 let cancel = CancellationToken::new();
1643 let sup = TaskSupervisor::new(cancel);
1644
1645 let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<()>();
1646 let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>();
1647
1648 let handle = sup.spawn_blocking_named(
1649 Arc::from("chunk_file"),
1650 Box::new(move || {
1651 let _ = ready_tx.send(());
1653 let _ = release_rx.blocking_recv();
1655 }),
1656 );
1657
1658 ready_rx.await.expect("task should start");
1660
1661 let snapshot = sup.snapshot();
1662 assert!(
1663 snapshot.iter().any(|t| t.name.as_ref() == "chunk_file"),
1664 "chunk_file task must appear in supervisor snapshot"
1665 );
1666
1667 let _ = release_tx.send(());
1669 handle.await.expect("task should complete");
1670 }
1671
1672 #[test]
1678 fn test_measure_blocking_emits_metrics() {
1679 use metrics_util::debugging::DebuggingRecorder;
1680
1681 let recorder = DebuggingRecorder::new();
1682 let snapshotter = recorder.snapshotter();
1683
1684 metrics::with_local_recorder(&recorder, || {
1687 measure_blocking("test_task", || std::hint::black_box(42_u64));
1688 });
1689
1690 let snapshot = snapshotter.snapshot();
1691 let metric_names: Vec<String> = snapshot
1692 .into_vec()
1693 .into_iter()
1694 .map(|(k, _, _, _)| k.key().name().to_owned())
1695 .collect();
1696
1697 assert!(
1698 metric_names.iter().any(|n| n == "zeph.task.wall_time_ms"),
1699 "expected zeph.task.wall_time_ms histogram; got: {metric_names:?}"
1700 );
1701 assert!(
1702 metric_names.iter().any(|n| n == "zeph.task.cpu_time_ms"),
1703 "expected zeph.task.cpu_time_ms histogram; got: {metric_names:?}"
1704 );
1705 }
1706
1707 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1713 async fn test_spawn_blocking_semaphore_cap() {
1714 let (sup, _cancel) = make_supervisor();
1715 let concurrent = Arc::new(AtomicU32::new(0));
1716 let max_concurrent = Arc::new(AtomicU32::new(0));
1717 let barrier = Arc::new(std::sync::Barrier::new(1)); let mut handles = Vec::new();
1720 for i in 0u32..16 {
1721 let c = Arc::clone(&concurrent);
1722 let m = Arc::clone(&max_concurrent);
1723 let name: Arc<str> = Arc::from(format!("blocking-{i}").as_str());
1724 let h = sup.spawn_blocking(name, move || {
1725 let prev = c.fetch_add(1, Ordering::SeqCst);
1726 let mut cur_max = m.load(Ordering::SeqCst);
1728 while prev + 1 > cur_max {
1729 match m.compare_exchange(cur_max, prev + 1, Ordering::SeqCst, Ordering::SeqCst)
1730 {
1731 Ok(_) => break,
1732 Err(x) => cur_max = x,
1733 }
1734 }
1735 std::thread::sleep(std::time::Duration::from_millis(20));
1737 c.fetch_sub(1, Ordering::SeqCst);
1738 });
1739 handles.push(h);
1740 }
1741
1742 for h in handles {
1743 h.join().await.expect("blocking task should succeed");
1744 }
1745 drop(barrier);
1746
1747 let observed = max_concurrent.load(Ordering::SeqCst);
1748 assert!(
1749 observed <= 8,
1750 "observed {observed} concurrent blocking tasks; expected ≤ 8 (semaphore cap)"
1751 );
1752 }
1753}