1use std::collections::HashMap;
47use std::future::Future;
48use std::pin::Pin;
49use std::sync::Arc;
50use std::time::{Duration, Instant};
51
52use tokio::sync::{mpsc, oneshot};
53use tokio::task::AbortHandle;
54use tokio_util::sync::CancellationToken;
55use tracing::Instrument as _;
56
57use crate::BlockingSpawner;
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum RestartPolicy {
66 RunOnce,
68 Restart { max: u32, base_delay: Duration },
90}
91
92pub const MAX_RESTART_DELAY: Duration = Duration::from_mins(1);
94
95const SHUTDOWN_DRAIN_TIMEOUT: Duration = Duration::from_secs(5);
100
101pub struct TaskDescriptor<F> {
106 pub name: &'static str,
111 pub restart: RestartPolicy,
113 pub factory: F,
115}
116
117#[derive(Debug, Clone)]
121pub struct TaskHandle {
122 name: &'static str,
123 abort: AbortHandle,
124}
125
126impl TaskHandle {
127 pub fn abort(&self) {
129 tracing::debug!(task.name = self.name, "task aborted via handle");
130 self.abort.abort();
131 }
132
133 #[must_use]
135 pub const fn name(&self) -> &'static str {
136 self.name
137 }
138}
139
140#[derive(Debug, PartialEq, Eq)]
142pub enum BlockingError {
143 Panicked,
145 SupervisorDropped,
147}
148
149impl std::fmt::Display for BlockingError {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 match self {
152 Self::Panicked => write!(f, "supervised blocking task panicked"),
153 Self::SupervisorDropped => write!(f, "supervisor dropped before task completed"),
154 }
155 }
156}
157
158impl std::error::Error for BlockingError {}
159
160pub struct BlockingHandle<R> {
169 rx: oneshot::Receiver<Result<R, BlockingError>>,
170 abort: AbortHandle,
171}
172
173impl<R> BlockingHandle<R> {
174 pub async fn join(self) -> Result<R, BlockingError> {
182 self.rx
183 .await
184 .unwrap_or(Err(BlockingError::SupervisorDropped))
185 }
186
187 pub fn try_join(mut self) -> Result<Result<R, BlockingError>, Self> {
219 match self.rx.try_recv() {
220 Ok(result) => Ok(result),
221 Err(tokio::sync::oneshot::error::TryRecvError::Empty) => Err(self),
222 Err(tokio::sync::oneshot::error::TryRecvError::Closed) => {
223 Ok(Err(BlockingError::SupervisorDropped))
224 }
225 }
226 }
227
228 pub fn abort(&self) {
230 self.abort.abort();
231 }
232}
233
234#[derive(Debug, Clone, PartialEq, Eq)]
236pub enum TaskStatus {
237 Running,
239 Restarting { attempt: u32, max: u32 },
241 Completed,
243 Aborted,
245 Failed { reason: String },
247}
248
249#[derive(Debug, Clone)]
251pub struct TaskSnapshot {
261 pub name: Arc<str>,
263 pub status: TaskStatus,
265 pub started_at: Instant,
267 pub restart_count: u32,
269}
270
271type BoxFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
274type BoxFactory = Box<dyn Fn() -> BoxFuture + Send + Sync>;
275
276struct TaskEntry {
277 name: Arc<str>,
278 status: TaskStatus,
279 started_at: Instant,
280 restart_count: u32,
281 restart_policy: RestartPolicy,
282 abort_handle: AbortHandle,
283 factory: Option<BoxFactory>,
285}
286
287#[derive(Debug, Clone, Copy, PartialEq, Eq)]
289enum CompletionKind {
290 Normal,
292 Panicked,
294 Cancelled,
296}
297
298struct Completion {
299 name: Arc<str>,
300 kind: CompletionKind,
301}
302
303struct SupervisorState {
304 tasks: HashMap<Arc<str>, TaskEntry>,
305}
306
307struct Inner {
308 state: parking_lot::Mutex<SupervisorState>,
309 completion_tx: mpsc::UnboundedSender<Completion>,
313 cancel: CancellationToken,
314 blocking_semaphore: Arc<tokio::sync::Semaphore>,
317 shutdown_notify: Arc<tokio::sync::Notify>,
319}
320
321#[derive(Clone)]
357pub struct TaskSupervisor {
358 inner: Arc<Inner>,
359}
360
361impl TaskSupervisor {
362 #[must_use]
374 pub fn new(cancel: CancellationToken) -> Self {
375 let (completion_tx, completion_rx) = mpsc::unbounded_channel();
379 let inner = Arc::new(Inner {
380 state: parking_lot::Mutex::new(SupervisorState {
381 tasks: HashMap::new(),
382 }),
383 completion_tx,
384 cancel: cancel.clone(),
385 blocking_semaphore: Arc::new(tokio::sync::Semaphore::new(8)),
386 shutdown_notify: Arc::new(tokio::sync::Notify::new()),
387 });
388
389 if tokio::runtime::Handle::try_current().is_ok() {
394 Self::start_reap_driver(Arc::clone(&inner), completion_rx, cancel);
395 }
396
397 Self { inner }
398 }
399
400 pub fn spawn<F, Fut>(&self, desc: TaskDescriptor<F>) -> TaskHandle
425 where
426 F: Fn() -> Fut + Send + Sync + 'static,
427 Fut: Future<Output = ()> + Send + 'static,
428 {
429 let factory: BoxFactory = Box::new(move || Box::pin((desc.factory)()));
430 let cancel = self.inner.cancel.clone();
431 let completion_tx = self.inner.completion_tx.clone();
432 let name: Arc<str> = Arc::from(desc.name);
433
434 let (abort_handle, jh) = Self::do_spawn(desc.name, &factory, cancel);
435 Self::wire_completion_reporter(Arc::clone(&name), jh, completion_tx);
436
437 let entry = TaskEntry {
438 name: Arc::clone(&name),
439 status: TaskStatus::Running,
440 started_at: Instant::now(),
441 restart_count: 0,
442 restart_policy: desc.restart,
443 abort_handle: abort_handle.clone(),
444 factory: match desc.restart {
445 RestartPolicy::RunOnce => None,
446 RestartPolicy::Restart { .. } => Some(factory),
447 },
448 };
449
450 {
451 let mut state = self.inner.state.lock();
452 if let Some(old) = state.tasks.remove(&name) {
453 old.abort_handle.abort();
454 }
455 state.tasks.insert(Arc::clone(&name), entry);
456 }
457
458 TaskHandle {
459 name: desc.name,
460 abort: abort_handle,
461 }
462 }
463
464 #[allow(clippy::needless_pass_by_value)] pub fn spawn_blocking<F, R>(&self, name: Arc<str>, f: F) -> BlockingHandle<R>
509 where
510 F: FnOnce() -> R + Send + 'static,
511 R: Send + 'static,
512 {
513 let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
514 let span = tracing::info_span!(
515 "supervised_blocking_task",
516 task.name = %name,
517 task.wall_time_ms = tracing::field::Empty,
518 task.cpu_time_ms = tracing::field::Empty,
519 );
520
521 let semaphore = Arc::clone(&self.inner.blocking_semaphore);
522 let inner = Arc::clone(&self.inner);
523 let name_clone = Arc::clone(&name);
524 let completion_tx = self.inner.completion_tx.clone();
525
526 let outer = tokio::spawn(async move {
529 let _permit = semaphore
530 .acquire_owned()
531 .await
532 .expect("blocking semaphore closed");
533
534 let name_for_measure = Arc::clone(&name_clone);
535 let join_handle = tokio::task::spawn_blocking(move || {
536 let _enter = span.enter();
537 measure_blocking(&name_for_measure, f)
538 });
539 let abort = join_handle.abort_handle();
540
541 {
543 let mut state = inner.state.lock();
544 if let Some(entry) = state.tasks.get_mut(&name_clone) {
545 entry.abort_handle = abort;
546 }
547 }
548
549 let kind = match join_handle.await {
550 Ok(val) => {
551 let _ = tx.send(Ok(val));
552 CompletionKind::Normal
553 }
554 Err(e) if e.is_panic() => {
555 let _ = tx.send(Err(BlockingError::Panicked));
556 CompletionKind::Panicked
557 }
558 Err(_) => {
559 CompletionKind::Cancelled
561 }
562 };
563 let _ = completion_tx.send(Completion {
565 name: name_clone,
566 kind,
567 });
568 });
569 let abort = outer.abort_handle();
570
571 {
573 let mut state = self.inner.state.lock();
574 if let Some(old) = state.tasks.remove(&name) {
575 old.abort_handle.abort();
576 }
577 state.tasks.insert(
578 Arc::clone(&name),
579 TaskEntry {
580 name: Arc::clone(&name),
581 status: TaskStatus::Running,
582 started_at: Instant::now(),
583 restart_count: 0,
584 restart_policy: RestartPolicy::RunOnce,
585 abort_handle: abort.clone(),
586 factory: None,
587 },
588 );
589 }
590
591 BlockingHandle { rx, abort }
592 }
593
594 pub fn spawn_oneshot<F, Fut, R>(&self, name: Arc<str>, factory: F) -> BlockingHandle<R>
622 where
623 F: FnOnce() -> Fut + Send + 'static,
624 Fut: Future<Output = R> + Send + 'static,
625 R: Send + 'static,
626 {
627 let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
628 let cancel = self.inner.cancel.clone();
629 let span = tracing::info_span!("supervised_task", task.name = %name);
630 let join_handle: tokio::task::JoinHandle<Option<R>> = tokio::spawn(
631 async move {
632 let fut = factory();
633 tokio::select! {
634 result = fut => Some(result),
635 () = cancel.cancelled() => None,
636 }
637 }
638 .instrument(span),
639 );
640 let abort = join_handle.abort_handle();
641
642 {
643 let mut state = self.inner.state.lock();
644 if let Some(old) = state.tasks.remove(&name) {
645 old.abort_handle.abort();
646 }
647 state.tasks.insert(
648 Arc::clone(&name),
649 TaskEntry {
650 name: Arc::clone(&name),
651 status: TaskStatus::Running,
652 started_at: Instant::now(),
653 restart_count: 0,
654 restart_policy: RestartPolicy::RunOnce,
655 abort_handle: abort.clone(),
656 factory: None,
657 },
658 );
659 }
660
661 let completion_tx = self.inner.completion_tx.clone();
662 tokio::spawn(async move {
663 let kind = match join_handle.await {
664 Ok(Some(val)) => {
665 let _ = tx.send(Ok(val));
666 CompletionKind::Normal
667 }
668 Err(e) if e.is_panic() => {
669 let _ = tx.send(Err(BlockingError::Panicked));
670 CompletionKind::Panicked
671 }
672 _ => CompletionKind::Cancelled,
673 };
674 let _ = completion_tx.send(Completion { name, kind });
675 });
676 BlockingHandle { rx, abort }
677 }
678
679 pub fn abort(&self, name: &'static str) {
681 let state = self.inner.state.lock();
682 let key: Arc<str> = Arc::from(name);
683 if let Some(entry) = state.tasks.get(&key) {
684 entry.abort_handle.abort();
685 tracing::debug!(task.name = name, "task aborted via supervisor");
686 }
687 }
688
689 pub async fn shutdown_all(&self, timeout: Duration) {
702 self.inner.cancel.cancel();
703 let sleep = tokio::time::sleep(timeout);
704 tokio::pin!(sleep);
705 loop {
706 let active = self.active_count();
707 if active == 0 {
708 break;
709 }
710 let notified = self.inner.shutdown_notify.notified();
713 tokio::select! {
714 biased;
715 () = notified => {
716 }
718 () = &mut sleep => {
719 let mut remaining_names: Vec<Arc<str>> = Vec::new();
720 {
721 let mut state = self.inner.state.lock();
722 for entry in state.tasks.values_mut() {
723 if matches!(
724 entry.status,
725 TaskStatus::Running | TaskStatus::Restarting { .. }
726 ) {
727 remaining_names.push(Arc::clone(&entry.name));
728 entry.abort_handle.abort();
729 entry.status = TaskStatus::Aborted;
730 }
731 }
732 }
733 tracing::warn!(
734 remaining = active,
735 tasks = ?remaining_names,
736 "shutdown timeout — aborting remaining tasks"
737 );
738 break;
739 }
740 }
741 }
742 }
743
744 #[must_use]
749 pub fn snapshot(&self) -> Vec<TaskSnapshot> {
750 let state = self.inner.state.lock();
751 let mut snaps: Vec<TaskSnapshot> = state
752 .tasks
753 .values()
754 .map(|e| TaskSnapshot {
755 name: Arc::clone(&e.name),
756 status: e.status.clone(),
757 started_at: e.started_at,
758 restart_count: e.restart_count,
759 })
760 .collect();
761 snaps.sort_by_key(|s| s.started_at);
762 snaps
763 }
764
765 #[must_use]
767 pub fn active_count(&self) -> usize {
768 let state = self.inner.state.lock();
769 state
770 .tasks
771 .values()
772 .filter(|e| {
773 matches!(
774 e.status,
775 TaskStatus::Running | TaskStatus::Restarting { .. }
776 )
777 })
778 .count()
779 }
780
781 #[must_use]
785 pub fn cancellation_token(&self) -> CancellationToken {
786 self.inner.cancel.clone()
787 }
788
789 fn do_spawn(
793 name: &'static str,
794 factory: &BoxFactory,
795 cancel: CancellationToken,
796 ) -> (AbortHandle, tokio::task::JoinHandle<()>) {
797 let fut = factory();
798 let span = tracing::info_span!("supervised_task", task.name = name);
799 let jh = tokio::spawn(
800 async move {
801 tokio::select! {
802 () = fut => {},
803 () = cancel.cancelled() => {},
804 }
805 }
806 .instrument(span),
807 );
808 let abort = jh.abort_handle();
809 (abort, jh)
810 }
811
812 fn wire_completion_reporter(
814 name: Arc<str>,
815 jh: tokio::task::JoinHandle<()>,
816 completion_tx: mpsc::UnboundedSender<Completion>,
817 ) {
818 tokio::spawn(async move {
819 let kind = match jh.await {
820 Ok(()) => CompletionKind::Normal,
821 Err(e) if e.is_panic() => CompletionKind::Panicked,
822 Err(_) => CompletionKind::Cancelled,
823 };
824 let _ = completion_tx.send(Completion { name, kind });
825 });
826 }
827
828 fn start_reap_driver(
835 inner: Arc<Inner>,
836 mut completion_rx: mpsc::UnboundedReceiver<Completion>,
837 cancel: CancellationToken,
838 ) {
839 tokio::spawn(async move {
840 loop {
842 tokio::select! {
843 biased;
844 Some(completion) = completion_rx.recv() => {
845 Self::handle_completion(&inner, completion).await;
846 }
847 () = cancel.cancelled() => break,
848 }
849 }
850
851 let drain_deadline = tokio::time::Instant::now() + SHUTDOWN_DRAIN_TIMEOUT;
856 let active = Self::has_active_tasks(&inner);
857 tracing::debug!(active, "reap driver entered post-cancel drain phase");
858 loop {
859 if !Self::has_active_tasks(&inner) {
860 break;
861 }
862 let remaining =
863 drain_deadline.saturating_duration_since(tokio::time::Instant::now());
864 if remaining.is_zero() {
865 break;
866 }
867 match tokio::time::timeout(remaining, completion_rx.recv()).await {
868 Ok(Some(completion)) => Self::handle_completion(&inner, completion).await,
869 Ok(None) | Err(_) => break,
871 }
872 }
873 tracing::debug!(
874 active = Self::has_active_tasks(&inner),
875 "reap driver drain phase complete"
876 );
877 });
878 }
879
880 fn has_active_tasks(inner: &Arc<Inner>) -> bool {
882 let state = inner.state.lock();
883 state.tasks.values().any(|e| {
884 matches!(
885 e.status,
886 TaskStatus::Running | TaskStatus::Restarting { .. }
887 )
888 })
889 }
890
891 async fn handle_completion(inner: &Arc<Inner>, completion: Completion) {
897 if inner.cancel.is_cancelled() {
901 {
902 let mut state = inner.state.lock();
903 state.tasks.remove(&completion.name);
904 }
905 inner.shutdown_notify.notify_waiters();
906 return;
907 }
908
909 let Some((attempt, max, delay)) = Self::classify_completion(inner, &completion) else {
910 inner.shutdown_notify.notify_waiters();
913 return;
914 };
915
916 tracing::warn!(
917 task.name = %completion.name,
918 attempt,
919 max,
920 delay_ms = delay.as_millis(),
921 "restarting supervised task"
922 );
923
924 if !delay.is_zero() {
925 tokio::time::sleep(delay).await;
926 }
927
928 Self::do_restart(inner, &completion.name, attempt);
929 }
930
931 fn classify_completion(
935 inner: &Arc<Inner>,
936 completion: &Completion,
937 ) -> Option<(u32, u32, Duration)> {
938 let mut state = inner.state.lock();
939 let entry = state.tasks.get_mut(&completion.name)?;
940
941 match completion.kind {
942 CompletionKind::Panicked => {
943 tracing::warn!(task.name = %completion.name, "supervised task panicked");
944 }
945 CompletionKind::Normal => {
946 tracing::info!(task.name = %completion.name, "supervised task completed");
947 }
948 CompletionKind::Cancelled => {
949 tracing::debug!(task.name = %completion.name, "supervised task cancelled");
950 }
951 }
952
953 match entry.restart_policy {
954 RestartPolicy::RunOnce => {
955 entry.status = TaskStatus::Completed;
956 state.tasks.remove(&completion.name);
957 None
958 }
959 RestartPolicy::Restart { max, base_delay } => {
960 if completion.kind != CompletionKind::Panicked {
962 entry.status = TaskStatus::Completed;
963 state.tasks.remove(&completion.name);
964 return None;
965 }
966 if entry.restart_count >= max {
967 let reason = format!("panicked after {max} restart(s)");
968 tracing::error!(
969 task.name = %completion.name,
970 attempts = max,
971 "task failed permanently"
972 );
973 entry.status = TaskStatus::Failed { reason };
974 None
975 } else {
976 let attempt = entry.restart_count + 1;
977 entry.status = TaskStatus::Restarting { attempt, max };
978 let multiplier = 1_u32
980 .checked_shl(attempt.saturating_sub(1))
981 .unwrap_or(u32::MAX);
982 let delay = base_delay.saturating_mul(multiplier).min(MAX_RESTART_DELAY);
983 Some((attempt, max, delay))
984 }
985 }
986 }
987 }
989
990 fn do_restart(inner: &Arc<Inner>, name: &Arc<str>, attempt: u32) {
992 let spawn_params = {
993 let mut state = inner.state.lock();
994 let Some(entry) = state.tasks.get_mut(name.as_ref()) else {
995 tracing::debug!(
996 task.name = %name,
997 "task removed during restart delay — skipping"
998 );
999 return;
1000 };
1001 if !matches!(entry.status, TaskStatus::Restarting { .. }) {
1002 return;
1003 }
1004 let Some(factory) = &entry.factory else {
1005 return;
1006 };
1007 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(factory)) {
1010 Err(_) => {
1011 let reason = format!("factory panicked on restart attempt {attempt}");
1012 tracing::error!(task.name = %name, attempt, "factory panicked during restart");
1013 entry.status = TaskStatus::Failed { reason };
1014 None
1015 }
1016 Ok(fut) => Some((
1017 fut,
1018 inner.cancel.clone(),
1019 inner.completion_tx.clone(),
1020 name.clone(),
1021 )),
1022 }
1023 };
1025
1026 let Some((fut, cancel, completion_tx, name)) = spawn_params else {
1027 return;
1028 };
1029
1030 let span = tracing::info_span!("supervised_task", task.name = %name);
1031 let jh = tokio::spawn(
1032 async move {
1033 tokio::select! {
1034 () = fut => {},
1035 () = cancel.cancelled() => {},
1036 }
1037 }
1038 .instrument(span),
1039 );
1040 let new_abort = jh.abort_handle();
1041
1042 {
1043 let mut state = inner.state.lock();
1044 if let Some(entry) = state.tasks.get_mut(name.as_ref()) {
1045 entry.restart_count = attempt;
1046 entry.status = TaskStatus::Running;
1047 entry.abort_handle = new_abort;
1048 }
1049 }
1050
1051 Self::wire_completion_reporter(name, jh, completion_tx);
1052 }
1053}
1054
1055#[inline]
1059fn measure_blocking<F, R>(name: &str, f: F) -> R
1060where
1061 F: FnOnce() -> R,
1062{
1063 use cpu_time::ThreadTime;
1064 let wall_start = std::time::Instant::now();
1065 let cpu_start = ThreadTime::now();
1066 let result = f();
1067 let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
1068 let cpu_ms = cpu_start.elapsed().as_secs_f64() * 1000.0;
1069 metrics::histogram!("zeph.task.wall_time_ms", "task" => name.to_owned()).record(wall_ms);
1070 metrics::histogram!("zeph.task.cpu_time_ms", "task" => name.to_owned()).record(cpu_ms);
1071 tracing::Span::current().record("task.wall_time_ms", wall_ms);
1072 tracing::Span::current().record("task.cpu_time_ms", cpu_ms);
1073 result
1074}
1075
1076impl BlockingSpawner for TaskSupervisor {
1079 fn spawn_blocking_named(
1085 &self,
1086 name: Arc<str>,
1087 f: Box<dyn FnOnce() + Send + 'static>,
1088 ) -> tokio::task::JoinHandle<()> {
1089 let handle = self.spawn_blocking(Arc::clone(&name), f);
1090 tokio::spawn(async move {
1091 if let Err(e) = handle.join().await {
1092 tracing::error!(task.name = %name, error = %e, "supervised blocking task failed");
1093 }
1094 })
1095 }
1096}
1097
1098#[cfg(test)]
1101mod tests {
1102 use std::sync::Arc;
1103 use std::sync::atomic::{AtomicU32, Ordering};
1104 use std::time::Duration;
1105
1106 use tokio_util::sync::CancellationToken;
1107
1108 use super::*;
1109
1110 fn make_supervisor() -> (TaskSupervisor, CancellationToken) {
1111 let cancel = CancellationToken::new();
1112 let sup = TaskSupervisor::new(cancel.clone());
1113 (sup, cancel)
1114 }
1115
1116 #[tokio::test]
1117 async fn test_spawn_and_complete() {
1118 let (sup, _cancel) = make_supervisor();
1119
1120 let done = Arc::new(tokio::sync::Notify::new());
1121 let done2 = Arc::clone(&done);
1122
1123 sup.spawn(TaskDescriptor {
1124 name: "simple",
1125 restart: RestartPolicy::RunOnce,
1126 factory: move || {
1127 let d = Arc::clone(&done2);
1128 async move {
1129 d.notify_one();
1130 }
1131 },
1132 });
1133
1134 tokio::time::timeout(Duration::from_secs(2), done.notified())
1135 .await
1136 .expect("task should complete");
1137
1138 tokio::time::sleep(Duration::from_millis(50)).await;
1139 assert_eq!(
1140 sup.active_count(),
1141 0,
1142 "RunOnce task should be removed after completion"
1143 );
1144 }
1145
1146 #[tokio::test]
1147 async fn test_panic_capture() {
1148 let (sup, _cancel) = make_supervisor();
1149
1150 sup.spawn(TaskDescriptor {
1151 name: "panicking",
1152 restart: RestartPolicy::RunOnce,
1153 factory: || async { panic!("intentional test panic") },
1154 });
1155
1156 tokio::time::sleep(Duration::from_millis(200)).await;
1157
1158 let snaps = sup.snapshot();
1159 assert!(
1160 snaps.iter().all(|s| s.name.as_ref() != "panicking"),
1161 "entry should be reaped"
1162 );
1163 assert_eq!(
1164 sup.active_count(),
1165 0,
1166 "active count must be 0 after RunOnce panic"
1167 );
1168 }
1169
1170 #[tokio::test]
1173 async fn test_restart_only_on_panic() {
1174 let (sup, _cancel) = make_supervisor();
1175
1176 let normal_counter = Arc::new(AtomicU32::new(0));
1178 let nc = Arc::clone(&normal_counter);
1179 sup.spawn(TaskDescriptor {
1180 name: "normal-exit",
1181 restart: RestartPolicy::Restart {
1182 max: 3,
1183 base_delay: Duration::from_millis(10),
1184 },
1185 factory: move || {
1186 let c = Arc::clone(&nc);
1187 async move {
1188 c.fetch_add(1, Ordering::SeqCst);
1189 }
1191 },
1192 });
1193
1194 tokio::time::sleep(Duration::from_millis(300)).await;
1195 assert_eq!(
1196 normal_counter.load(Ordering::SeqCst),
1197 1,
1198 "normal exit must not restart"
1199 );
1200 assert!(
1201 sup.snapshot()
1202 .iter()
1203 .all(|s| s.name.as_ref() != "normal-exit"),
1204 "entry removed after normal exit"
1205 );
1206
1207 let panic_counter = Arc::new(AtomicU32::new(0));
1209 let pc = Arc::clone(&panic_counter);
1210 sup.spawn(TaskDescriptor {
1211 name: "panic-exit",
1212 restart: RestartPolicy::Restart {
1213 max: 2,
1214 base_delay: Duration::from_millis(10),
1215 },
1216 factory: move || {
1217 let c = Arc::clone(&pc);
1218 async move {
1219 c.fetch_add(1, Ordering::SeqCst);
1220 panic!("test panic");
1221 }
1222 },
1223 });
1224
1225 tokio::time::sleep(Duration::from_millis(500)).await;
1227 assert!(
1228 panic_counter.load(Ordering::SeqCst) >= 3,
1229 "panicking task must restart max times"
1230 );
1231 let snap = sup
1232 .snapshot()
1233 .into_iter()
1234 .find(|s| s.name.as_ref() == "panic-exit");
1235 assert!(
1236 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1237 "task must be Failed after exhausting restarts"
1238 );
1239 }
1240
1241 #[tokio::test]
1242 async fn test_restart_policy() {
1243 let (sup, _cancel) = make_supervisor();
1244
1245 let counter = Arc::new(AtomicU32::new(0));
1246 let counter2 = Arc::clone(&counter);
1247
1248 sup.spawn(TaskDescriptor {
1249 name: "restartable",
1250 restart: RestartPolicy::Restart {
1251 max: 2,
1252 base_delay: Duration::from_millis(10),
1253 },
1254 factory: move || {
1255 let c = Arc::clone(&counter2);
1256 async move {
1257 c.fetch_add(1, Ordering::SeqCst);
1258 panic!("always panic");
1259 }
1260 },
1261 });
1262
1263 tokio::time::sleep(Duration::from_millis(500)).await;
1264
1265 let runs = counter.load(Ordering::SeqCst);
1266 assert!(
1267 runs >= 3,
1268 "expected at least 3 invocations (initial + 2 restarts), got {runs}"
1269 );
1270
1271 let snaps = sup.snapshot();
1272 let snap = snaps.iter().find(|s| s.name.as_ref() == "restartable");
1273 assert!(snap.is_some(), "failed task should remain in registry");
1274 assert!(
1275 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1276 "task should be Failed after exhausting retries"
1277 );
1278 }
1279
1280 #[tokio::test]
1282 async fn test_exponential_backoff() {
1283 let (sup, _cancel) = make_supervisor();
1284
1285 let timestamps = Arc::new(parking_lot::Mutex::new(Vec::<std::time::Instant>::new()));
1286 let ts = Arc::clone(×tamps);
1287
1288 sup.spawn(TaskDescriptor {
1289 name: "backoff-task",
1290 restart: RestartPolicy::Restart {
1291 max: 3,
1292 base_delay: Duration::from_millis(50),
1293 },
1294 factory: move || {
1295 let t = Arc::clone(&ts);
1296 async move {
1297 t.lock().push(std::time::Instant::now());
1298 panic!("always panic");
1299 }
1300 },
1301 });
1302
1303 tokio::time::sleep(Duration::from_millis(800)).await;
1305
1306 let ts = timestamps.lock();
1307 assert!(
1308 ts.len() >= 3,
1309 "expected at least 3 invocations, got {}",
1310 ts.len()
1311 );
1312
1313 if ts.len() >= 3 {
1315 let d1 = ts[1].duration_since(ts[0]);
1316 let d2 = ts[2].duration_since(ts[1]);
1317 assert!(
1319 d2 >= d1.mul_f64(1.5),
1320 "expected exponential backoff: d1={d1:?} d2={d2:?}"
1321 );
1322 }
1323 }
1324
1325 #[tokio::test]
1326 async fn test_graceful_shutdown() {
1327 let (sup, _cancel) = make_supervisor();
1328
1329 for name in ["svc-a", "svc-b", "svc-c"] {
1330 sup.spawn(TaskDescriptor {
1331 name,
1332 restart: RestartPolicy::RunOnce,
1333 factory: || async {
1334 tokio::time::sleep(Duration::from_mins(1)).await;
1335 },
1336 });
1337 }
1338
1339 assert_eq!(sup.active_count(), 3);
1340
1341 tokio::time::timeout(
1342 Duration::from_secs(2),
1343 sup.shutdown_all(Duration::from_secs(1)),
1344 )
1345 .await
1346 .expect("shutdown should complete within timeout");
1347 }
1348
1349 #[tokio::test]
1351 async fn test_force_abort_marks_aborted() {
1352 let cancel = CancellationToken::new();
1353 let sup = TaskSupervisor::new(cancel.clone());
1354
1355 sup.spawn(TaskDescriptor {
1356 name: "stubborn-for-abort",
1357 restart: RestartPolicy::RunOnce,
1358 factory: || async {
1359 std::future::pending::<()>().await;
1361 },
1362 });
1363
1364 sup.shutdown_all(Duration::from_millis(1)).await;
1366
1367 let snaps = sup.snapshot();
1369 if let Some(snap) = snaps
1370 .iter()
1371 .find(|s| s.name.as_ref() == "stubborn-for-abort")
1372 {
1373 assert_eq!(
1374 snap.status,
1375 TaskStatus::Aborted,
1376 "force-aborted task must have Aborted status"
1377 );
1378 }
1379 }
1381
1382 #[tokio::test]
1383 async fn test_registry_snapshot() {
1384 let (sup, _cancel) = make_supervisor();
1385
1386 for name in ["alpha", "beta"] {
1387 sup.spawn(TaskDescriptor {
1388 name,
1389 restart: RestartPolicy::RunOnce,
1390 factory: || async {
1391 tokio::time::sleep(Duration::from_secs(10)).await;
1392 },
1393 });
1394 }
1395
1396 let snaps = sup.snapshot();
1397 assert_eq!(snaps.len(), 2);
1398 let names: Vec<&str> = snaps.iter().map(|s| s.name.as_ref()).collect();
1399 assert!(names.contains(&"alpha"));
1400 assert!(names.contains(&"beta"));
1401 assert!(snaps.iter().all(|s| s.status == TaskStatus::Running));
1402 }
1403
1404 #[tokio::test]
1405 async fn test_blocking_returns_value() {
1406 let (sup, cancel) = make_supervisor();
1407
1408 let handle: BlockingHandle<u32> = sup.spawn_blocking(Arc::from("compute"), || 42_u32);
1409 let result = handle.join().await.expect("should return value");
1410 assert_eq!(result, 42);
1411 cancel.cancel();
1412 }
1413
1414 #[tokio::test]
1415 async fn test_blocking_panic() {
1416 let (sup, _cancel) = make_supervisor();
1417
1418 let handle: BlockingHandle<u32> =
1419 sup.spawn_blocking(Arc::from("panicking-compute"), || panic!("intentional"));
1420 let err = handle
1421 .join()
1422 .await
1423 .expect_err("should return error on panic");
1424 assert_eq!(err, BlockingError::Panicked);
1425 }
1426
1427 #[tokio::test]
1429 async fn test_blocking_registered_in_registry() {
1430 let (sup, cancel) = make_supervisor();
1431
1432 let (tx, rx) = std::sync::mpsc::channel::<()>();
1433 let _handle: BlockingHandle<()> =
1434 sup.spawn_blocking(Arc::from("blocking-task"), move || {
1435 let _ = rx.recv();
1437 });
1438
1439 tokio::time::sleep(Duration::from_millis(10)).await;
1440 assert_eq!(
1441 sup.active_count(),
1442 1,
1443 "blocking task must appear in active_count"
1444 );
1445
1446 let _ = tx.send(());
1447 tokio::time::sleep(Duration::from_millis(100)).await;
1448 assert_eq!(
1449 sup.active_count(),
1450 0,
1451 "blocking task must be removed after completion"
1452 );
1453
1454 cancel.cancel();
1455 }
1456
1457 #[tokio::test]
1459 async fn test_oneshot_registered_in_registry() {
1460 let (sup, cancel) = make_supervisor();
1461
1462 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
1463 let _handle: BlockingHandle<()> =
1464 sup.spawn_oneshot(Arc::from("oneshot-task"), move || async move {
1465 let _ = rx.await;
1466 });
1467
1468 tokio::time::sleep(Duration::from_millis(10)).await;
1469 assert_eq!(
1470 sup.active_count(),
1471 1,
1472 "oneshot task must appear in active_count"
1473 );
1474
1475 let _ = tx.send(());
1476 tokio::time::sleep(Duration::from_millis(50)).await;
1477 assert_eq!(
1478 sup.active_count(),
1479 0,
1480 "oneshot task must be removed after completion"
1481 );
1482
1483 cancel.cancel();
1484 }
1485
1486 #[tokio::test]
1487 async fn test_restart_max_zero() {
1488 let (sup, _cancel) = make_supervisor();
1489
1490 let counter = Arc::new(AtomicU32::new(0));
1491 let counter2 = Arc::clone(&counter);
1492
1493 sup.spawn(TaskDescriptor {
1494 name: "zero-max",
1495 restart: RestartPolicy::Restart {
1496 max: 0,
1497 base_delay: Duration::from_millis(10),
1498 },
1499 factory: move || {
1500 let c = Arc::clone(&counter2);
1501 async move {
1502 c.fetch_add(1, Ordering::SeqCst);
1503 panic!("always panic");
1504 }
1505 },
1506 });
1507
1508 tokio::time::sleep(Duration::from_millis(200)).await;
1509
1510 assert_eq!(
1511 counter.load(Ordering::SeqCst),
1512 1,
1513 "max=0 should not restart"
1514 );
1515
1516 let snaps = sup.snapshot();
1517 let snap = snaps.iter().find(|s| s.name.as_ref() == "zero-max");
1518 assert!(snap.is_some(), "entry should remain as Failed");
1519 assert!(
1520 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1521 "status should be Failed"
1522 );
1523 }
1524
1525 #[tokio::test]
1527 async fn test_concurrent_spawns() {
1528 static NAMES: [&str; 50] = [
1530 "t00", "t01", "t02", "t03", "t04", "t05", "t06", "t07", "t08", "t09", "t10", "t11",
1531 "t12", "t13", "t14", "t15", "t16", "t17", "t18", "t19", "t20", "t21", "t22", "t23",
1532 "t24", "t25", "t26", "t27", "t28", "t29", "t30", "t31", "t32", "t33", "t34", "t35",
1533 "t36", "t37", "t38", "t39", "t40", "t41", "t42", "t43", "t44", "t45", "t46", "t47",
1534 "t48", "t49",
1535 ];
1536 let (sup, cancel) = make_supervisor();
1537
1538 let completed = Arc::new(AtomicU32::new(0));
1539 for name in &NAMES {
1540 let c = Arc::clone(&completed);
1541 sup.spawn(TaskDescriptor {
1542 name,
1543 restart: RestartPolicy::RunOnce,
1544 factory: move || {
1545 let c = Arc::clone(&c);
1546 async move {
1547 c.fetch_add(1, Ordering::SeqCst);
1548 }
1549 },
1550 });
1551 }
1552
1553 tokio::time::timeout(Duration::from_secs(5), async {
1555 loop {
1556 if completed.load(Ordering::SeqCst) == 50 {
1557 break;
1558 }
1559 tokio::time::sleep(Duration::from_millis(10)).await;
1560 }
1561 })
1562 .await
1563 .expect("all 50 tasks should complete");
1564
1565 tokio::time::sleep(Duration::from_millis(100)).await;
1567 assert_eq!(sup.active_count(), 0, "all tasks must be reaped");
1568
1569 cancel.cancel();
1570 }
1571
1572 #[tokio::test]
1573 async fn test_shutdown_timeout_expiry() {
1574 let cancel = CancellationToken::new();
1575 let sup = TaskSupervisor::new(cancel.clone());
1576
1577 sup.spawn(TaskDescriptor {
1578 name: "stubborn",
1579 restart: RestartPolicy::RunOnce,
1580 factory: || async {
1581 tokio::time::sleep(Duration::from_mins(1)).await;
1582 },
1583 });
1584
1585 assert_eq!(sup.active_count(), 1);
1586
1587 tokio::time::timeout(
1588 Duration::from_secs(2),
1589 sup.shutdown_all(Duration::from_millis(50)),
1590 )
1591 .await
1592 .expect("shutdown_all should return even on timeout expiry");
1593
1594 assert!(
1595 cancel.is_cancelled(),
1596 "cancel token must be cancelled after shutdown"
1597 );
1598 }
1599
1600 #[tokio::test]
1601 async fn test_cancellation_token() {
1602 let cancel = CancellationToken::new();
1603 let sup = TaskSupervisor::new(cancel.clone());
1604
1605 assert!(!sup.cancellation_token().is_cancelled());
1606
1607 sup.shutdown_all(Duration::from_millis(100)).await;
1608
1609 assert!(
1610 sup.cancellation_token().is_cancelled(),
1611 "token must be cancelled after shutdown"
1612 );
1613 }
1614
1615 #[tokio::test]
1621 async fn test_shutdown_drains_post_cancel_completions() {
1622 let cancel = CancellationToken::new();
1623 let sup = TaskSupervisor::new(cancel.clone());
1624
1625 for name in [
1626 "loop-1", "loop-2", "loop-3", "loop-4", "loop-5", "loop-6", "loop-7",
1627 ] {
1628 let cancel_inner = cancel.clone();
1629 sup.spawn(TaskDescriptor {
1630 name,
1631 restart: RestartPolicy::RunOnce,
1632 factory: move || {
1633 let c = cancel_inner.clone();
1634 async move {
1635 c.cancelled().await;
1636 for _ in 0..64 {
1638 tokio::task::yield_now().await;
1639 }
1640 }
1641 },
1642 });
1643 }
1644 assert_eq!(sup.active_count(), 7);
1645
1646 sup.shutdown_all(Duration::from_secs(2)).await;
1647
1648 assert_eq!(
1649 sup.active_count(),
1650 0,
1651 "all tasks must be reaped after shutdown (#3161)"
1652 );
1653 }
1654
1655 #[tokio::test]
1656 async fn test_blocking_spawner_task_appears_in_snapshot() {
1657 use crate::BlockingSpawner;
1659
1660 let cancel = CancellationToken::new();
1661 let sup = TaskSupervisor::new(cancel);
1662
1663 let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<()>();
1664 let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>();
1665
1666 let handle = sup.spawn_blocking_named(
1667 Arc::from("chunk_file"),
1668 Box::new(move || {
1669 let _ = ready_tx.send(());
1671 let _ = release_rx.blocking_recv();
1673 }),
1674 );
1675
1676 ready_rx.await.expect("task should start");
1678
1679 let snapshot = sup.snapshot();
1680 assert!(
1681 snapshot.iter().any(|t| t.name.as_ref() == "chunk_file"),
1682 "chunk_file task must appear in supervisor snapshot"
1683 );
1684
1685 let _ = release_tx.send(());
1687 handle.await.expect("task should complete");
1688 }
1689
1690 #[test]
1696 fn test_measure_blocking_emits_metrics() {
1697 use metrics_util::debugging::DebuggingRecorder;
1698
1699 let recorder = DebuggingRecorder::new();
1700 let snapshotter = recorder.snapshotter();
1701
1702 metrics::with_local_recorder(&recorder, || {
1705 measure_blocking("test_task", || std::hint::black_box(42_u64));
1706 });
1707
1708 let snapshot = snapshotter.snapshot();
1709 let metric_names: Vec<String> = snapshot
1710 .into_vec()
1711 .into_iter()
1712 .map(|(k, _, _, _)| k.key().name().to_owned())
1713 .collect();
1714
1715 assert!(
1716 metric_names.iter().any(|n| n == "zeph.task.wall_time_ms"),
1717 "expected zeph.task.wall_time_ms histogram; got: {metric_names:?}"
1718 );
1719 assert!(
1720 metric_names.iter().any(|n| n == "zeph.task.cpu_time_ms"),
1721 "expected zeph.task.cpu_time_ms histogram; got: {metric_names:?}"
1722 );
1723 }
1724
1725 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1731 async fn test_spawn_blocking_semaphore_cap() {
1732 let (sup, _cancel) = make_supervisor();
1733 let concurrent = Arc::new(AtomicU32::new(0));
1734 let max_concurrent = Arc::new(AtomicU32::new(0));
1735 let barrier = Arc::new(std::sync::Barrier::new(1)); let mut handles = Vec::new();
1738 for i in 0u32..16 {
1739 let c = Arc::clone(&concurrent);
1740 let m = Arc::clone(&max_concurrent);
1741 let name: Arc<str> = Arc::from(format!("blocking-{i}").as_str());
1742 let h = sup.spawn_blocking(name, move || {
1743 let prev = c.fetch_add(1, Ordering::SeqCst);
1744 let mut cur_max = m.load(Ordering::SeqCst);
1746 while prev + 1 > cur_max {
1747 match m.compare_exchange(cur_max, prev + 1, Ordering::SeqCst, Ordering::SeqCst)
1748 {
1749 Ok(_) => break,
1750 Err(x) => cur_max = x,
1751 }
1752 }
1753 std::thread::sleep(std::time::Duration::from_millis(20));
1755 c.fetch_sub(1, Ordering::SeqCst);
1756 });
1757 handles.push(h);
1758 }
1759
1760 for h in handles {
1761 h.join().await.expect("blocking task should succeed");
1762 }
1763 drop(barrier);
1764
1765 let observed = max_concurrent.load(Ordering::SeqCst);
1766 assert!(
1767 observed <= 8,
1768 "observed {observed} concurrent blocking tasks; expected ≤ 8 (semaphore cap)"
1769 );
1770 }
1771}