1use std::collections::HashMap;
47use std::future::Future;
48use std::pin::Pin;
49use std::sync::Arc;
50use std::time::{Duration, Instant};
51
52use tokio::sync::{mpsc, oneshot};
53use tokio::task::AbortHandle;
54use tokio_util::sync::CancellationToken;
55use tracing::Instrument as _;
56
57use crate::BlockingSpawner;
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum RestartPolicy {
66 RunOnce,
68 Restart { max: u32, base_delay: Duration },
90}
91
92pub const MAX_RESTART_DELAY: Duration = Duration::from_mins(1);
94
95const SHUTDOWN_DRAIN_TIMEOUT: Duration = Duration::from_secs(5);
100
101pub struct TaskDescriptor<F> {
106 pub name: &'static str,
111 pub restart: RestartPolicy,
113 pub factory: F,
115}
116
117#[derive(Debug, Clone)]
121pub struct TaskHandle {
122 name: &'static str,
123 abort: AbortHandle,
124}
125
126impl TaskHandle {
127 pub fn abort(&self) {
129 tracing::debug!(task.name = self.name, "task aborted via handle");
130 self.abort.abort();
131 }
132
133 #[must_use]
135 pub fn name(&self) -> &'static str {
136 self.name
137 }
138}
139
140#[derive(Debug, PartialEq, Eq)]
142pub enum BlockingError {
143 Panicked,
145 SupervisorDropped,
147}
148
149impl std::fmt::Display for BlockingError {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 match self {
152 Self::Panicked => write!(f, "supervised blocking task panicked"),
153 Self::SupervisorDropped => write!(f, "supervisor dropped before task completed"),
154 }
155 }
156}
157
158impl std::error::Error for BlockingError {}
159
160pub struct BlockingHandle<R> {
169 rx: oneshot::Receiver<Result<R, BlockingError>>,
170 abort: AbortHandle,
171}
172
173impl<R> BlockingHandle<R> {
174 pub async fn join(self) -> Result<R, BlockingError> {
182 self.rx
183 .await
184 .unwrap_or(Err(BlockingError::SupervisorDropped))
185 }
186
187 pub fn try_join(mut self) -> Result<Result<R, BlockingError>, Self> {
219 match self.rx.try_recv() {
220 Ok(result) => Ok(result),
221 Err(tokio::sync::oneshot::error::TryRecvError::Empty) => Err(self),
222 Err(tokio::sync::oneshot::error::TryRecvError::Closed) => {
223 Ok(Err(BlockingError::SupervisorDropped))
224 }
225 }
226 }
227
228 pub fn abort(&self) {
230 self.abort.abort();
231 }
232}
233
234#[derive(Debug, Clone, PartialEq, Eq)]
236pub enum TaskStatus {
237 Running,
239 Restarting { attempt: u32, max: u32 },
241 Completed,
243 Aborted,
245 Failed { reason: String },
247}
248
249#[derive(Debug, Clone)]
251pub struct TaskSnapshot {
264 pub name: Arc<str>,
266 pub status: TaskStatus,
268 pub started_at: Instant,
270 pub restart_count: u32,
272}
273
274type BoxFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
277type BoxFactory = Box<dyn Fn() -> BoxFuture + Send + Sync>;
278
279struct TaskEntry {
280 name: Arc<str>,
281 status: TaskStatus,
282 started_at: Instant,
283 restart_count: u32,
284 restart_policy: RestartPolicy,
285 abort_handle: AbortHandle,
286 factory: Option<BoxFactory>,
288}
289
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292enum CompletionKind {
293 Normal,
295 Panicked,
297 Cancelled,
299}
300
301struct Completion {
302 name: Arc<str>,
303 kind: CompletionKind,
304}
305
306struct SupervisorState {
307 tasks: HashMap<Arc<str>, TaskEntry>,
308}
309
310struct Inner {
311 state: parking_lot::Mutex<SupervisorState>,
312 completion_tx: mpsc::UnboundedSender<Completion>,
316 cancel: CancellationToken,
317 blocking_semaphore: Arc<tokio::sync::Semaphore>,
320}
321
322#[derive(Clone)]
358pub struct TaskSupervisor {
359 inner: Arc<Inner>,
360}
361
362impl TaskSupervisor {
363 #[must_use]
375 pub fn new(cancel: CancellationToken) -> Self {
376 let (completion_tx, completion_rx) = mpsc::unbounded_channel();
380 let inner = Arc::new(Inner {
381 state: parking_lot::Mutex::new(SupervisorState {
382 tasks: HashMap::new(),
383 }),
384 completion_tx,
385 cancel: cancel.clone(),
386 blocking_semaphore: Arc::new(tokio::sync::Semaphore::new(8)),
387 });
388
389 if tokio::runtime::Handle::try_current().is_ok() {
394 Self::start_reap_driver(Arc::clone(&inner), completion_rx, cancel);
395 }
396
397 Self { inner }
398 }
399
400 pub fn spawn<F, Fut>(&self, desc: TaskDescriptor<F>) -> TaskHandle
425 where
426 F: Fn() -> Fut + Send + Sync + 'static,
427 Fut: Future<Output = ()> + Send + 'static,
428 {
429 let factory: BoxFactory = Box::new(move || Box::pin((desc.factory)()));
430 let cancel = self.inner.cancel.clone();
431 let completion_tx = self.inner.completion_tx.clone();
432 let name: Arc<str> = Arc::from(desc.name);
433
434 let (abort_handle, jh) = Self::do_spawn(desc.name, &factory, cancel);
435 Self::wire_completion_reporter(Arc::clone(&name), jh, completion_tx);
436
437 let entry = TaskEntry {
438 name: Arc::clone(&name),
439 status: TaskStatus::Running,
440 started_at: Instant::now(),
441 restart_count: 0,
442 restart_policy: desc.restart,
443 abort_handle: abort_handle.clone(),
444 factory: match desc.restart {
445 RestartPolicy::RunOnce => None,
446 RestartPolicy::Restart { .. } => Some(factory),
447 },
448 };
449
450 {
451 let mut state = self.inner.state.lock();
452 if let Some(old) = state.tasks.remove(&name) {
453 old.abort_handle.abort();
454 }
455 state.tasks.insert(Arc::clone(&name), entry);
456 }
457
458 TaskHandle {
459 name: desc.name,
460 abort: abort_handle,
461 }
462 }
463
464 #[allow(clippy::needless_pass_by_value)] pub fn spawn_blocking<F, R>(&self, name: Arc<str>, f: F) -> BlockingHandle<R>
509 where
510 F: FnOnce() -> R + Send + 'static,
511 R: Send + 'static,
512 {
513 let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
514 #[cfg(feature = "task-metrics")]
515 let span = tracing::info_span!(
516 "supervised_blocking_task",
517 task.name = %name,
518 task.wall_time_ms = tracing::field::Empty,
519 task.cpu_time_ms = tracing::field::Empty,
520 );
521 #[cfg(not(feature = "task-metrics"))]
522 let span = tracing::info_span!("supervised_blocking_task", task.name = %name);
523
524 let semaphore = Arc::clone(&self.inner.blocking_semaphore);
525 let inner = Arc::clone(&self.inner);
526 let name_clone = Arc::clone(&name);
527 let completion_tx = self.inner.completion_tx.clone();
528
529 let outer = tokio::spawn(async move {
532 let _permit = semaphore
533 .acquire_owned()
534 .await
535 .expect("blocking semaphore closed");
536
537 let name_for_measure = Arc::clone(&name_clone);
538 let join_handle = tokio::task::spawn_blocking(move || {
539 let _enter = span.enter();
540 measure_blocking(&name_for_measure, f)
541 });
542 let abort = join_handle.abort_handle();
543
544 {
546 let mut state = inner.state.lock();
547 if let Some(entry) = state.tasks.get_mut(&name_clone) {
548 entry.abort_handle = abort;
549 }
550 }
551
552 let kind = match join_handle.await {
553 Ok(val) => {
554 let _ = tx.send(Ok(val));
555 CompletionKind::Normal
556 }
557 Err(e) if e.is_panic() => {
558 let _ = tx.send(Err(BlockingError::Panicked));
559 CompletionKind::Panicked
560 }
561 Err(_) => {
562 CompletionKind::Cancelled
564 }
565 };
566 let _ = completion_tx.send(Completion {
568 name: name_clone,
569 kind,
570 });
571 });
572 let abort = outer.abort_handle();
573
574 {
576 let mut state = self.inner.state.lock();
577 if let Some(old) = state.tasks.remove(&name) {
578 old.abort_handle.abort();
579 }
580 state.tasks.insert(
581 Arc::clone(&name),
582 TaskEntry {
583 name: Arc::clone(&name),
584 status: TaskStatus::Running,
585 started_at: Instant::now(),
586 restart_count: 0,
587 restart_policy: RestartPolicy::RunOnce,
588 abort_handle: abort.clone(),
589 factory: None,
590 },
591 );
592 }
593
594 BlockingHandle { rx, abort }
595 }
596
597 pub fn spawn_oneshot<F, Fut, R>(&self, name: Arc<str>, factory: F) -> BlockingHandle<R>
625 where
626 F: FnOnce() -> Fut + Send + 'static,
627 Fut: Future<Output = R> + Send + 'static,
628 R: Send + 'static,
629 {
630 let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
631 let cancel = self.inner.cancel.clone();
632 let span = tracing::info_span!("supervised_task", task.name = %name);
633 let join_handle: tokio::task::JoinHandle<Option<R>> = tokio::spawn(
634 async move {
635 let fut = factory();
636 tokio::select! {
637 result = fut => Some(result),
638 () = cancel.cancelled() => None,
639 }
640 }
641 .instrument(span),
642 );
643 let abort = join_handle.abort_handle();
644
645 {
646 let mut state = self.inner.state.lock();
647 if let Some(old) = state.tasks.remove(&name) {
648 old.abort_handle.abort();
649 }
650 state.tasks.insert(
651 Arc::clone(&name),
652 TaskEntry {
653 name: Arc::clone(&name),
654 status: TaskStatus::Running,
655 started_at: Instant::now(),
656 restart_count: 0,
657 restart_policy: RestartPolicy::RunOnce,
658 abort_handle: abort.clone(),
659 factory: None,
660 },
661 );
662 }
663
664 let completion_tx = self.inner.completion_tx.clone();
665 tokio::spawn(async move {
666 let kind = match join_handle.await {
667 Ok(Some(val)) => {
668 let _ = tx.send(Ok(val));
669 CompletionKind::Normal
670 }
671 Err(e) if e.is_panic() => {
672 let _ = tx.send(Err(BlockingError::Panicked));
673 CompletionKind::Panicked
674 }
675 _ => CompletionKind::Cancelled,
676 };
677 let _ = completion_tx.send(Completion { name, kind });
678 });
679 BlockingHandle { rx, abort }
680 }
681
682 pub fn abort(&self, name: &'static str) {
684 let state = self.inner.state.lock();
685 let key: Arc<str> = Arc::from(name);
686 if let Some(entry) = state.tasks.get(&key) {
687 entry.abort_handle.abort();
688 tracing::debug!(task.name = name, "task aborted via supervisor");
689 }
690 }
691
692 pub async fn shutdown_all(&self, timeout: Duration) {
705 self.inner.cancel.cancel();
706 let deadline = tokio::time::Instant::now() + timeout;
707 loop {
708 let active = self.active_count();
709 if active == 0 {
710 break;
711 }
712 if tokio::time::Instant::now() >= deadline {
713 let mut remaining_names: Vec<Arc<str>> = Vec::new();
714 {
715 let mut state = self.inner.state.lock();
716 for entry in state.tasks.values_mut() {
717 if matches!(
718 entry.status,
719 TaskStatus::Running | TaskStatus::Restarting { .. }
720 ) {
721 remaining_names.push(Arc::clone(&entry.name));
722 entry.abort_handle.abort();
723 entry.status = TaskStatus::Aborted;
724 }
725 }
726 }
727 tracing::warn!(
728 remaining = active,
729 tasks = ?remaining_names,
730 "shutdown timeout — aborting remaining tasks"
731 );
732 break;
733 }
734 tokio::time::sleep(Duration::from_millis(50)).await;
735 }
736 }
737
738 #[must_use]
743 pub fn snapshot(&self) -> Vec<TaskSnapshot> {
744 let state = self.inner.state.lock();
745 let mut snaps: Vec<TaskSnapshot> = state
746 .tasks
747 .values()
748 .map(|e| TaskSnapshot {
749 name: Arc::clone(&e.name),
750 status: e.status.clone(),
751 started_at: e.started_at,
752 restart_count: e.restart_count,
753 })
754 .collect();
755 snaps.sort_by_key(|s| s.started_at);
756 snaps
757 }
758
759 #[must_use]
761 pub fn active_count(&self) -> usize {
762 let state = self.inner.state.lock();
763 state
764 .tasks
765 .values()
766 .filter(|e| {
767 matches!(
768 e.status,
769 TaskStatus::Running | TaskStatus::Restarting { .. }
770 )
771 })
772 .count()
773 }
774
775 #[must_use]
779 pub fn cancellation_token(&self) -> CancellationToken {
780 self.inner.cancel.clone()
781 }
782
783 fn do_spawn(
787 name: &'static str,
788 factory: &BoxFactory,
789 cancel: CancellationToken,
790 ) -> (AbortHandle, tokio::task::JoinHandle<()>) {
791 let fut = factory();
792 let span = tracing::info_span!("supervised_task", task.name = name);
793 let jh = tokio::spawn(
794 async move {
795 tokio::select! {
796 () = fut => {},
797 () = cancel.cancelled() => {},
798 }
799 }
800 .instrument(span),
801 );
802 let abort = jh.abort_handle();
803 (abort, jh)
804 }
805
806 fn wire_completion_reporter(
808 name: Arc<str>,
809 jh: tokio::task::JoinHandle<()>,
810 completion_tx: mpsc::UnboundedSender<Completion>,
811 ) {
812 tokio::spawn(async move {
813 let kind = match jh.await {
814 Ok(()) => CompletionKind::Normal,
815 Err(e) if e.is_panic() => CompletionKind::Panicked,
816 Err(_) => CompletionKind::Cancelled,
817 };
818 let _ = completion_tx.send(Completion { name, kind });
819 });
820 }
821
822 fn start_reap_driver(
829 inner: Arc<Inner>,
830 mut completion_rx: mpsc::UnboundedReceiver<Completion>,
831 cancel: CancellationToken,
832 ) {
833 tokio::spawn(async move {
834 loop {
836 tokio::select! {
837 biased;
838 Some(completion) = completion_rx.recv() => {
839 Self::handle_completion(&inner, completion).await;
840 }
841 () = cancel.cancelled() => break,
842 }
843 }
844
845 let drain_deadline = tokio::time::Instant::now() + SHUTDOWN_DRAIN_TIMEOUT;
850 let active = Self::has_active_tasks(&inner);
851 tracing::debug!(active, "reap driver entered post-cancel drain phase");
852 loop {
853 if !Self::has_active_tasks(&inner) {
854 break;
855 }
856 let remaining =
857 drain_deadline.saturating_duration_since(tokio::time::Instant::now());
858 if remaining.is_zero() {
859 break;
860 }
861 match tokio::time::timeout(remaining, completion_rx.recv()).await {
862 Ok(Some(completion)) => Self::handle_completion(&inner, completion).await,
863 Ok(None) | Err(_) => break,
865 }
866 }
867 tracing::debug!(
868 active = Self::has_active_tasks(&inner),
869 "reap driver drain phase complete"
870 );
871 });
872 }
873
874 fn has_active_tasks(inner: &Arc<Inner>) -> bool {
876 let state = inner.state.lock();
877 state.tasks.values().any(|e| {
878 matches!(
879 e.status,
880 TaskStatus::Running | TaskStatus::Restarting { .. }
881 )
882 })
883 }
884
885 async fn handle_completion(inner: &Arc<Inner>, completion: Completion) {
891 if inner.cancel.is_cancelled() {
895 let mut state = inner.state.lock();
896 state.tasks.remove(&completion.name);
897 return;
898 }
899
900 let Some((attempt, max, delay)) = Self::classify_completion(inner, &completion) else {
901 return;
902 };
903
904 tracing::warn!(
905 task.name = %completion.name,
906 attempt,
907 max,
908 delay_ms = delay.as_millis(),
909 "restarting supervised task"
910 );
911
912 if !delay.is_zero() {
913 tokio::time::sleep(delay).await;
914 }
915
916 Self::do_restart(inner, &completion.name, attempt);
917 }
918
919 fn classify_completion(
923 inner: &Arc<Inner>,
924 completion: &Completion,
925 ) -> Option<(u32, u32, Duration)> {
926 let mut state = inner.state.lock();
927 let entry = state.tasks.get_mut(&completion.name)?;
928
929 match completion.kind {
930 CompletionKind::Panicked => {
931 tracing::warn!(task.name = %completion.name, "supervised task panicked");
932 }
933 CompletionKind::Normal => {
934 tracing::info!(task.name = %completion.name, "supervised task completed");
935 }
936 CompletionKind::Cancelled => {
937 tracing::debug!(task.name = %completion.name, "supervised task cancelled");
938 }
939 }
940
941 match entry.restart_policy {
942 RestartPolicy::RunOnce => {
943 entry.status = TaskStatus::Completed;
944 state.tasks.remove(&completion.name);
945 None
946 }
947 RestartPolicy::Restart { max, base_delay } => {
948 if completion.kind != CompletionKind::Panicked {
950 entry.status = TaskStatus::Completed;
951 state.tasks.remove(&completion.name);
952 return None;
953 }
954 if entry.restart_count >= max {
955 let reason = format!("panicked after {max} restart(s)");
956 tracing::error!(
957 task.name = %completion.name,
958 attempts = max,
959 "task failed permanently"
960 );
961 entry.status = TaskStatus::Failed { reason };
962 None
963 } else {
964 let attempt = entry.restart_count + 1;
965 entry.status = TaskStatus::Restarting { attempt, max };
966 let multiplier = 1_u32
968 .checked_shl(attempt.saturating_sub(1))
969 .unwrap_or(u32::MAX);
970 let delay = base_delay.saturating_mul(multiplier).min(MAX_RESTART_DELAY);
971 Some((attempt, max, delay))
972 }
973 }
974 }
975 }
977
978 fn do_restart(inner: &Arc<Inner>, name: &Arc<str>, attempt: u32) {
980 let spawn_params = {
981 let mut state = inner.state.lock();
982 let Some(entry) = state.tasks.get_mut(name.as_ref()) else {
983 tracing::debug!(
984 task.name = %name,
985 "task removed during restart delay — skipping"
986 );
987 return;
988 };
989 if !matches!(entry.status, TaskStatus::Restarting { .. }) {
990 return;
991 }
992 let Some(factory) = &entry.factory else {
993 return;
994 };
995 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(factory)) {
998 Err(_) => {
999 let reason = format!("factory panicked on restart attempt {attempt}");
1000 tracing::error!(task.name = %name, attempt, "factory panicked during restart");
1001 entry.status = TaskStatus::Failed { reason };
1002 None
1003 }
1004 Ok(fut) => Some((
1005 fut,
1006 inner.cancel.clone(),
1007 inner.completion_tx.clone(),
1008 name.clone(),
1009 )),
1010 }
1011 };
1013
1014 let Some((fut, cancel, completion_tx, name)) = spawn_params else {
1015 return;
1016 };
1017
1018 let span = tracing::info_span!("supervised_task", task.name = %name);
1019 let jh = tokio::spawn(
1020 async move {
1021 tokio::select! {
1022 () = fut => {},
1023 () = cancel.cancelled() => {},
1024 }
1025 }
1026 .instrument(span),
1027 );
1028 let new_abort = jh.abort_handle();
1029
1030 {
1031 let mut state = inner.state.lock();
1032 if let Some(entry) = state.tasks.get_mut(name.as_ref()) {
1033 entry.restart_count = attempt;
1034 entry.status = TaskStatus::Running;
1035 entry.abort_handle = new_abort;
1036 }
1037 }
1038
1039 Self::wire_completion_reporter(name.clone(), jh, completion_tx);
1040 }
1041}
1042
1043#[cfg(feature = "task-metrics")]
1050#[inline]
1051fn measure_blocking<F, R>(name: &str, f: F) -> R
1052where
1053 F: FnOnce() -> R,
1054{
1055 use cpu_time::ThreadTime;
1056 let wall_start = std::time::Instant::now();
1057 let cpu_start = ThreadTime::now();
1058 let result = f();
1059 let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
1060 let cpu_ms = cpu_start.elapsed().as_secs_f64() * 1000.0;
1061 metrics::histogram!("zeph.task.wall_time_ms", "task" => name.to_owned()).record(wall_ms);
1062 metrics::histogram!("zeph.task.cpu_time_ms", "task" => name.to_owned()).record(cpu_ms);
1063 tracing::Span::current().record("task.wall_time_ms", wall_ms);
1064 tracing::Span::current().record("task.cpu_time_ms", cpu_ms);
1065 result
1066}
1067
1068#[cfg(not(feature = "task-metrics"))]
1072#[inline]
1073fn measure_blocking<F, R>(_name: &str, f: F) -> R
1074where
1075 F: FnOnce() -> R,
1076{
1077 f()
1078}
1079
1080impl BlockingSpawner for TaskSupervisor {
1083 fn spawn_blocking_named(
1089 &self,
1090 name: Arc<str>,
1091 f: Box<dyn FnOnce() + Send + 'static>,
1092 ) -> tokio::task::JoinHandle<()> {
1093 let handle = self.spawn_blocking(Arc::clone(&name), f);
1094 tokio::spawn(async move {
1095 if let Err(e) = handle.join().await {
1096 tracing::error!(task.name = %name, error = %e, "supervised blocking task failed");
1097 }
1098 })
1099 }
1100}
1101
1102#[cfg(test)]
1105mod tests {
1106 use std::sync::Arc;
1107 use std::sync::atomic::{AtomicU32, Ordering};
1108 use std::time::Duration;
1109
1110 use tokio_util::sync::CancellationToken;
1111
1112 use super::*;
1113
1114 fn make_supervisor() -> (TaskSupervisor, CancellationToken) {
1115 let cancel = CancellationToken::new();
1116 let sup = TaskSupervisor::new(cancel.clone());
1117 (sup, cancel)
1118 }
1119
1120 #[tokio::test]
1121 async fn test_spawn_and_complete() {
1122 let (sup, _cancel) = make_supervisor();
1123
1124 let done = Arc::new(tokio::sync::Notify::new());
1125 let done2 = Arc::clone(&done);
1126
1127 sup.spawn(TaskDescriptor {
1128 name: "simple",
1129 restart: RestartPolicy::RunOnce,
1130 factory: move || {
1131 let d = Arc::clone(&done2);
1132 async move {
1133 d.notify_one();
1134 }
1135 },
1136 });
1137
1138 tokio::time::timeout(Duration::from_secs(2), done.notified())
1139 .await
1140 .expect("task should complete");
1141
1142 tokio::time::sleep(Duration::from_millis(50)).await;
1143 assert_eq!(
1144 sup.active_count(),
1145 0,
1146 "RunOnce task should be removed after completion"
1147 );
1148 }
1149
1150 #[tokio::test]
1151 async fn test_panic_capture() {
1152 let (sup, _cancel) = make_supervisor();
1153
1154 sup.spawn(TaskDescriptor {
1155 name: "panicking",
1156 restart: RestartPolicy::RunOnce,
1157 factory: || async { panic!("intentional test panic") },
1158 });
1159
1160 tokio::time::sleep(Duration::from_millis(200)).await;
1161
1162 let snaps = sup.snapshot();
1163 assert!(
1164 snaps.iter().all(|s| s.name.as_ref() != "panicking"),
1165 "entry should be reaped"
1166 );
1167 assert_eq!(
1168 sup.active_count(),
1169 0,
1170 "active count must be 0 after RunOnce panic"
1171 );
1172 }
1173
1174 #[tokio::test]
1177 async fn test_restart_only_on_panic() {
1178 let (sup, _cancel) = make_supervisor();
1179
1180 let normal_counter = Arc::new(AtomicU32::new(0));
1182 let nc = Arc::clone(&normal_counter);
1183 sup.spawn(TaskDescriptor {
1184 name: "normal-exit",
1185 restart: RestartPolicy::Restart {
1186 max: 3,
1187 base_delay: Duration::from_millis(10),
1188 },
1189 factory: move || {
1190 let c = Arc::clone(&nc);
1191 async move {
1192 c.fetch_add(1, Ordering::SeqCst);
1193 }
1195 },
1196 });
1197
1198 tokio::time::sleep(Duration::from_millis(300)).await;
1199 assert_eq!(
1200 normal_counter.load(Ordering::SeqCst),
1201 1,
1202 "normal exit must not restart"
1203 );
1204 assert!(
1205 sup.snapshot()
1206 .iter()
1207 .all(|s| s.name.as_ref() != "normal-exit"),
1208 "entry removed after normal exit"
1209 );
1210
1211 let panic_counter = Arc::new(AtomicU32::new(0));
1213 let pc = Arc::clone(&panic_counter);
1214 sup.spawn(TaskDescriptor {
1215 name: "panic-exit",
1216 restart: RestartPolicy::Restart {
1217 max: 2,
1218 base_delay: Duration::from_millis(10),
1219 },
1220 factory: move || {
1221 let c = Arc::clone(&pc);
1222 async move {
1223 c.fetch_add(1, Ordering::SeqCst);
1224 panic!("test panic");
1225 }
1226 },
1227 });
1228
1229 tokio::time::sleep(Duration::from_millis(500)).await;
1231 assert!(
1232 panic_counter.load(Ordering::SeqCst) >= 3,
1233 "panicking task must restart max times"
1234 );
1235 let snap = sup
1236 .snapshot()
1237 .into_iter()
1238 .find(|s| s.name.as_ref() == "panic-exit");
1239 assert!(
1240 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1241 "task must be Failed after exhausting restarts"
1242 );
1243 }
1244
1245 #[tokio::test]
1246 async fn test_restart_policy() {
1247 let (sup, _cancel) = make_supervisor();
1248
1249 let counter = Arc::new(AtomicU32::new(0));
1250 let counter2 = Arc::clone(&counter);
1251
1252 sup.spawn(TaskDescriptor {
1253 name: "restartable",
1254 restart: RestartPolicy::Restart {
1255 max: 2,
1256 base_delay: Duration::from_millis(10),
1257 },
1258 factory: move || {
1259 let c = Arc::clone(&counter2);
1260 async move {
1261 c.fetch_add(1, Ordering::SeqCst);
1262 panic!("always panic");
1263 }
1264 },
1265 });
1266
1267 tokio::time::sleep(Duration::from_millis(500)).await;
1268
1269 let runs = counter.load(Ordering::SeqCst);
1270 assert!(
1271 runs >= 3,
1272 "expected at least 3 invocations (initial + 2 restarts), got {runs}"
1273 );
1274
1275 let snaps = sup.snapshot();
1276 let snap = snaps.iter().find(|s| s.name.as_ref() == "restartable");
1277 assert!(snap.is_some(), "failed task should remain in registry");
1278 assert!(
1279 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1280 "task should be Failed after exhausting retries"
1281 );
1282 }
1283
1284 #[tokio::test]
1286 async fn test_exponential_backoff() {
1287 let (sup, _cancel) = make_supervisor();
1288
1289 let timestamps = Arc::new(parking_lot::Mutex::new(Vec::<std::time::Instant>::new()));
1290 let ts = Arc::clone(×tamps);
1291
1292 sup.spawn(TaskDescriptor {
1293 name: "backoff-task",
1294 restart: RestartPolicy::Restart {
1295 max: 3,
1296 base_delay: Duration::from_millis(50),
1297 },
1298 factory: move || {
1299 let t = Arc::clone(&ts);
1300 async move {
1301 t.lock().push(std::time::Instant::now());
1302 panic!("always panic");
1303 }
1304 },
1305 });
1306
1307 tokio::time::sleep(Duration::from_millis(800)).await;
1309
1310 let ts = timestamps.lock();
1311 assert!(
1312 ts.len() >= 3,
1313 "expected at least 3 invocations, got {}",
1314 ts.len()
1315 );
1316
1317 if ts.len() >= 3 {
1319 let d1 = ts[1].duration_since(ts[0]);
1320 let d2 = ts[2].duration_since(ts[1]);
1321 assert!(
1323 d2 >= d1.mul_f64(1.5),
1324 "expected exponential backoff: d1={d1:?} d2={d2:?}"
1325 );
1326 }
1327 }
1328
1329 #[tokio::test]
1330 async fn test_graceful_shutdown() {
1331 let (sup, _cancel) = make_supervisor();
1332
1333 for name in ["svc-a", "svc-b", "svc-c"] {
1334 sup.spawn(TaskDescriptor {
1335 name,
1336 restart: RestartPolicy::RunOnce,
1337 factory: || async {
1338 tokio::time::sleep(Duration::from_mins(1)).await;
1339 },
1340 });
1341 }
1342
1343 assert_eq!(sup.active_count(), 3);
1344
1345 tokio::time::timeout(
1346 Duration::from_secs(2),
1347 sup.shutdown_all(Duration::from_secs(1)),
1348 )
1349 .await
1350 .expect("shutdown should complete within timeout");
1351 }
1352
1353 #[tokio::test]
1355 async fn test_force_abort_marks_aborted() {
1356 let cancel = CancellationToken::new();
1357 let sup = TaskSupervisor::new(cancel.clone());
1358
1359 sup.spawn(TaskDescriptor {
1360 name: "stubborn-for-abort",
1361 restart: RestartPolicy::RunOnce,
1362 factory: || async {
1363 std::future::pending::<()>().await;
1365 },
1366 });
1367
1368 sup.shutdown_all(Duration::from_millis(1)).await;
1370
1371 let snaps = sup.snapshot();
1373 if let Some(snap) = snaps
1374 .iter()
1375 .find(|s| s.name.as_ref() == "stubborn-for-abort")
1376 {
1377 assert_eq!(
1378 snap.status,
1379 TaskStatus::Aborted,
1380 "force-aborted task must have Aborted status"
1381 );
1382 }
1383 }
1385
1386 #[tokio::test]
1387 async fn test_registry_snapshot() {
1388 let (sup, _cancel) = make_supervisor();
1389
1390 for name in ["alpha", "beta"] {
1391 sup.spawn(TaskDescriptor {
1392 name,
1393 restart: RestartPolicy::RunOnce,
1394 factory: || async {
1395 tokio::time::sleep(Duration::from_secs(10)).await;
1396 },
1397 });
1398 }
1399
1400 let snaps = sup.snapshot();
1401 assert_eq!(snaps.len(), 2);
1402 let names: Vec<&str> = snaps.iter().map(|s| s.name.as_ref()).collect();
1403 assert!(names.contains(&"alpha"));
1404 assert!(names.contains(&"beta"));
1405 assert!(snaps.iter().all(|s| s.status == TaskStatus::Running));
1406 }
1407
1408 #[tokio::test]
1409 async fn test_blocking_returns_value() {
1410 let (sup, cancel) = make_supervisor();
1411
1412 let handle: BlockingHandle<u32> = sup.spawn_blocking(Arc::from("compute"), || 42_u32);
1413 let result = handle.join().await.expect("should return value");
1414 assert_eq!(result, 42);
1415 cancel.cancel();
1416 }
1417
1418 #[tokio::test]
1419 async fn test_blocking_panic() {
1420 let (sup, _cancel) = make_supervisor();
1421
1422 let handle: BlockingHandle<u32> =
1423 sup.spawn_blocking(Arc::from("panicking-compute"), || panic!("intentional"));
1424 let err = handle
1425 .join()
1426 .await
1427 .expect_err("should return error on panic");
1428 assert_eq!(err, BlockingError::Panicked);
1429 }
1430
1431 #[tokio::test]
1433 async fn test_blocking_registered_in_registry() {
1434 let (sup, cancel) = make_supervisor();
1435
1436 let (tx, rx) = std::sync::mpsc::channel::<()>();
1437 let _handle: BlockingHandle<()> =
1438 sup.spawn_blocking(Arc::from("blocking-task"), move || {
1439 let _ = rx.recv();
1441 });
1442
1443 tokio::time::sleep(Duration::from_millis(10)).await;
1444 assert_eq!(
1445 sup.active_count(),
1446 1,
1447 "blocking task must appear in active_count"
1448 );
1449
1450 let _ = tx.send(());
1451 tokio::time::sleep(Duration::from_millis(100)).await;
1452 assert_eq!(
1453 sup.active_count(),
1454 0,
1455 "blocking task must be removed after completion"
1456 );
1457
1458 cancel.cancel();
1459 }
1460
1461 #[tokio::test]
1463 async fn test_oneshot_registered_in_registry() {
1464 let (sup, cancel) = make_supervisor();
1465
1466 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
1467 let _handle: BlockingHandle<()> =
1468 sup.spawn_oneshot(Arc::from("oneshot-task"), move || async move {
1469 let _ = rx.await;
1470 });
1471
1472 tokio::time::sleep(Duration::from_millis(10)).await;
1473 assert_eq!(
1474 sup.active_count(),
1475 1,
1476 "oneshot task must appear in active_count"
1477 );
1478
1479 let _ = tx.send(());
1480 tokio::time::sleep(Duration::from_millis(50)).await;
1481 assert_eq!(
1482 sup.active_count(),
1483 0,
1484 "oneshot task must be removed after completion"
1485 );
1486
1487 cancel.cancel();
1488 }
1489
1490 #[tokio::test]
1491 async fn test_restart_max_zero() {
1492 let (sup, _cancel) = make_supervisor();
1493
1494 let counter = Arc::new(AtomicU32::new(0));
1495 let counter2 = Arc::clone(&counter);
1496
1497 sup.spawn(TaskDescriptor {
1498 name: "zero-max",
1499 restart: RestartPolicy::Restart {
1500 max: 0,
1501 base_delay: Duration::from_millis(10),
1502 },
1503 factory: move || {
1504 let c = Arc::clone(&counter2);
1505 async move {
1506 c.fetch_add(1, Ordering::SeqCst);
1507 panic!("always panic");
1508 }
1509 },
1510 });
1511
1512 tokio::time::sleep(Duration::from_millis(200)).await;
1513
1514 assert_eq!(
1515 counter.load(Ordering::SeqCst),
1516 1,
1517 "max=0 should not restart"
1518 );
1519
1520 let snaps = sup.snapshot();
1521 let snap = snaps.iter().find(|s| s.name.as_ref() == "zero-max");
1522 assert!(snap.is_some(), "entry should remain as Failed");
1523 assert!(
1524 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1525 "status should be Failed"
1526 );
1527 }
1528
1529 #[tokio::test]
1531 async fn test_concurrent_spawns() {
1532 static NAMES: [&str; 50] = [
1534 "t00", "t01", "t02", "t03", "t04", "t05", "t06", "t07", "t08", "t09", "t10", "t11",
1535 "t12", "t13", "t14", "t15", "t16", "t17", "t18", "t19", "t20", "t21", "t22", "t23",
1536 "t24", "t25", "t26", "t27", "t28", "t29", "t30", "t31", "t32", "t33", "t34", "t35",
1537 "t36", "t37", "t38", "t39", "t40", "t41", "t42", "t43", "t44", "t45", "t46", "t47",
1538 "t48", "t49",
1539 ];
1540 let (sup, cancel) = make_supervisor();
1541
1542 let completed = Arc::new(AtomicU32::new(0));
1543 for name in &NAMES {
1544 let c = Arc::clone(&completed);
1545 sup.spawn(TaskDescriptor {
1546 name,
1547 restart: RestartPolicy::RunOnce,
1548 factory: move || {
1549 let c = Arc::clone(&c);
1550 async move {
1551 c.fetch_add(1, Ordering::SeqCst);
1552 }
1553 },
1554 });
1555 }
1556
1557 tokio::time::timeout(Duration::from_secs(5), async {
1559 loop {
1560 if completed.load(Ordering::SeqCst) == 50 {
1561 break;
1562 }
1563 tokio::time::sleep(Duration::from_millis(10)).await;
1564 }
1565 })
1566 .await
1567 .expect("all 50 tasks should complete");
1568
1569 tokio::time::sleep(Duration::from_millis(100)).await;
1571 assert_eq!(sup.active_count(), 0, "all tasks must be reaped");
1572
1573 cancel.cancel();
1574 }
1575
1576 #[tokio::test]
1577 async fn test_shutdown_timeout_expiry() {
1578 let cancel = CancellationToken::new();
1579 let sup = TaskSupervisor::new(cancel.clone());
1580
1581 sup.spawn(TaskDescriptor {
1582 name: "stubborn",
1583 restart: RestartPolicy::RunOnce,
1584 factory: || async {
1585 tokio::time::sleep(Duration::from_mins(1)).await;
1586 },
1587 });
1588
1589 assert_eq!(sup.active_count(), 1);
1590
1591 tokio::time::timeout(
1592 Duration::from_secs(2),
1593 sup.shutdown_all(Duration::from_millis(50)),
1594 )
1595 .await
1596 .expect("shutdown_all should return even on timeout expiry");
1597
1598 assert!(
1599 cancel.is_cancelled(),
1600 "cancel token must be cancelled after shutdown"
1601 );
1602 }
1603
1604 #[tokio::test]
1605 async fn test_cancellation_token() {
1606 let cancel = CancellationToken::new();
1607 let sup = TaskSupervisor::new(cancel.clone());
1608
1609 assert!(!sup.cancellation_token().is_cancelled());
1610
1611 sup.shutdown_all(Duration::from_millis(100)).await;
1612
1613 assert!(
1614 sup.cancellation_token().is_cancelled(),
1615 "token must be cancelled after shutdown"
1616 );
1617 }
1618
1619 #[tokio::test]
1625 async fn test_shutdown_drains_post_cancel_completions() {
1626 let cancel = CancellationToken::new();
1627 let sup = TaskSupervisor::new(cancel.clone());
1628
1629 for name in [
1630 "loop-1", "loop-2", "loop-3", "loop-4", "loop-5", "loop-6", "loop-7",
1631 ] {
1632 let cancel_inner = cancel.clone();
1633 sup.spawn(TaskDescriptor {
1634 name,
1635 restart: RestartPolicy::RunOnce,
1636 factory: move || {
1637 let c = cancel_inner.clone();
1638 async move {
1639 c.cancelled().await;
1640 for _ in 0..64 {
1642 tokio::task::yield_now().await;
1643 }
1644 }
1645 },
1646 });
1647 }
1648 assert_eq!(sup.active_count(), 7);
1649
1650 sup.shutdown_all(Duration::from_secs(2)).await;
1651
1652 assert_eq!(
1653 sup.active_count(),
1654 0,
1655 "all tasks must be reaped after shutdown (#3161)"
1656 );
1657 }
1658
1659 #[tokio::test]
1660 async fn test_blocking_spawner_task_appears_in_snapshot() {
1661 use crate::BlockingSpawner;
1663
1664 let cancel = CancellationToken::new();
1665 let sup = TaskSupervisor::new(cancel);
1666
1667 let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<()>();
1668 let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>();
1669
1670 let handle = sup.spawn_blocking_named(
1671 Arc::from("chunk_file"),
1672 Box::new(move || {
1673 let _ = ready_tx.send(());
1675 let _ = release_rx.blocking_recv();
1677 }),
1678 );
1679
1680 ready_rx.await.expect("task should start");
1682
1683 let snapshot = sup.snapshot();
1684 assert!(
1685 snapshot.iter().any(|t| t.name.as_ref() == "chunk_file"),
1686 "chunk_file task must appear in supervisor snapshot"
1687 );
1688
1689 let _ = release_tx.send(());
1691 handle.await.expect("task should complete");
1692 }
1693
1694 #[cfg(feature = "task-metrics")]
1701 #[test]
1702 fn test_measure_blocking_emits_metrics() {
1703 use metrics_util::debugging::DebuggingRecorder;
1704
1705 let recorder = DebuggingRecorder::new();
1706 let snapshotter = recorder.snapshotter();
1707
1708 metrics::with_local_recorder(&recorder, || {
1711 measure_blocking("test_task", || std::hint::black_box(42_u64));
1712 });
1713
1714 let snapshot = snapshotter.snapshot();
1715 let metric_names: Vec<String> = snapshot
1716 .into_vec()
1717 .into_iter()
1718 .map(|(k, _, _, _)| k.key().name().to_owned())
1719 .collect();
1720
1721 assert!(
1722 metric_names.iter().any(|n| n == "zeph.task.wall_time_ms"),
1723 "expected zeph.task.wall_time_ms histogram; got: {metric_names:?}"
1724 );
1725 assert!(
1726 metric_names.iter().any(|n| n == "zeph.task.cpu_time_ms"),
1727 "expected zeph.task.cpu_time_ms histogram; got: {metric_names:?}"
1728 );
1729 }
1730
1731 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1737 async fn test_spawn_blocking_semaphore_cap() {
1738 let (sup, _cancel) = make_supervisor();
1739 let concurrent = Arc::new(AtomicU32::new(0));
1740 let max_concurrent = Arc::new(AtomicU32::new(0));
1741 let barrier = Arc::new(std::sync::Barrier::new(1)); let mut handles = Vec::new();
1744 for i in 0u32..16 {
1745 let c = Arc::clone(&concurrent);
1746 let m = Arc::clone(&max_concurrent);
1747 let name: Arc<str> = Arc::from(format!("blocking-{i}").as_str());
1748 let h = sup.spawn_blocking(name, move || {
1749 let prev = c.fetch_add(1, Ordering::SeqCst);
1750 let mut cur_max = m.load(Ordering::SeqCst);
1752 while prev + 1 > cur_max {
1753 match m.compare_exchange(cur_max, prev + 1, Ordering::SeqCst, Ordering::SeqCst)
1754 {
1755 Ok(_) => break,
1756 Err(x) => cur_max = x,
1757 }
1758 }
1759 std::thread::sleep(std::time::Duration::from_millis(20));
1761 c.fetch_sub(1, Ordering::SeqCst);
1762 });
1763 handles.push(h);
1764 }
1765
1766 for h in handles {
1767 h.join().await.expect("blocking task should succeed");
1768 }
1769 drop(barrier);
1770
1771 let observed = max_concurrent.load(Ordering::SeqCst);
1772 assert!(
1773 observed <= 8,
1774 "observed {observed} concurrent blocking tasks; expected ≤ 8 (semaphore cap)"
1775 );
1776 }
1777}