1use std::collections::HashMap;
47use std::future::Future;
48use std::pin::Pin;
49use std::sync::Arc;
50use std::time::{Duration, Instant};
51
52use tokio::sync::{mpsc, oneshot};
53use tokio::task::AbortHandle;
54use tokio_util::sync::CancellationToken;
55use tracing::Instrument as _;
56use zeph_common::BlockingSpawner;
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum RestartPolicy {
65 RunOnce,
67 Restart { max: u32, base_delay: Duration },
89}
90
91pub const MAX_RESTART_DELAY: Duration = Duration::from_mins(1);
93
94const SHUTDOWN_DRAIN_TIMEOUT: Duration = Duration::from_secs(5);
99
100pub struct TaskDescriptor<F> {
105 pub name: &'static str,
110 pub restart: RestartPolicy,
112 pub factory: F,
114}
115
116#[derive(Debug, Clone)]
120pub struct TaskHandle {
121 name: &'static str,
122 abort: AbortHandle,
123}
124
125impl TaskHandle {
126 pub fn abort(&self) {
128 tracing::debug!(task.name = self.name, "task aborted via handle");
129 self.abort.abort();
130 }
131
132 #[must_use]
134 pub fn name(&self) -> &'static str {
135 self.name
136 }
137}
138
139#[derive(Debug, PartialEq, Eq)]
141pub enum BlockingError {
142 Panicked,
144 SupervisorDropped,
146}
147
148impl std::fmt::Display for BlockingError {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 match self {
151 Self::Panicked => write!(f, "supervised blocking task panicked"),
152 Self::SupervisorDropped => write!(f, "supervisor dropped before task completed"),
153 }
154 }
155}
156
157impl std::error::Error for BlockingError {}
158
159pub struct BlockingHandle<R> {
168 rx: oneshot::Receiver<Result<R, BlockingError>>,
169 abort: AbortHandle,
170}
171
172impl<R> BlockingHandle<R> {
173 pub async fn join(self) -> Result<R, BlockingError> {
181 self.rx
182 .await
183 .unwrap_or(Err(BlockingError::SupervisorDropped))
184 }
185
186 pub fn abort(&self) {
188 self.abort.abort();
189 }
190}
191
192#[derive(Debug, Clone, PartialEq, Eq)]
194pub enum TaskStatus {
195 Running,
197 Restarting { attempt: u32, max: u32 },
199 Completed,
201 Aborted,
203 Failed { reason: String },
205}
206
207#[derive(Debug, Clone)]
209pub struct TaskSnapshot {
222 pub name: Arc<str>,
224 pub status: TaskStatus,
226 pub started_at: Instant,
228 pub restart_count: u32,
230}
231
232type BoxFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
235type BoxFactory = Box<dyn Fn() -> BoxFuture + Send + Sync>;
236
237struct TaskEntry {
238 name: Arc<str>,
239 status: TaskStatus,
240 started_at: Instant,
241 restart_count: u32,
242 restart_policy: RestartPolicy,
243 abort_handle: AbortHandle,
244 factory: Option<BoxFactory>,
246}
247
248#[derive(Debug, Clone, Copy, PartialEq, Eq)]
250enum CompletionKind {
251 Normal,
253 Panicked,
255 Cancelled,
257}
258
259struct Completion {
260 name: Arc<str>,
261 kind: CompletionKind,
262}
263
264struct SupervisorState {
265 tasks: HashMap<Arc<str>, TaskEntry>,
266}
267
268struct Inner {
269 state: parking_lot::Mutex<SupervisorState>,
270 completion_tx: mpsc::UnboundedSender<Completion>,
274 cancel: CancellationToken,
275 blocking_semaphore: Arc<tokio::sync::Semaphore>,
278}
279
280#[derive(Clone)]
316pub struct TaskSupervisor {
317 inner: Arc<Inner>,
318}
319
320impl TaskSupervisor {
321 #[must_use]
328 pub fn new(cancel: CancellationToken) -> Self {
329 let (completion_tx, completion_rx) = mpsc::unbounded_channel();
333 let inner = Arc::new(Inner {
334 state: parking_lot::Mutex::new(SupervisorState {
335 tasks: HashMap::new(),
336 }),
337 completion_tx,
338 cancel: cancel.clone(),
339 blocking_semaphore: Arc::new(tokio::sync::Semaphore::new(8)),
340 });
341
342 Self::start_reap_driver(Arc::clone(&inner), completion_rx, cancel);
343
344 Self { inner }
345 }
346
347 pub fn spawn<F, Fut>(&self, desc: TaskDescriptor<F>) -> TaskHandle
372 where
373 F: Fn() -> Fut + Send + Sync + 'static,
374 Fut: Future<Output = ()> + Send + 'static,
375 {
376 let factory: BoxFactory = Box::new(move || Box::pin((desc.factory)()));
377 let cancel = self.inner.cancel.clone();
378 let completion_tx = self.inner.completion_tx.clone();
379 let name: Arc<str> = Arc::from(desc.name);
380
381 let (abort_handle, jh) = Self::do_spawn(desc.name, &factory, cancel);
382 Self::wire_completion_reporter(Arc::clone(&name), jh, completion_tx);
383
384 let entry = TaskEntry {
385 name: Arc::clone(&name),
386 status: TaskStatus::Running,
387 started_at: Instant::now(),
388 restart_count: 0,
389 restart_policy: desc.restart,
390 abort_handle: abort_handle.clone(),
391 factory: match desc.restart {
392 RestartPolicy::RunOnce => None,
393 RestartPolicy::Restart { .. } => Some(factory),
394 },
395 };
396
397 {
398 let mut state = self.inner.state.lock();
399 if let Some(old) = state.tasks.remove(&name) {
400 old.abort_handle.abort();
401 }
402 state.tasks.insert(Arc::clone(&name), entry);
403 }
404
405 TaskHandle {
406 name: desc.name,
407 abort: abort_handle,
408 }
409 }
410
411 #[allow(clippy::needless_pass_by_value)] pub fn spawn_blocking<F, R>(&self, name: Arc<str>, f: F) -> BlockingHandle<R>
456 where
457 F: FnOnce() -> R + Send + 'static,
458 R: Send + 'static,
459 {
460 let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
461 #[cfg(feature = "task-metrics")]
462 let span = tracing::info_span!(
463 "supervised_blocking_task",
464 task.name = %name,
465 task.wall_time_ms = tracing::field::Empty,
466 task.cpu_time_ms = tracing::field::Empty,
467 );
468 #[cfg(not(feature = "task-metrics"))]
469 let span = tracing::info_span!("supervised_blocking_task", task.name = %name);
470
471 let semaphore = Arc::clone(&self.inner.blocking_semaphore);
472 let inner = Arc::clone(&self.inner);
473 let name_clone = Arc::clone(&name);
474 let completion_tx = self.inner.completion_tx.clone();
475
476 let outer = tokio::spawn(async move {
479 let _permit = semaphore
480 .acquire_owned()
481 .await
482 .expect("blocking semaphore closed");
483
484 let name_for_measure = Arc::clone(&name_clone);
485 let join_handle = tokio::task::spawn_blocking(move || {
486 let _enter = span.enter();
487 measure_blocking(&name_for_measure, f)
488 });
489 let abort = join_handle.abort_handle();
490
491 {
493 let mut state = inner.state.lock();
494 if let Some(entry) = state.tasks.get_mut(&name_clone) {
495 entry.abort_handle = abort;
496 }
497 }
498
499 let kind = match join_handle.await {
500 Ok(val) => {
501 let _ = tx.send(Ok(val));
502 CompletionKind::Normal
503 }
504 Err(e) if e.is_panic() => {
505 let _ = tx.send(Err(BlockingError::Panicked));
506 CompletionKind::Panicked
507 }
508 Err(_) => {
509 CompletionKind::Cancelled
511 }
512 };
513 let _ = completion_tx.send(Completion {
515 name: name_clone,
516 kind,
517 });
518 });
519 let abort = outer.abort_handle();
520
521 {
523 let mut state = self.inner.state.lock();
524 if let Some(old) = state.tasks.remove(&name) {
525 old.abort_handle.abort();
526 }
527 state.tasks.insert(
528 Arc::clone(&name),
529 TaskEntry {
530 name: Arc::clone(&name),
531 status: TaskStatus::Running,
532 started_at: Instant::now(),
533 restart_count: 0,
534 restart_policy: RestartPolicy::RunOnce,
535 abort_handle: abort.clone(),
536 factory: None,
537 },
538 );
539 }
540
541 BlockingHandle { rx, abort }
542 }
543
544 pub fn spawn_oneshot<F, Fut, R>(&self, name: Arc<str>, factory: F) -> BlockingHandle<R>
572 where
573 F: FnOnce() -> Fut + Send + 'static,
574 Fut: Future<Output = R> + Send + 'static,
575 R: Send + 'static,
576 {
577 let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
578 let cancel = self.inner.cancel.clone();
579 let span = tracing::info_span!("supervised_task", task.name = %name);
580 let join_handle: tokio::task::JoinHandle<Option<R>> = tokio::spawn(
581 async move {
582 let fut = factory();
583 tokio::select! {
584 result = fut => Some(result),
585 () = cancel.cancelled() => None,
586 }
587 }
588 .instrument(span),
589 );
590 let abort = join_handle.abort_handle();
591
592 {
593 let mut state = self.inner.state.lock();
594 if let Some(old) = state.tasks.remove(&name) {
595 old.abort_handle.abort();
596 }
597 state.tasks.insert(
598 Arc::clone(&name),
599 TaskEntry {
600 name: Arc::clone(&name),
601 status: TaskStatus::Running,
602 started_at: Instant::now(),
603 restart_count: 0,
604 restart_policy: RestartPolicy::RunOnce,
605 abort_handle: abort.clone(),
606 factory: None,
607 },
608 );
609 }
610
611 let completion_tx = self.inner.completion_tx.clone();
612 tokio::spawn(async move {
613 let kind = match join_handle.await {
614 Ok(Some(val)) => {
615 let _ = tx.send(Ok(val));
616 CompletionKind::Normal
617 }
618 Err(e) if e.is_panic() => {
619 let _ = tx.send(Err(BlockingError::Panicked));
620 CompletionKind::Panicked
621 }
622 _ => CompletionKind::Cancelled,
623 };
624 let _ = completion_tx.send(Completion { name, kind });
625 });
626 BlockingHandle { rx, abort }
627 }
628
629 pub fn abort(&self, name: &'static str) {
631 let state = self.inner.state.lock();
632 let key: Arc<str> = Arc::from(name);
633 if let Some(entry) = state.tasks.get(&key) {
634 entry.abort_handle.abort();
635 tracing::debug!(task.name = name, "task aborted via supervisor");
636 }
637 }
638
639 pub async fn shutdown_all(&self, timeout: Duration) {
652 self.inner.cancel.cancel();
653 let deadline = tokio::time::Instant::now() + timeout;
654 loop {
655 let active = self.active_count();
656 if active == 0 {
657 break;
658 }
659 if tokio::time::Instant::now() >= deadline {
660 let mut remaining_names: Vec<Arc<str>> = Vec::new();
661 {
662 let mut state = self.inner.state.lock();
663 for entry in state.tasks.values_mut() {
664 if matches!(
665 entry.status,
666 TaskStatus::Running | TaskStatus::Restarting { .. }
667 ) {
668 remaining_names.push(Arc::clone(&entry.name));
669 entry.abort_handle.abort();
670 entry.status = TaskStatus::Aborted;
671 }
672 }
673 }
674 tracing::warn!(
675 remaining = active,
676 tasks = ?remaining_names,
677 "shutdown timeout — aborting remaining tasks"
678 );
679 break;
680 }
681 tokio::time::sleep(Duration::from_millis(50)).await;
682 }
683 }
684
685 #[must_use]
690 pub fn snapshot(&self) -> Vec<TaskSnapshot> {
691 let state = self.inner.state.lock();
692 let mut snaps: Vec<TaskSnapshot> = state
693 .tasks
694 .values()
695 .map(|e| TaskSnapshot {
696 name: Arc::clone(&e.name),
697 status: e.status.clone(),
698 started_at: e.started_at,
699 restart_count: e.restart_count,
700 })
701 .collect();
702 snaps.sort_by_key(|s| s.started_at);
703 snaps
704 }
705
706 #[must_use]
708 pub fn active_count(&self) -> usize {
709 let state = self.inner.state.lock();
710 state
711 .tasks
712 .values()
713 .filter(|e| {
714 matches!(
715 e.status,
716 TaskStatus::Running | TaskStatus::Restarting { .. }
717 )
718 })
719 .count()
720 }
721
722 #[must_use]
726 pub fn cancellation_token(&self) -> CancellationToken {
727 self.inner.cancel.clone()
728 }
729
730 fn do_spawn(
734 name: &'static str,
735 factory: &BoxFactory,
736 cancel: CancellationToken,
737 ) -> (AbortHandle, tokio::task::JoinHandle<()>) {
738 let fut = factory();
739 let span = tracing::info_span!("supervised_task", task.name = name);
740 let jh = tokio::spawn(
741 async move {
742 tokio::select! {
743 () = fut => {},
744 () = cancel.cancelled() => {},
745 }
746 }
747 .instrument(span),
748 );
749 let abort = jh.abort_handle();
750 (abort, jh)
751 }
752
753 fn wire_completion_reporter(
755 name: Arc<str>,
756 jh: tokio::task::JoinHandle<()>,
757 completion_tx: mpsc::UnboundedSender<Completion>,
758 ) {
759 tokio::spawn(async move {
760 let kind = match jh.await {
761 Ok(()) => CompletionKind::Normal,
762 Err(e) if e.is_panic() => CompletionKind::Panicked,
763 Err(_) => CompletionKind::Cancelled,
764 };
765 let _ = completion_tx.send(Completion { name, kind });
766 });
767 }
768
769 fn start_reap_driver(
776 inner: Arc<Inner>,
777 mut completion_rx: mpsc::UnboundedReceiver<Completion>,
778 cancel: CancellationToken,
779 ) {
780 tokio::spawn(async move {
781 loop {
783 tokio::select! {
784 biased;
785 Some(completion) = completion_rx.recv() => {
786 Self::handle_completion(&inner, completion).await;
787 }
788 () = cancel.cancelled() => break,
789 }
790 }
791
792 let drain_deadline = tokio::time::Instant::now() + SHUTDOWN_DRAIN_TIMEOUT;
797 let active = Self::has_active_tasks(&inner);
798 tracing::debug!(active, "reap driver entered post-cancel drain phase");
799 loop {
800 if !Self::has_active_tasks(&inner) {
801 break;
802 }
803 let remaining =
804 drain_deadline.saturating_duration_since(tokio::time::Instant::now());
805 if remaining.is_zero() {
806 break;
807 }
808 match tokio::time::timeout(remaining, completion_rx.recv()).await {
809 Ok(Some(completion)) => Self::handle_completion(&inner, completion).await,
810 Ok(None) | Err(_) => break,
812 }
813 }
814 tracing::debug!(
815 active = Self::has_active_tasks(&inner),
816 "reap driver drain phase complete"
817 );
818 });
819 }
820
821 fn has_active_tasks(inner: &Arc<Inner>) -> bool {
823 let state = inner.state.lock();
824 state.tasks.values().any(|e| {
825 matches!(
826 e.status,
827 TaskStatus::Running | TaskStatus::Restarting { .. }
828 )
829 })
830 }
831
832 async fn handle_completion(inner: &Arc<Inner>, completion: Completion) {
838 if inner.cancel.is_cancelled() {
842 let mut state = inner.state.lock();
843 state.tasks.remove(&completion.name);
844 return;
845 }
846
847 let Some((attempt, max, delay)) = Self::classify_completion(inner, &completion) else {
848 return;
849 };
850
851 tracing::warn!(
852 task.name = %completion.name,
853 attempt,
854 max,
855 delay_ms = delay.as_millis(),
856 "restarting supervised task"
857 );
858
859 if !delay.is_zero() {
860 tokio::time::sleep(delay).await;
861 }
862
863 Self::do_restart(inner, &completion.name, attempt);
864 }
865
866 fn classify_completion(
870 inner: &Arc<Inner>,
871 completion: &Completion,
872 ) -> Option<(u32, u32, Duration)> {
873 let mut state = inner.state.lock();
874 let entry = state.tasks.get_mut(&completion.name)?;
875
876 match completion.kind {
877 CompletionKind::Panicked => {
878 tracing::warn!(task.name = %completion.name, "supervised task panicked");
879 }
880 CompletionKind::Normal => {
881 tracing::info!(task.name = %completion.name, "supervised task completed");
882 }
883 CompletionKind::Cancelled => {
884 tracing::debug!(task.name = %completion.name, "supervised task cancelled");
885 }
886 }
887
888 match entry.restart_policy {
889 RestartPolicy::RunOnce => {
890 entry.status = TaskStatus::Completed;
891 state.tasks.remove(&completion.name);
892 None
893 }
894 RestartPolicy::Restart { max, base_delay } => {
895 if completion.kind != CompletionKind::Panicked {
897 entry.status = TaskStatus::Completed;
898 state.tasks.remove(&completion.name);
899 return None;
900 }
901 if entry.restart_count >= max {
902 let reason = format!("panicked after {max} restart(s)");
903 tracing::error!(
904 task.name = %completion.name,
905 attempts = max,
906 "task failed permanently"
907 );
908 entry.status = TaskStatus::Failed { reason };
909 None
910 } else {
911 let attempt = entry.restart_count + 1;
912 entry.status = TaskStatus::Restarting { attempt, max };
913 let multiplier = 1_u32
915 .checked_shl(attempt.saturating_sub(1))
916 .unwrap_or(u32::MAX);
917 let delay = base_delay.saturating_mul(multiplier).min(MAX_RESTART_DELAY);
918 Some((attempt, max, delay))
919 }
920 }
921 }
922 }
924
925 fn do_restart(inner: &Arc<Inner>, name: &Arc<str>, attempt: u32) {
927 let spawn_params = {
928 let mut state = inner.state.lock();
929 let Some(entry) = state.tasks.get_mut(name.as_ref()) else {
930 tracing::debug!(
931 task.name = %name,
932 "task removed during restart delay — skipping"
933 );
934 return;
935 };
936 if !matches!(entry.status, TaskStatus::Restarting { .. }) {
937 return;
938 }
939 let Some(factory) = &entry.factory else {
940 return;
941 };
942 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(factory)) {
945 Err(_) => {
946 let reason = format!("factory panicked on restart attempt {attempt}");
947 tracing::error!(task.name = %name, attempt, "factory panicked during restart");
948 entry.status = TaskStatus::Failed { reason };
949 None
950 }
951 Ok(fut) => Some((
952 fut,
953 inner.cancel.clone(),
954 inner.completion_tx.clone(),
955 name.clone(),
956 )),
957 }
958 };
960
961 let Some((fut, cancel, completion_tx, name)) = spawn_params else {
962 return;
963 };
964
965 let span = tracing::info_span!("supervised_task", task.name = %name);
966 let jh = tokio::spawn(
967 async move {
968 tokio::select! {
969 () = fut => {},
970 () = cancel.cancelled() => {},
971 }
972 }
973 .instrument(span),
974 );
975 let new_abort = jh.abort_handle();
976
977 {
978 let mut state = inner.state.lock();
979 if let Some(entry) = state.tasks.get_mut(name.as_ref()) {
980 entry.restart_count = attempt;
981 entry.status = TaskStatus::Running;
982 entry.abort_handle = new_abort;
983 }
984 }
985
986 Self::wire_completion_reporter(name.clone(), jh, completion_tx);
987 }
988}
989
990#[cfg(feature = "task-metrics")]
997#[inline]
998fn measure_blocking<F, R>(name: &str, f: F) -> R
999where
1000 F: FnOnce() -> R,
1001{
1002 use cpu_time::ThreadTime;
1003 let wall_start = std::time::Instant::now();
1004 let cpu_start = ThreadTime::now();
1005 let result = f();
1006 let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
1007 let cpu_ms = cpu_start.elapsed().as_secs_f64() * 1000.0;
1008 metrics::histogram!("zeph.task.wall_time_ms", "task" => name.to_owned()).record(wall_ms);
1009 metrics::histogram!("zeph.task.cpu_time_ms", "task" => name.to_owned()).record(cpu_ms);
1010 tracing::Span::current().record("task.wall_time_ms", wall_ms);
1011 tracing::Span::current().record("task.cpu_time_ms", cpu_ms);
1012 result
1013}
1014
1015#[cfg(not(feature = "task-metrics"))]
1019#[inline]
1020fn measure_blocking<F, R>(_name: &str, f: F) -> R
1021where
1022 F: FnOnce() -> R,
1023{
1024 f()
1025}
1026
1027impl BlockingSpawner for TaskSupervisor {
1030 fn spawn_blocking_named(
1036 &self,
1037 name: Arc<str>,
1038 f: Box<dyn FnOnce() + Send + 'static>,
1039 ) -> tokio::task::JoinHandle<()> {
1040 let handle = self.spawn_blocking(Arc::clone(&name), f);
1041 tokio::spawn(async move {
1042 if let Err(e) = handle.join().await {
1043 tracing::error!(task.name = %name, error = %e, "supervised blocking task failed");
1044 }
1045 })
1046 }
1047}
1048
1049#[cfg(test)]
1052mod tests {
1053 use std::sync::Arc;
1054 use std::sync::atomic::{AtomicU32, Ordering};
1055 use std::time::Duration;
1056
1057 use tokio_util::sync::CancellationToken;
1058
1059 use super::*;
1060
1061 fn make_supervisor() -> (TaskSupervisor, CancellationToken) {
1062 let cancel = CancellationToken::new();
1063 let sup = TaskSupervisor::new(cancel.clone());
1064 (sup, cancel)
1065 }
1066
1067 #[tokio::test]
1068 async fn test_spawn_and_complete() {
1069 let (sup, _cancel) = make_supervisor();
1070
1071 let done = Arc::new(tokio::sync::Notify::new());
1072 let done2 = Arc::clone(&done);
1073
1074 sup.spawn(TaskDescriptor {
1075 name: "simple",
1076 restart: RestartPolicy::RunOnce,
1077 factory: move || {
1078 let d = Arc::clone(&done2);
1079 async move {
1080 d.notify_one();
1081 }
1082 },
1083 });
1084
1085 tokio::time::timeout(Duration::from_secs(2), done.notified())
1086 .await
1087 .expect("task should complete");
1088
1089 tokio::time::sleep(Duration::from_millis(50)).await;
1090 assert_eq!(
1091 sup.active_count(),
1092 0,
1093 "RunOnce task should be removed after completion"
1094 );
1095 }
1096
1097 #[tokio::test]
1098 async fn test_panic_capture() {
1099 let (sup, _cancel) = make_supervisor();
1100
1101 sup.spawn(TaskDescriptor {
1102 name: "panicking",
1103 restart: RestartPolicy::RunOnce,
1104 factory: || async { panic!("intentional test panic") },
1105 });
1106
1107 tokio::time::sleep(Duration::from_millis(200)).await;
1108
1109 let snaps = sup.snapshot();
1110 assert!(
1111 snaps.iter().all(|s| s.name.as_ref() != "panicking"),
1112 "entry should be reaped"
1113 );
1114 assert_eq!(
1115 sup.active_count(),
1116 0,
1117 "active count must be 0 after RunOnce panic"
1118 );
1119 }
1120
1121 #[tokio::test]
1124 async fn test_restart_only_on_panic() {
1125 let (sup, _cancel) = make_supervisor();
1126
1127 let normal_counter = Arc::new(AtomicU32::new(0));
1129 let nc = Arc::clone(&normal_counter);
1130 sup.spawn(TaskDescriptor {
1131 name: "normal-exit",
1132 restart: RestartPolicy::Restart {
1133 max: 3,
1134 base_delay: Duration::from_millis(10),
1135 },
1136 factory: move || {
1137 let c = Arc::clone(&nc);
1138 async move {
1139 c.fetch_add(1, Ordering::SeqCst);
1140 }
1142 },
1143 });
1144
1145 tokio::time::sleep(Duration::from_millis(300)).await;
1146 assert_eq!(
1147 normal_counter.load(Ordering::SeqCst),
1148 1,
1149 "normal exit must not restart"
1150 );
1151 assert!(
1152 sup.snapshot()
1153 .iter()
1154 .all(|s| s.name.as_ref() != "normal-exit"),
1155 "entry removed after normal exit"
1156 );
1157
1158 let panic_counter = Arc::new(AtomicU32::new(0));
1160 let pc = Arc::clone(&panic_counter);
1161 sup.spawn(TaskDescriptor {
1162 name: "panic-exit",
1163 restart: RestartPolicy::Restart {
1164 max: 2,
1165 base_delay: Duration::from_millis(10),
1166 },
1167 factory: move || {
1168 let c = Arc::clone(&pc);
1169 async move {
1170 c.fetch_add(1, Ordering::SeqCst);
1171 panic!("test panic");
1172 }
1173 },
1174 });
1175
1176 tokio::time::sleep(Duration::from_millis(500)).await;
1178 assert!(
1179 panic_counter.load(Ordering::SeqCst) >= 3,
1180 "panicking task must restart max times"
1181 );
1182 let snap = sup
1183 .snapshot()
1184 .into_iter()
1185 .find(|s| s.name.as_ref() == "panic-exit");
1186 assert!(
1187 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1188 "task must be Failed after exhausting restarts"
1189 );
1190 }
1191
1192 #[tokio::test]
1193 async fn test_restart_policy() {
1194 let (sup, _cancel) = make_supervisor();
1195
1196 let counter = Arc::new(AtomicU32::new(0));
1197 let counter2 = Arc::clone(&counter);
1198
1199 sup.spawn(TaskDescriptor {
1200 name: "restartable",
1201 restart: RestartPolicy::Restart {
1202 max: 2,
1203 base_delay: Duration::from_millis(10),
1204 },
1205 factory: move || {
1206 let c = Arc::clone(&counter2);
1207 async move {
1208 c.fetch_add(1, Ordering::SeqCst);
1209 panic!("always panic");
1210 }
1211 },
1212 });
1213
1214 tokio::time::sleep(Duration::from_millis(500)).await;
1215
1216 let runs = counter.load(Ordering::SeqCst);
1217 assert!(
1218 runs >= 3,
1219 "expected at least 3 invocations (initial + 2 restarts), got {runs}"
1220 );
1221
1222 let snaps = sup.snapshot();
1223 let snap = snaps.iter().find(|s| s.name.as_ref() == "restartable");
1224 assert!(snap.is_some(), "failed task should remain in registry");
1225 assert!(
1226 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1227 "task should be Failed after exhausting retries"
1228 );
1229 }
1230
1231 #[tokio::test]
1233 async fn test_exponential_backoff() {
1234 let (sup, _cancel) = make_supervisor();
1235
1236 let timestamps = Arc::new(parking_lot::Mutex::new(Vec::<std::time::Instant>::new()));
1237 let ts = Arc::clone(×tamps);
1238
1239 sup.spawn(TaskDescriptor {
1240 name: "backoff-task",
1241 restart: RestartPolicy::Restart {
1242 max: 3,
1243 base_delay: Duration::from_millis(50),
1244 },
1245 factory: move || {
1246 let t = Arc::clone(&ts);
1247 async move {
1248 t.lock().push(std::time::Instant::now());
1249 panic!("always panic");
1250 }
1251 },
1252 });
1253
1254 tokio::time::sleep(Duration::from_millis(800)).await;
1256
1257 let ts = timestamps.lock();
1258 assert!(
1259 ts.len() >= 3,
1260 "expected at least 3 invocations, got {}",
1261 ts.len()
1262 );
1263
1264 if ts.len() >= 3 {
1266 let d1 = ts[1].duration_since(ts[0]);
1267 let d2 = ts[2].duration_since(ts[1]);
1268 assert!(
1270 d2 >= d1.mul_f64(1.5),
1271 "expected exponential backoff: d1={d1:?} d2={d2:?}"
1272 );
1273 }
1274 }
1275
1276 #[tokio::test]
1277 async fn test_graceful_shutdown() {
1278 let (sup, _cancel) = make_supervisor();
1279
1280 for name in ["svc-a", "svc-b", "svc-c"] {
1281 sup.spawn(TaskDescriptor {
1282 name,
1283 restart: RestartPolicy::RunOnce,
1284 factory: || async {
1285 tokio::time::sleep(Duration::from_mins(1)).await;
1286 },
1287 });
1288 }
1289
1290 assert_eq!(sup.active_count(), 3);
1291
1292 tokio::time::timeout(
1293 Duration::from_secs(2),
1294 sup.shutdown_all(Duration::from_secs(1)),
1295 )
1296 .await
1297 .expect("shutdown should complete within timeout");
1298 }
1299
1300 #[tokio::test]
1302 async fn test_force_abort_marks_aborted() {
1303 let cancel = CancellationToken::new();
1304 let sup = TaskSupervisor::new(cancel.clone());
1305
1306 sup.spawn(TaskDescriptor {
1307 name: "stubborn-for-abort",
1308 restart: RestartPolicy::RunOnce,
1309 factory: || async {
1310 std::future::pending::<()>().await;
1312 },
1313 });
1314
1315 sup.shutdown_all(Duration::from_millis(1)).await;
1317
1318 let snaps = sup.snapshot();
1320 if let Some(snap) = snaps
1321 .iter()
1322 .find(|s| s.name.as_ref() == "stubborn-for-abort")
1323 {
1324 assert_eq!(
1325 snap.status,
1326 TaskStatus::Aborted,
1327 "force-aborted task must have Aborted status"
1328 );
1329 }
1330 }
1332
1333 #[tokio::test]
1334 async fn test_registry_snapshot() {
1335 let (sup, _cancel) = make_supervisor();
1336
1337 for name in ["alpha", "beta"] {
1338 sup.spawn(TaskDescriptor {
1339 name,
1340 restart: RestartPolicy::RunOnce,
1341 factory: || async {
1342 tokio::time::sleep(Duration::from_secs(10)).await;
1343 },
1344 });
1345 }
1346
1347 let snaps = sup.snapshot();
1348 assert_eq!(snaps.len(), 2);
1349 let names: Vec<&str> = snaps.iter().map(|s| s.name.as_ref()).collect();
1350 assert!(names.contains(&"alpha"));
1351 assert!(names.contains(&"beta"));
1352 assert!(snaps.iter().all(|s| s.status == TaskStatus::Running));
1353 }
1354
1355 #[tokio::test]
1356 async fn test_blocking_returns_value() {
1357 let (sup, cancel) = make_supervisor();
1358
1359 let handle: BlockingHandle<u32> = sup.spawn_blocking(Arc::from("compute"), || 42_u32);
1360 let result = handle.join().await.expect("should return value");
1361 assert_eq!(result, 42);
1362 cancel.cancel();
1363 }
1364
1365 #[tokio::test]
1366 async fn test_blocking_panic() {
1367 let (sup, _cancel) = make_supervisor();
1368
1369 let handle: BlockingHandle<u32> =
1370 sup.spawn_blocking(Arc::from("panicking-compute"), || panic!("intentional"));
1371 let err = handle
1372 .join()
1373 .await
1374 .expect_err("should return error on panic");
1375 assert_eq!(err, BlockingError::Panicked);
1376 }
1377
1378 #[tokio::test]
1380 async fn test_blocking_registered_in_registry() {
1381 let (sup, cancel) = make_supervisor();
1382
1383 let (tx, rx) = std::sync::mpsc::channel::<()>();
1384 let _handle: BlockingHandle<()> =
1385 sup.spawn_blocking(Arc::from("blocking-task"), move || {
1386 let _ = rx.recv();
1388 });
1389
1390 tokio::time::sleep(Duration::from_millis(10)).await;
1391 assert_eq!(
1392 sup.active_count(),
1393 1,
1394 "blocking task must appear in active_count"
1395 );
1396
1397 let _ = tx.send(());
1398 tokio::time::sleep(Duration::from_millis(100)).await;
1399 assert_eq!(
1400 sup.active_count(),
1401 0,
1402 "blocking task must be removed after completion"
1403 );
1404
1405 cancel.cancel();
1406 }
1407
1408 #[tokio::test]
1410 async fn test_oneshot_registered_in_registry() {
1411 let (sup, cancel) = make_supervisor();
1412
1413 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
1414 let _handle: BlockingHandle<()> =
1415 sup.spawn_oneshot(Arc::from("oneshot-task"), move || async move {
1416 let _ = rx.await;
1417 });
1418
1419 tokio::time::sleep(Duration::from_millis(10)).await;
1420 assert_eq!(
1421 sup.active_count(),
1422 1,
1423 "oneshot task must appear in active_count"
1424 );
1425
1426 let _ = tx.send(());
1427 tokio::time::sleep(Duration::from_millis(50)).await;
1428 assert_eq!(
1429 sup.active_count(),
1430 0,
1431 "oneshot task must be removed after completion"
1432 );
1433
1434 cancel.cancel();
1435 }
1436
1437 #[tokio::test]
1438 async fn test_restart_max_zero() {
1439 let (sup, _cancel) = make_supervisor();
1440
1441 let counter = Arc::new(AtomicU32::new(0));
1442 let counter2 = Arc::clone(&counter);
1443
1444 sup.spawn(TaskDescriptor {
1445 name: "zero-max",
1446 restart: RestartPolicy::Restart {
1447 max: 0,
1448 base_delay: Duration::from_millis(10),
1449 },
1450 factory: move || {
1451 let c = Arc::clone(&counter2);
1452 async move {
1453 c.fetch_add(1, Ordering::SeqCst);
1454 panic!("always panic");
1455 }
1456 },
1457 });
1458
1459 tokio::time::sleep(Duration::from_millis(200)).await;
1460
1461 assert_eq!(
1462 counter.load(Ordering::SeqCst),
1463 1,
1464 "max=0 should not restart"
1465 );
1466
1467 let snaps = sup.snapshot();
1468 let snap = snaps.iter().find(|s| s.name.as_ref() == "zero-max");
1469 assert!(snap.is_some(), "entry should remain as Failed");
1470 assert!(
1471 matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1472 "status should be Failed"
1473 );
1474 }
1475
1476 #[tokio::test]
1478 async fn test_concurrent_spawns() {
1479 static NAMES: [&str; 50] = [
1481 "t00", "t01", "t02", "t03", "t04", "t05", "t06", "t07", "t08", "t09", "t10", "t11",
1482 "t12", "t13", "t14", "t15", "t16", "t17", "t18", "t19", "t20", "t21", "t22", "t23",
1483 "t24", "t25", "t26", "t27", "t28", "t29", "t30", "t31", "t32", "t33", "t34", "t35",
1484 "t36", "t37", "t38", "t39", "t40", "t41", "t42", "t43", "t44", "t45", "t46", "t47",
1485 "t48", "t49",
1486 ];
1487 let (sup, cancel) = make_supervisor();
1488
1489 let completed = Arc::new(AtomicU32::new(0));
1490 for name in &NAMES {
1491 let c = Arc::clone(&completed);
1492 sup.spawn(TaskDescriptor {
1493 name,
1494 restart: RestartPolicy::RunOnce,
1495 factory: move || {
1496 let c = Arc::clone(&c);
1497 async move {
1498 c.fetch_add(1, Ordering::SeqCst);
1499 }
1500 },
1501 });
1502 }
1503
1504 tokio::time::timeout(Duration::from_secs(5), async {
1506 loop {
1507 if completed.load(Ordering::SeqCst) == 50 {
1508 break;
1509 }
1510 tokio::time::sleep(Duration::from_millis(10)).await;
1511 }
1512 })
1513 .await
1514 .expect("all 50 tasks should complete");
1515
1516 tokio::time::sleep(Duration::from_millis(100)).await;
1518 assert_eq!(sup.active_count(), 0, "all tasks must be reaped");
1519
1520 cancel.cancel();
1521 }
1522
1523 #[tokio::test]
1524 async fn test_shutdown_timeout_expiry() {
1525 let cancel = CancellationToken::new();
1526 let sup = TaskSupervisor::new(cancel.clone());
1527
1528 sup.spawn(TaskDescriptor {
1529 name: "stubborn",
1530 restart: RestartPolicy::RunOnce,
1531 factory: || async {
1532 tokio::time::sleep(Duration::from_mins(1)).await;
1533 },
1534 });
1535
1536 assert_eq!(sup.active_count(), 1);
1537
1538 tokio::time::timeout(
1539 Duration::from_secs(2),
1540 sup.shutdown_all(Duration::from_millis(50)),
1541 )
1542 .await
1543 .expect("shutdown_all should return even on timeout expiry");
1544
1545 assert!(
1546 cancel.is_cancelled(),
1547 "cancel token must be cancelled after shutdown"
1548 );
1549 }
1550
1551 #[tokio::test]
1552 async fn test_cancellation_token() {
1553 let cancel = CancellationToken::new();
1554 let sup = TaskSupervisor::new(cancel.clone());
1555
1556 assert!(!sup.cancellation_token().is_cancelled());
1557
1558 sup.shutdown_all(Duration::from_millis(100)).await;
1559
1560 assert!(
1561 sup.cancellation_token().is_cancelled(),
1562 "token must be cancelled after shutdown"
1563 );
1564 }
1565
1566 #[tokio::test]
1572 async fn test_shutdown_drains_post_cancel_completions() {
1573 let cancel = CancellationToken::new();
1574 let sup = TaskSupervisor::new(cancel.clone());
1575
1576 for name in [
1577 "loop-1", "loop-2", "loop-3", "loop-4", "loop-5", "loop-6", "loop-7",
1578 ] {
1579 let cancel_inner = cancel.clone();
1580 sup.spawn(TaskDescriptor {
1581 name,
1582 restart: RestartPolicy::RunOnce,
1583 factory: move || {
1584 let c = cancel_inner.clone();
1585 async move {
1586 c.cancelled().await;
1587 for _ in 0..64 {
1589 tokio::task::yield_now().await;
1590 }
1591 }
1592 },
1593 });
1594 }
1595 assert_eq!(sup.active_count(), 7);
1596
1597 sup.shutdown_all(Duration::from_secs(2)).await;
1598
1599 assert_eq!(
1600 sup.active_count(),
1601 0,
1602 "all tasks must be reaped after shutdown (#3161)"
1603 );
1604 }
1605
1606 #[tokio::test]
1607 async fn test_blocking_spawner_task_appears_in_snapshot() {
1608 use zeph_common::BlockingSpawner;
1610
1611 let cancel = CancellationToken::new();
1612 let sup = TaskSupervisor::new(cancel);
1613
1614 let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<()>();
1615 let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>();
1616
1617 let handle = sup.spawn_blocking_named(
1618 Arc::from("chunk_file"),
1619 Box::new(move || {
1620 let _ = ready_tx.send(());
1622 let _ = release_rx.blocking_recv();
1624 }),
1625 );
1626
1627 ready_rx.await.expect("task should start");
1629
1630 let snapshot = sup.snapshot();
1631 assert!(
1632 snapshot.iter().any(|t| t.name.as_ref() == "chunk_file"),
1633 "chunk_file task must appear in supervisor snapshot"
1634 );
1635
1636 let _ = release_tx.send(());
1638 handle.await.expect("task should complete");
1639 }
1640
1641 #[cfg(feature = "task-metrics")]
1648 #[test]
1649 fn test_measure_blocking_emits_metrics() {
1650 use metrics_util::debugging::DebuggingRecorder;
1651
1652 let recorder = DebuggingRecorder::new();
1653 let snapshotter = recorder.snapshotter();
1654
1655 metrics::with_local_recorder(&recorder, || {
1658 measure_blocking("test_task", || std::hint::black_box(42_u64));
1659 });
1660
1661 let snapshot = snapshotter.snapshot();
1662 let metric_names: Vec<String> = snapshot
1663 .into_vec()
1664 .into_iter()
1665 .map(|(k, _, _, _)| k.key().name().to_owned())
1666 .collect();
1667
1668 assert!(
1669 metric_names.iter().any(|n| n == "zeph.task.wall_time_ms"),
1670 "expected zeph.task.wall_time_ms histogram; got: {metric_names:?}"
1671 );
1672 assert!(
1673 metric_names.iter().any(|n| n == "zeph.task.cpu_time_ms"),
1674 "expected zeph.task.cpu_time_ms histogram; got: {metric_names:?}"
1675 );
1676 }
1677
1678 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1684 async fn test_spawn_blocking_semaphore_cap() {
1685 let (sup, _cancel) = make_supervisor();
1686 let concurrent = Arc::new(AtomicU32::new(0));
1687 let max_concurrent = Arc::new(AtomicU32::new(0));
1688 let barrier = Arc::new(std::sync::Barrier::new(1)); let mut handles = Vec::new();
1691 for i in 0u32..16 {
1692 let c = Arc::clone(&concurrent);
1693 let m = Arc::clone(&max_concurrent);
1694 let name: Arc<str> = Arc::from(format!("blocking-{i}").as_str());
1695 let h = sup.spawn_blocking(name, move || {
1696 let prev = c.fetch_add(1, Ordering::SeqCst);
1697 let mut cur_max = m.load(Ordering::SeqCst);
1699 while prev + 1 > cur_max {
1700 match m.compare_exchange(cur_max, prev + 1, Ordering::SeqCst, Ordering::SeqCst)
1701 {
1702 Ok(_) => break,
1703 Err(x) => cur_max = x,
1704 }
1705 }
1706 std::thread::sleep(std::time::Duration::from_millis(20));
1708 c.fetch_sub(1, Ordering::SeqCst);
1709 });
1710 handles.push(h);
1711 }
1712
1713 for h in handles {
1714 h.join().await.expect("blocking task should succeed");
1715 }
1716 drop(barrier);
1717
1718 let observed = max_concurrent.load(Ordering::SeqCst);
1719 assert!(
1720 observed <= 8,
1721 "observed {observed} concurrent blocking tasks; expected ≤ 8 (semaphore cap)"
1722 );
1723 }
1724}