1use std::collections::HashMap;
47use std::future::Future;
48use std::pin::Pin;
49use std::sync::Arc;
50use std::time::{Duration, Instant};
51
52use tokio::sync::{mpsc, oneshot};
53use tokio::task::AbortHandle;
54use tokio_util::sync::CancellationToken;
55use tracing::Instrument as _;
56
57use crate::BlockingSpawner;
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65#[non_exhaustive]
66pub enum RestartPolicy {
67 RunOnce,
69 Restart { max: u32, base_delay: Duration },
91}
92
93pub const MAX_RESTART_DELAY: Duration = Duration::from_mins(1);
95
96const SHUTDOWN_DRAIN_TIMEOUT: Duration = Duration::from_secs(5);
101
102pub struct TaskDescriptor<F> {
107 pub name: &'static str,
112 pub restart: RestartPolicy,
114 pub factory: F,
116}
117
118#[derive(Debug, Clone)]
122pub struct TaskHandle {
123 name: &'static str,
124 abort: AbortHandle,
125}
126
127impl TaskHandle {
128 pub fn abort(&self) {
130 tracing::debug!(task.name = self.name, "task aborted via handle");
131 self.abort.abort();
132 }
133
134 #[must_use]
136 pub const fn name(&self) -> &'static str {
137 self.name
138 }
139}
140
141#[derive(Debug, PartialEq, Eq)]
143#[non_exhaustive]
144pub enum BlockingError {
145 Panicked,
147 SupervisorDropped,
149}
150
151impl std::fmt::Display for BlockingError {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 match self {
154 Self::Panicked => write!(f, "supervised blocking task panicked"),
155 Self::SupervisorDropped => write!(f, "supervisor dropped before task completed"),
156 }
157 }
158}
159
160impl std::error::Error for BlockingError {}
161
162pub struct BlockingHandle<R> {
171 rx: oneshot::Receiver<Result<R, BlockingError>>,
172 abort: AbortHandle,
173}
174
175impl<R> BlockingHandle<R> {
176 pub async fn join(self) -> Result<R, BlockingError> {
184 self.rx
185 .await
186 .unwrap_or(Err(BlockingError::SupervisorDropped))
187 }
188
189 pub fn try_join(mut self) -> Result<Result<R, BlockingError>, Self> {
221 match self.rx.try_recv() {
222 Ok(result) => Ok(result),
223 Err(tokio::sync::oneshot::error::TryRecvError::Empty) => Err(self),
224 Err(tokio::sync::oneshot::error::TryRecvError::Closed) => {
225 Ok(Err(BlockingError::SupervisorDropped))
226 }
227 }
228 }
229
230 pub fn abort(&self) {
232 self.abort.abort();
233 }
234}
235
236#[derive(Debug, Clone, PartialEq, Eq)]
238#[non_exhaustive]
239pub enum TaskStatus {
240 Running,
242 Restarting { attempt: u32, max: u32 },
244 Completed,
246 Aborted,
248 Failed { reason: String },
250}
251
252#[derive(Debug, Clone)]
254pub struct TaskSnapshot {
264 pub name: Arc<str>,
266 pub status: TaskStatus,
268 pub started_at: Instant,
270 pub restart_count: u32,
272}
273
274type BoxFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
277type BoxFactory = Box<dyn Fn() -> BoxFuture + Send + Sync>;
278
279struct TaskEntry {
280 name: Arc<str>,
281 status: TaskStatus,
282 started_at: Instant,
283 restart_count: u32,
284 restart_policy: RestartPolicy,
285 abort_handle: AbortHandle,
286 factory: Option<BoxFactory>,
288}
289
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292enum CompletionKind {
293 Normal,
295 Panicked,
297 Cancelled,
299}
300
301struct Completion {
302 name: Arc<str>,
303 kind: CompletionKind,
304}
305
306struct SupervisorState {
307 tasks: HashMap<Arc<str>, TaskEntry>,
308}
309
310struct Inner {
311 state: parking_lot::Mutex<SupervisorState>,
312 completion_tx: mpsc::UnboundedSender<Completion>,
316 cancel: CancellationToken,
317 blocking_semaphore: Arc<tokio::sync::Semaphore>,
320 shutdown_notify: Arc<tokio::sync::Notify>,
322}
323
324#[derive(Clone)]
360pub struct TaskSupervisor {
361 inner: Arc<Inner>,
362}
363
364impl TaskSupervisor {
365 #[must_use]
377 pub fn new(cancel: CancellationToken) -> Self {
378 let (completion_tx, completion_rx) = mpsc::unbounded_channel();
382 let inner = Arc::new(Inner {
383 state: parking_lot::Mutex::new(SupervisorState {
384 tasks: HashMap::new(),
385 }),
386 completion_tx,
387 cancel: cancel.clone(),
388 blocking_semaphore: Arc::new(tokio::sync::Semaphore::new(8)),
389 shutdown_notify: Arc::new(tokio::sync::Notify::new()),
390 });
391
392 if tokio::runtime::Handle::try_current().is_ok() {
397 Self::start_reap_driver(Arc::clone(&inner), completion_rx, cancel);
398 }
399
400 Self { inner }
401 }
402
403 pub fn spawn<F, Fut>(&self, desc: TaskDescriptor<F>) -> TaskHandle
428 where
429 F: Fn() -> Fut + Send + Sync + 'static,
430 Fut: Future<Output = ()> + Send + 'static,
431 {
432 let factory: BoxFactory = Box::new(move || Box::pin((desc.factory)()));
433 let cancel = self.inner.cancel.clone();
434 let completion_tx = self.inner.completion_tx.clone();
435 let name: Arc<str> = Arc::from(desc.name);
436
437 let (abort_handle, jh) = Self::do_spawn(desc.name, &factory, cancel);
438 Self::wire_completion_reporter(Arc::clone(&name), jh, completion_tx);
439
440 let entry = TaskEntry {
441 name: Arc::clone(&name),
442 status: TaskStatus::Running,
443 started_at: Instant::now(),
444 restart_count: 0,
445 restart_policy: desc.restart,
446 abort_handle: abort_handle.clone(),
447 factory: match desc.restart {
448 RestartPolicy::RunOnce => None,
449 RestartPolicy::Restart { .. } => Some(factory),
450 },
451 };
452
453 {
454 let mut state = self.inner.state.lock();
455 if let Some(old) = state.tasks.remove(&name) {
456 old.abort_handle.abort();
457 }
458 state.tasks.insert(Arc::clone(&name), entry);
459 }
460
461 TaskHandle {
462 name: desc.name,
463 abort: abort_handle,
464 }
465 }
466
467 #[allow(clippy::needless_pass_by_value)] pub fn spawn_blocking<F, R>(&self, name: Arc<str>, f: F) -> BlockingHandle<R>
512 where
513 F: FnOnce() -> R + Send + 'static,
514 R: Send + 'static,
515 {
516 let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
517 let span = tracing::info_span!(
518 "supervised_blocking_task",
519 task.name = %name,
520 task.wall_time_ms = tracing::field::Empty,
521 task.cpu_time_ms = tracing::field::Empty,
522 );
523
524 let semaphore = Arc::clone(&self.inner.blocking_semaphore);
525 let inner = Arc::clone(&self.inner);
526 let name_clone = Arc::clone(&name);
527 let completion_tx = self.inner.completion_tx.clone();
528
529 let outer = tokio::spawn(async move {
532 let _permit = semaphore
533 .acquire_owned()
534 .await
535 .expect("blocking semaphore closed");
536
537 let name_for_measure = Arc::clone(&name_clone);
538 let join_handle = tokio::task::spawn_blocking(move || {
539 let _enter = span.enter();
540 measure_blocking(&name_for_measure, f)
541 });
542 let abort = join_handle.abort_handle();
543
544 {
546 let mut state = inner.state.lock();
547 if let Some(entry) = state.tasks.get_mut(&name_clone) {
548 entry.abort_handle = abort;
549 }
550 }
551
552 let kind = match join_handle.await {
553 Ok(val) => {
554 let _ = tx.send(Ok(val));
555 CompletionKind::Normal
556 }
557 Err(e) if e.is_panic() => {
558 let _ = tx.send(Err(BlockingError::Panicked));
559 CompletionKind::Panicked
560 }
561 Err(_) => {
562 CompletionKind::Cancelled
564 }
565 };
566 let _ = completion_tx.send(Completion {
568 name: name_clone,
569 kind,
570 });
571 });
572 let abort = outer.abort_handle();
573
574 {
576 let mut state = self.inner.state.lock();
577 if let Some(old) = state.tasks.remove(&name) {
578 old.abort_handle.abort();
579 }
580 state.tasks.insert(
581 Arc::clone(&name),
582 TaskEntry {
583 name: Arc::clone(&name),
584 status: TaskStatus::Running,
585 started_at: Instant::now(),
586 restart_count: 0,
587 restart_policy: RestartPolicy::RunOnce,
588 abort_handle: abort.clone(),
589 factory: None,
590 },
591 );
592 }
593
594 BlockingHandle { rx, abort }
595 }
596
597 pub fn spawn_oneshot<F, Fut, R>(&self, name: Arc<str>, factory: F) -> BlockingHandle<R>
625 where
626 F: FnOnce() -> Fut + Send + 'static,
627 Fut: Future<Output = R> + Send + 'static,
628 R: Send + 'static,
629 {
630 let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
631 let cancel = self.inner.cancel.clone();
632 let span = tracing::info_span!("supervised_task", task.name = %name);
633 let join_handle: tokio::task::JoinHandle<Option<R>> = tokio::spawn(
634 async move {
635 let fut = factory();
636 tokio::select! {
637 result = fut => Some(result),
638 () = cancel.cancelled() => None,
639 }
640 }
641 .instrument(span),
642 );
643 let abort = join_handle.abort_handle();
644
645 {
646 let mut state = self.inner.state.lock();
647 if let Some(old) = state.tasks.remove(&name) {
648 old.abort_handle.abort();
649 }
650 state.tasks.insert(
651 Arc::clone(&name),
652 TaskEntry {
653 name: Arc::clone(&name),
654 status: TaskStatus::Running,
655 started_at: Instant::now(),
656 restart_count: 0,
657 restart_policy: RestartPolicy::RunOnce,
658 abort_handle: abort.clone(),
659 factory: None,
660 },
661 );
662 }
663
664 let completion_tx = self.inner.completion_tx.clone();
665 tokio::spawn(async move {
666 let kind = match join_handle.await {
667 Ok(Some(val)) => {
668 let _ = tx.send(Ok(val));
669 CompletionKind::Normal
670 }
671 Err(e) if e.is_panic() => {
672 let _ = tx.send(Err(BlockingError::Panicked));
673 CompletionKind::Panicked
674 }
675 _ => CompletionKind::Cancelled,
676 };
677 let _ = completion_tx.send(Completion { name, kind });
678 });
679 BlockingHandle { rx, abort }
680 }
681
682 pub fn abort(&self, name: &'static str) {
684 let state = self.inner.state.lock();
685 let key: Arc<str> = Arc::from(name);
686 if let Some(entry) = state.tasks.get(&key) {
687 entry.abort_handle.abort();
688 tracing::debug!(task.name = name, "task aborted via supervisor");
689 }
690 }
691
692 pub async fn shutdown_all(&self, timeout: Duration) {
705 self.inner.cancel.cancel();
706 let sleep = tokio::time::sleep(timeout);
707 tokio::pin!(sleep);
708 loop {
709 let active = self.active_count();
710 if active == 0 {
711 break;
712 }
713 let notified = self.inner.shutdown_notify.notified();
716 tokio::select! {
717 biased;
718 () = notified => {
719 }
721 () = &mut sleep => {
722 let mut remaining_names: Vec<Arc<str>> = Vec::new();
723 {
724 let mut state = self.inner.state.lock();
725 for entry in state.tasks.values_mut() {
726 if matches!(
727 entry.status,
728 TaskStatus::Running | TaskStatus::Restarting { .. }
729 ) {
730 remaining_names.push(Arc::clone(&entry.name));
731 entry.abort_handle.abort();
732 entry.status = TaskStatus::Aborted;
733 }
734 }
735 }
736 tracing::warn!(
737 remaining = active,
738 tasks = ?remaining_names,
739 "shutdown timeout — aborting remaining tasks"
740 );
741 break;
742 }
743 }
744 }
745 }
746
747 #[must_use]
752 pub fn snapshot(&self) -> Vec<TaskSnapshot> {
753 let state = self.inner.state.lock();
754 let mut snaps: Vec<TaskSnapshot> = state
755 .tasks
756 .values()
757 .map(|e| TaskSnapshot {
758 name: Arc::clone(&e.name),
759 status: e.status.clone(),
760 started_at: e.started_at,
761 restart_count: e.restart_count,
762 })
763 .collect();
764 snaps.sort_by_key(|s| s.started_at);
765 snaps
766 }
767
768 #[must_use]
770 pub fn active_count(&self) -> usize {
771 let state = self.inner.state.lock();
772 state
773 .tasks
774 .values()
775 .filter(|e| {
776 matches!(
777 e.status,
778 TaskStatus::Running | TaskStatus::Restarting { .. }
779 )
780 })
781 .count()
782 }
783
784 #[must_use]
788 pub fn cancellation_token(&self) -> CancellationToken {
789 self.inner.cancel.clone()
790 }
791
792 fn do_spawn(
796 name: &'static str,
797 factory: &BoxFactory,
798 cancel: CancellationToken,
799 ) -> (AbortHandle, tokio::task::JoinHandle<()>) {
800 let fut = factory();
801 let span = tracing::info_span!("supervised_task", task.name = name);
802 let jh = tokio::spawn(
803 async move {
804 tokio::select! {
805 () = fut => {},
806 () = cancel.cancelled() => {},
807 }
808 }
809 .instrument(span),
810 );
811 let abort = jh.abort_handle();
812 (abort, jh)
813 }
814
815 fn wire_completion_reporter(
817 name: Arc<str>,
818 jh: tokio::task::JoinHandle<()>,
819 completion_tx: mpsc::UnboundedSender<Completion>,
820 ) {
821 tokio::spawn(async move {
822 let kind = match jh.await {
823 Ok(()) => CompletionKind::Normal,
824 Err(e) if e.is_panic() => CompletionKind::Panicked,
825 Err(_) => CompletionKind::Cancelled,
826 };
827 let _ = completion_tx.send(Completion { name, kind });
828 });
829 }
830
831 fn start_reap_driver(
838 inner: Arc<Inner>,
839 mut completion_rx: mpsc::UnboundedReceiver<Completion>,
840 cancel: CancellationToken,
841 ) {
842 tokio::spawn(async move {
843 loop {
845 tokio::select! {
846 biased;
847 Some(completion) = completion_rx.recv() => {
848 Self::handle_completion(&inner, completion).await;
849 }
850 () = cancel.cancelled() => break,
851 }
852 }
853
854 let drain_deadline = tokio::time::Instant::now() + SHUTDOWN_DRAIN_TIMEOUT;
859 let active = Self::has_active_tasks(&inner);
860 tracing::debug!(active, "reap driver entered post-cancel drain phase");
861 loop {
862 if !Self::has_active_tasks(&inner) {
863 break;
864 }
865 let remaining =
866 drain_deadline.saturating_duration_since(tokio::time::Instant::now());
867 if remaining.is_zero() {
868 break;
869 }
870 match tokio::time::timeout(remaining, completion_rx.recv()).await {
871 Ok(Some(completion)) => Self::handle_completion(&inner, completion).await,
872 Ok(None) | Err(_) => break,
874 }
875 }
876 tracing::debug!(
877 active = Self::has_active_tasks(&inner),
878 "reap driver drain phase complete"
879 );
880 });
881 }
882
883 fn has_active_tasks(inner: &Arc<Inner>) -> bool {
885 let state = inner.state.lock();
886 state.tasks.values().any(|e| {
887 matches!(
888 e.status,
889 TaskStatus::Running | TaskStatus::Restarting { .. }
890 )
891 })
892 }
893
894 async fn handle_completion(inner: &Arc<Inner>, completion: Completion) {
900 if inner.cancel.is_cancelled() {
904 {
905 let mut state = inner.state.lock();
906 state.tasks.remove(&completion.name);
907 }
908 inner.shutdown_notify.notify_waiters();
909 return;
910 }
911
912 let Some((attempt, max, delay)) = Self::classify_completion(inner, &completion) else {
913 inner.shutdown_notify.notify_waiters();
916 return;
917 };
918
919 tracing::warn!(
920 task.name = %completion.name,
921 attempt,
922 max,
923 delay_ms = delay.as_millis(),
924 "restarting supervised task"
925 );
926
927 if !delay.is_zero() {
928 tokio::time::sleep(delay).await;
929 }
930
931 Self::do_restart(inner, &completion.name, attempt);
932 }
933
934 fn classify_completion(
938 inner: &Arc<Inner>,
939 completion: &Completion,
940 ) -> Option<(u32, u32, Duration)> {
941 let mut state = inner.state.lock();
942 let entry = state.tasks.get_mut(&completion.name)?;
943
944 match completion.kind {
945 CompletionKind::Panicked => {
946 tracing::warn!(task.name = %completion.name, "supervised task panicked");
947 }
948 CompletionKind::Normal => {
949 tracing::info!(task.name = %completion.name, "supervised task completed");
950 }
951 CompletionKind::Cancelled => {
952 tracing::debug!(task.name = %completion.name, "supervised task cancelled");
953 }
954 }
955
956 match entry.restart_policy {
957 RestartPolicy::RunOnce => {
958 entry.status = TaskStatus::Completed;
959 state.tasks.remove(&completion.name);
960 None
961 }
962 RestartPolicy::Restart { max, base_delay } => {
963 if completion.kind != CompletionKind::Panicked {
965 entry.status = TaskStatus::Completed;
966 state.tasks.remove(&completion.name);
967 return None;
968 }
969 if entry.restart_count >= max {
970 let reason = format!("panicked after {max} restart(s)");
971 tracing::error!(
972 task.name = %completion.name,
973 attempts = max,
974 "task failed permanently"
975 );
976 entry.status = TaskStatus::Failed { reason };
977 None
978 } else {
979 let attempt = entry.restart_count + 1;
980 entry.status = TaskStatus::Restarting { attempt, max };
981 let multiplier = 1_u32
983 .checked_shl(attempt.saturating_sub(1))
984 .unwrap_or(u32::MAX);
985 let delay = base_delay.saturating_mul(multiplier).min(MAX_RESTART_DELAY);
986 Some((attempt, max, delay))
987 }
988 }
989 }
990 }
992
993 fn do_restart(inner: &Arc<Inner>, name: &Arc<str>, attempt: u32) {
995 let spawn_params = {
996 let mut state = inner.state.lock();
997 let Some(entry) = state.tasks.get_mut(name.as_ref()) else {
998 tracing::debug!(
999 task.name = %name,
1000 "task removed during restart delay — skipping"
1001 );
1002 return;
1003 };
1004 if !matches!(entry.status, TaskStatus::Restarting { .. }) {
1005 return;
1006 }
1007 let Some(factory) = &entry.factory else {
1008 return;
1009 };
1010 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(factory)) {
1013 Err(_) => {
1014 let reason = format!("factory panicked on restart attempt {attempt}");
1015 tracing::error!(task.name = %name, attempt, "factory panicked during restart");
1016 entry.status = TaskStatus::Failed { reason };
1017 None
1018 }
1019 Ok(fut) => Some((
1020 fut,
1021 inner.cancel.clone(),
1022 inner.completion_tx.clone(),
1023 name.clone(),
1024 )),
1025 }
1026 };
1028
1029 let Some((fut, cancel, completion_tx, name)) = spawn_params else {
1030 return;
1031 };
1032
1033 let span = tracing::info_span!("supervised_task", task.name = %name);
1034 let jh = tokio::spawn(
1035 async move {
1036 tokio::select! {
1037 () = fut => {},
1038 () = cancel.cancelled() => {},
1039 }
1040 }
1041 .instrument(span),
1042 );
1043 let new_abort = jh.abort_handle();
1044
1045 {
1046 let mut state = inner.state.lock();
1047 if let Some(entry) = state.tasks.get_mut(name.as_ref()) {
1048 entry.restart_count = attempt;
1049 entry.status = TaskStatus::Running;
1050 entry.abort_handle = new_abort;
1051 }
1052 }
1053
1054 Self::wire_completion_reporter(name, jh, completion_tx);
1055 }
1056}
1057
1058#[inline]
1062fn measure_blocking<F, R>(name: &str, f: F) -> R
1063where
1064 F: FnOnce() -> R,
1065{
1066 use cpu_time::ThreadTime;
1067 let wall_start = std::time::Instant::now();
1068 let cpu_start = ThreadTime::now();
1069 let result = f();
1070 let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
1071 let cpu_ms = cpu_start.elapsed().as_secs_f64() * 1000.0;
1072 metrics::histogram!("zeph.task.wall_time_ms", "task" => name.to_owned()).record(wall_ms);
1073 metrics::histogram!("zeph.task.cpu_time_ms", "task" => name.to_owned()).record(cpu_ms);
1074 tracing::Span::current().record("task.wall_time_ms", wall_ms);
1075 tracing::Span::current().record("task.cpu_time_ms", cpu_ms);
1076 result
1077}
1078
1079impl BlockingSpawner for TaskSupervisor {
1082 fn spawn_blocking_named(
1088 &self,
1089 name: Arc<str>,
1090 f: Box<dyn FnOnce() + Send + 'static>,
1091 ) -> tokio::task::JoinHandle<()> {
1092 let handle = self.spawn_blocking(Arc::clone(&name), f);
1093 tokio::spawn(async move {
1094 if let Err(e) = handle.join().await {
1095 tracing::error!(task.name = %name, error = %e, "supervised blocking task failed");
1096 }
1097 })
1098 }
1099}
1100
1101#[cfg(test)]
1104mod tests {
1105 use std::sync::Arc;
1106 use std::sync::atomic::{AtomicU32, Ordering};
1107 use std::time::Duration;
1108
1109 use tokio_util::sync::CancellationToken;
1110
1111 use super::*;
1112
1113 fn make_supervisor() -> (TaskSupervisor, CancellationToken) {
1114 let cancel = CancellationToken::new();
1115 let sup = TaskSupervisor::new(cancel.clone());
1116 (sup, cancel)
1117 }
1118
1119 #[tokio::test]
1120 async fn test_spawn_and_complete() {
1121 let (sup, _cancel) = make_supervisor();
1122
1123 let done = Arc::new(tokio::sync::Notify::new());
1124 let done2 = Arc::clone(&done);
1125
1126 sup.spawn(TaskDescriptor {
1127 name: "simple",
1128 restart: RestartPolicy::RunOnce,
1129 factory: move || {
1130 let d = Arc::clone(&done2);
1131 async move {
1132 d.notify_one();
1133 }
1134 },
1135 });
1136
1137 tokio::time::timeout(Duration::from_secs(2), done.notified())
1138 .await
1139 .expect("task should complete");
1140
1141 tokio::time::sleep(Duration::from_millis(50)).await;
1142 assert_eq!(
1143 sup.active_count(),
1144 0,
1145 "RunOnce task should be removed after completion"
1146 );
1147 }
1148
1149 #[tokio::test]
1150 async fn test_panic_capture() {
1151 let (sup, _cancel) = make_supervisor();
1152
1153 sup.spawn(TaskDescriptor {
1154 name: "panicking",
1155 restart: RestartPolicy::RunOnce,
1156 factory: || async { panic!("intentional test panic") },
1157 });
1158
1159 tokio::time::sleep(Duration::from_millis(200)).await;
1160
1161 let snaps = sup.snapshot();
1162 assert!(
1163 snaps.iter().all(|s| s.name.as_ref() != "panicking"),
1164 "entry should be reaped"
1165 );
1166 assert_eq!(
1167 sup.active_count(),
1168 0,
1169 "active count must be 0 after RunOnce panic"
1170 );
1171 }
1172
1173 #[tokio::test]
1176 async fn test_restart_only_on_panic() {
1177 let (sup, _cancel) = make_supervisor();
1178
1179 let normal_counter = Arc::new(AtomicU32::new(0));
1181 let nc = Arc::clone(&normal_counter);
1182 sup.spawn(TaskDescriptor {
1183 name: "normal-exit",
1184 restart: RestartPolicy::Restart {
1185 max: 3,
1186 base_delay: Duration::from_millis(10),
1187 },
1188 factory: move || {
1189 let c = Arc::clone(&nc);
1190 async move {
1191 c.fetch_add(1, Ordering::SeqCst);
1192 }
1194 },
1195 });
1196
1197 tokio::time::sleep(Duration::from_millis(300)).await;
1198 assert_eq!(
1199 normal_counter.load(Ordering::SeqCst),
1200 1,
1201 "normal exit must not restart"
1202 );
1203 assert!(
1204 sup.snapshot()
1205 .iter()
1206 .all(|s| s.name.as_ref() != "normal-exit"),
1207 "entry removed after normal exit"
1208 );
1209
1210 let panic_counter = Arc::new(AtomicU32::new(0));
1212 let pc = Arc::clone(&panic_counter);
1213 sup.spawn(TaskDescriptor {
1214 name: "panic-exit",
1215 restart: RestartPolicy::Restart {
1216 max: 2,
1217 base_delay: Duration::from_millis(10),
1218 },
1219 factory: move || {
1220 let c = Arc::clone(&pc);
1221 async move {
1222 c.fetch_add(1, Ordering::SeqCst);
1223 panic!("test panic");
1224 }
1225 },
1226 });
1227
1228 tokio::time::sleep(Duration::from_millis(500)).await;
1230 assert!(
1231 panic_counter.load(Ordering::SeqCst) >= 3,
1232 "panicking task must restart max times"
1233 );
1234 let snap = sup
1235 .snapshot()
1236 .into_iter()
1237 .find(|s| s.name.as_ref() == "panic-exit");
1238 assert!(
1239 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1240 "task must be Failed after exhausting restarts"
1241 );
1242 }
1243
1244 #[tokio::test]
1245 async fn test_restart_policy() {
1246 let (sup, _cancel) = make_supervisor();
1247
1248 let counter = Arc::new(AtomicU32::new(0));
1249 let counter2 = Arc::clone(&counter);
1250
1251 sup.spawn(TaskDescriptor {
1252 name: "restartable",
1253 restart: RestartPolicy::Restart {
1254 max: 2,
1255 base_delay: Duration::from_millis(10),
1256 },
1257 factory: move || {
1258 let c = Arc::clone(&counter2);
1259 async move {
1260 c.fetch_add(1, Ordering::SeqCst);
1261 panic!("always panic");
1262 }
1263 },
1264 });
1265
1266 tokio::time::sleep(Duration::from_millis(500)).await;
1267
1268 let runs = counter.load(Ordering::SeqCst);
1269 assert!(
1270 runs >= 3,
1271 "expected at least 3 invocations (initial + 2 restarts), got {runs}"
1272 );
1273
1274 let snaps = sup.snapshot();
1275 let snap = snaps.iter().find(|s| s.name.as_ref() == "restartable");
1276 assert!(snap.is_some(), "failed task should remain in registry");
1277 assert!(
1278 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1279 "task should be Failed after exhausting retries"
1280 );
1281 }
1282
1283 #[tokio::test]
1285 async fn test_exponential_backoff() {
1286 let (sup, _cancel) = make_supervisor();
1287
1288 let timestamps = Arc::new(parking_lot::Mutex::new(Vec::<std::time::Instant>::new()));
1289 let ts = Arc::clone(×tamps);
1290
1291 sup.spawn(TaskDescriptor {
1292 name: "backoff-task",
1293 restart: RestartPolicy::Restart {
1294 max: 3,
1295 base_delay: Duration::from_millis(50),
1296 },
1297 factory: move || {
1298 let t = Arc::clone(&ts);
1299 async move {
1300 t.lock().push(std::time::Instant::now());
1301 panic!("always panic");
1302 }
1303 },
1304 });
1305
1306 tokio::time::sleep(Duration::from_millis(800)).await;
1308
1309 let ts = timestamps.lock();
1310 assert!(
1311 ts.len() >= 3,
1312 "expected at least 3 invocations, got {}",
1313 ts.len()
1314 );
1315
1316 if ts.len() >= 3 {
1318 let d1 = ts[1].duration_since(ts[0]);
1319 let d2 = ts[2].duration_since(ts[1]);
1320 assert!(
1322 d2 >= d1.mul_f64(1.5),
1323 "expected exponential backoff: d1={d1:?} d2={d2:?}"
1324 );
1325 }
1326 }
1327
1328 #[tokio::test]
1329 async fn test_graceful_shutdown() {
1330 let (sup, _cancel) = make_supervisor();
1331
1332 for name in ["svc-a", "svc-b", "svc-c"] {
1333 sup.spawn(TaskDescriptor {
1334 name,
1335 restart: RestartPolicy::RunOnce,
1336 factory: || async {
1337 tokio::time::sleep(Duration::from_mins(1)).await;
1338 },
1339 });
1340 }
1341
1342 assert_eq!(sup.active_count(), 3);
1343
1344 tokio::time::timeout(
1345 Duration::from_secs(2),
1346 sup.shutdown_all(Duration::from_secs(1)),
1347 )
1348 .await
1349 .expect("shutdown should complete within timeout");
1350 }
1351
1352 #[tokio::test]
1354 async fn test_force_abort_marks_aborted() {
1355 let cancel = CancellationToken::new();
1356 let sup = TaskSupervisor::new(cancel.clone());
1357
1358 sup.spawn(TaskDescriptor {
1359 name: "stubborn-for-abort",
1360 restart: RestartPolicy::RunOnce,
1361 factory: || async {
1362 std::future::pending::<()>().await;
1364 },
1365 });
1366
1367 sup.shutdown_all(Duration::from_millis(1)).await;
1369
1370 let snaps = sup.snapshot();
1372 if let Some(snap) = snaps
1373 .iter()
1374 .find(|s| s.name.as_ref() == "stubborn-for-abort")
1375 {
1376 assert_eq!(
1377 snap.status,
1378 TaskStatus::Aborted,
1379 "force-aborted task must have Aborted status"
1380 );
1381 }
1382 }
1384
1385 #[tokio::test]
1386 async fn test_registry_snapshot() {
1387 let (sup, _cancel) = make_supervisor();
1388
1389 for name in ["alpha", "beta"] {
1390 sup.spawn(TaskDescriptor {
1391 name,
1392 restart: RestartPolicy::RunOnce,
1393 factory: || async {
1394 tokio::time::sleep(Duration::from_secs(10)).await;
1395 },
1396 });
1397 }
1398
1399 let snaps = sup.snapshot();
1400 assert_eq!(snaps.len(), 2);
1401 let names: Vec<&str> = snaps.iter().map(|s| s.name.as_ref()).collect();
1402 assert!(names.contains(&"alpha"));
1403 assert!(names.contains(&"beta"));
1404 assert!(snaps.iter().all(|s| s.status == TaskStatus::Running));
1405 }
1406
1407 #[tokio::test]
1408 async fn test_blocking_returns_value() {
1409 let (sup, cancel) = make_supervisor();
1410
1411 let handle: BlockingHandle<u32> = sup.spawn_blocking(Arc::from("compute"), || 42_u32);
1412 let result = handle.join().await.expect("should return value");
1413 assert_eq!(result, 42);
1414 cancel.cancel();
1415 }
1416
1417 #[tokio::test]
1418 async fn test_blocking_panic() {
1419 let (sup, _cancel) = make_supervisor();
1420
1421 let handle: BlockingHandle<u32> =
1422 sup.spawn_blocking(Arc::from("panicking-compute"), || panic!("intentional"));
1423 let err = handle
1424 .join()
1425 .await
1426 .expect_err("should return error on panic");
1427 assert_eq!(err, BlockingError::Panicked);
1428 }
1429
1430 #[tokio::test]
1432 async fn test_blocking_registered_in_registry() {
1433 let (sup, cancel) = make_supervisor();
1434
1435 let (tx, rx) = std::sync::mpsc::channel::<()>();
1436 let _handle: BlockingHandle<()> =
1437 sup.spawn_blocking(Arc::from("blocking-task"), move || {
1438 let _ = rx.recv();
1440 });
1441
1442 tokio::time::sleep(Duration::from_millis(10)).await;
1443 assert_eq!(
1444 sup.active_count(),
1445 1,
1446 "blocking task must appear in active_count"
1447 );
1448
1449 let _ = tx.send(());
1450 tokio::time::sleep(Duration::from_millis(100)).await;
1451 assert_eq!(
1452 sup.active_count(),
1453 0,
1454 "blocking task must be removed after completion"
1455 );
1456
1457 cancel.cancel();
1458 }
1459
1460 #[tokio::test]
1462 async fn test_oneshot_registered_in_registry() {
1463 let (sup, cancel) = make_supervisor();
1464
1465 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
1466 let _handle: BlockingHandle<()> =
1467 sup.spawn_oneshot(Arc::from("oneshot-task"), move || async move {
1468 let _ = rx.await;
1469 });
1470
1471 tokio::time::sleep(Duration::from_millis(10)).await;
1472 assert_eq!(
1473 sup.active_count(),
1474 1,
1475 "oneshot task must appear in active_count"
1476 );
1477
1478 let _ = tx.send(());
1479 tokio::time::sleep(Duration::from_millis(50)).await;
1480 assert_eq!(
1481 sup.active_count(),
1482 0,
1483 "oneshot task must be removed after completion"
1484 );
1485
1486 cancel.cancel();
1487 }
1488
1489 #[tokio::test]
1490 async fn test_restart_max_zero() {
1491 let (sup, _cancel) = make_supervisor();
1492
1493 let counter = Arc::new(AtomicU32::new(0));
1494 let counter2 = Arc::clone(&counter);
1495
1496 sup.spawn(TaskDescriptor {
1497 name: "zero-max",
1498 restart: RestartPolicy::Restart {
1499 max: 0,
1500 base_delay: Duration::from_millis(10),
1501 },
1502 factory: move || {
1503 let c = Arc::clone(&counter2);
1504 async move {
1505 c.fetch_add(1, Ordering::SeqCst);
1506 panic!("always panic");
1507 }
1508 },
1509 });
1510
1511 tokio::time::sleep(Duration::from_millis(200)).await;
1512
1513 assert_eq!(
1514 counter.load(Ordering::SeqCst),
1515 1,
1516 "max=0 should not restart"
1517 );
1518
1519 let snaps = sup.snapshot();
1520 let snap = snaps.iter().find(|s| s.name.as_ref() == "zero-max");
1521 assert!(snap.is_some(), "entry should remain as Failed");
1522 assert!(
1523 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1524 "status should be Failed"
1525 );
1526 }
1527
1528 #[tokio::test]
1530 async fn test_concurrent_spawns() {
1531 static NAMES: [&str; 50] = [
1533 "t00", "t01", "t02", "t03", "t04", "t05", "t06", "t07", "t08", "t09", "t10", "t11",
1534 "t12", "t13", "t14", "t15", "t16", "t17", "t18", "t19", "t20", "t21", "t22", "t23",
1535 "t24", "t25", "t26", "t27", "t28", "t29", "t30", "t31", "t32", "t33", "t34", "t35",
1536 "t36", "t37", "t38", "t39", "t40", "t41", "t42", "t43", "t44", "t45", "t46", "t47",
1537 "t48", "t49",
1538 ];
1539 let (sup, cancel) = make_supervisor();
1540
1541 let completed = Arc::new(AtomicU32::new(0));
1542 for name in &NAMES {
1543 let c = Arc::clone(&completed);
1544 sup.spawn(TaskDescriptor {
1545 name,
1546 restart: RestartPolicy::RunOnce,
1547 factory: move || {
1548 let c = Arc::clone(&c);
1549 async move {
1550 c.fetch_add(1, Ordering::SeqCst);
1551 }
1552 },
1553 });
1554 }
1555
1556 tokio::time::timeout(Duration::from_secs(5), async {
1558 loop {
1559 if completed.load(Ordering::SeqCst) == 50 {
1560 break;
1561 }
1562 tokio::time::sleep(Duration::from_millis(10)).await;
1563 }
1564 })
1565 .await
1566 .expect("all 50 tasks should complete");
1567
1568 tokio::time::sleep(Duration::from_millis(100)).await;
1570 assert_eq!(sup.active_count(), 0, "all tasks must be reaped");
1571
1572 cancel.cancel();
1573 }
1574
1575 #[tokio::test]
1576 async fn test_shutdown_timeout_expiry() {
1577 let cancel = CancellationToken::new();
1578 let sup = TaskSupervisor::new(cancel.clone());
1579
1580 sup.spawn(TaskDescriptor {
1581 name: "stubborn",
1582 restart: RestartPolicy::RunOnce,
1583 factory: || async {
1584 tokio::time::sleep(Duration::from_mins(1)).await;
1585 },
1586 });
1587
1588 assert_eq!(sup.active_count(), 1);
1589
1590 tokio::time::timeout(
1591 Duration::from_secs(2),
1592 sup.shutdown_all(Duration::from_millis(50)),
1593 )
1594 .await
1595 .expect("shutdown_all should return even on timeout expiry");
1596
1597 assert!(
1598 cancel.is_cancelled(),
1599 "cancel token must be cancelled after shutdown"
1600 );
1601 }
1602
1603 #[tokio::test]
1604 async fn test_cancellation_token() {
1605 let cancel = CancellationToken::new();
1606 let sup = TaskSupervisor::new(cancel.clone());
1607
1608 assert!(!sup.cancellation_token().is_cancelled());
1609
1610 sup.shutdown_all(Duration::from_millis(100)).await;
1611
1612 assert!(
1613 sup.cancellation_token().is_cancelled(),
1614 "token must be cancelled after shutdown"
1615 );
1616 }
1617
1618 #[tokio::test]
1624 async fn test_shutdown_drains_post_cancel_completions() {
1625 let cancel = CancellationToken::new();
1626 let sup = TaskSupervisor::new(cancel.clone());
1627
1628 for name in [
1629 "loop-1", "loop-2", "loop-3", "loop-4", "loop-5", "loop-6", "loop-7",
1630 ] {
1631 let cancel_inner = cancel.clone();
1632 sup.spawn(TaskDescriptor {
1633 name,
1634 restart: RestartPolicy::RunOnce,
1635 factory: move || {
1636 let c = cancel_inner.clone();
1637 async move {
1638 c.cancelled().await;
1639 for _ in 0..64 {
1641 tokio::task::yield_now().await;
1642 }
1643 }
1644 },
1645 });
1646 }
1647 assert_eq!(sup.active_count(), 7);
1648
1649 sup.shutdown_all(Duration::from_secs(2)).await;
1650
1651 assert_eq!(
1652 sup.active_count(),
1653 0,
1654 "all tasks must be reaped after shutdown (#3161)"
1655 );
1656 }
1657
1658 #[tokio::test]
1659 async fn test_blocking_spawner_task_appears_in_snapshot() {
1660 use crate::BlockingSpawner;
1662
1663 let cancel = CancellationToken::new();
1664 let sup = TaskSupervisor::new(cancel);
1665
1666 let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<()>();
1667 let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>();
1668
1669 let handle = sup.spawn_blocking_named(
1670 Arc::from("chunk_file"),
1671 Box::new(move || {
1672 let _ = ready_tx.send(());
1674 let _ = release_rx.blocking_recv();
1676 }),
1677 );
1678
1679 ready_rx.await.expect("task should start");
1681
1682 let snapshot = sup.snapshot();
1683 assert!(
1684 snapshot.iter().any(|t| t.name.as_ref() == "chunk_file"),
1685 "chunk_file task must appear in supervisor snapshot"
1686 );
1687
1688 let _ = release_tx.send(());
1690 handle.await.expect("task should complete");
1691 }
1692
1693 #[test]
1699 fn test_measure_blocking_emits_metrics() {
1700 use metrics_util::debugging::DebuggingRecorder;
1701
1702 let recorder = DebuggingRecorder::new();
1703 let snapshotter = recorder.snapshotter();
1704
1705 metrics::with_local_recorder(&recorder, || {
1708 measure_blocking("test_task", || std::hint::black_box(42_u64));
1709 });
1710
1711 let snapshot = snapshotter.snapshot();
1712 let metric_names: Vec<String> = snapshot
1713 .into_vec()
1714 .into_iter()
1715 .map(|(k, _, _, _)| k.key().name().to_owned())
1716 .collect();
1717
1718 assert!(
1719 metric_names.iter().any(|n| n == "zeph.task.wall_time_ms"),
1720 "expected zeph.task.wall_time_ms histogram; got: {metric_names:?}"
1721 );
1722 assert!(
1723 metric_names.iter().any(|n| n == "zeph.task.cpu_time_ms"),
1724 "expected zeph.task.cpu_time_ms histogram; got: {metric_names:?}"
1725 );
1726 }
1727
1728 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1734 async fn test_spawn_blocking_semaphore_cap() {
1735 let (sup, _cancel) = make_supervisor();
1736 let concurrent = Arc::new(AtomicU32::new(0));
1737 let max_concurrent = Arc::new(AtomicU32::new(0));
1738 let barrier = Arc::new(std::sync::Barrier::new(1)); let mut handles = Vec::new();
1741 for i in 0u32..16 {
1742 let c = Arc::clone(&concurrent);
1743 let m = Arc::clone(&max_concurrent);
1744 let name: Arc<str> = Arc::from(format!("blocking-{i}").as_str());
1745 let h = sup.spawn_blocking(name, move || {
1746 let prev = c.fetch_add(1, Ordering::SeqCst);
1747 let mut cur_max = m.load(Ordering::SeqCst);
1749 while prev + 1 > cur_max {
1750 match m.compare_exchange(cur_max, prev + 1, Ordering::SeqCst, Ordering::SeqCst)
1751 {
1752 Ok(_) => break,
1753 Err(x) => cur_max = x,
1754 }
1755 }
1756 std::thread::sleep(std::time::Duration::from_millis(20));
1758 c.fetch_sub(1, Ordering::SeqCst);
1759 });
1760 handles.push(h);
1761 }
1762
1763 for h in handles {
1764 h.join().await.expect("blocking task should succeed");
1765 }
1766 drop(barrier);
1767
1768 let observed = max_concurrent.load(Ordering::SeqCst);
1769 assert!(
1770 observed <= 8,
1771 "observed {observed} concurrent blocking tasks; expected ≤ 8 (semaphore cap)"
1772 );
1773 }
1774}