Skip to main content

zeph_common/
task_supervisor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Supervised lifecycle task manager for long-running named services.
5//!
6//! [`TaskSupervisor`] manages named, long-lived background tasks (config watcher,
7//! scheduler loop, gateway, MCP connections, etc.) with restart policies, health
8//! snapshots, and graceful shutdown. Unlike `BackgroundSupervisor`
9//! (which is `&mut self`-only, lossy, and turn-scoped), `TaskSupervisor` is
10//! `Clone + Send + Sync` and designed for the full agent session lifetime.
11//!
12//! # Design rationale
13//!
14//! - **Shared handle**: `Arc<Inner>` interior allows passing the supervisor to bootstrap
15//!   code, TUI status display, and shutdown orchestration without lifetime coupling.
16//! - **Event-driven reap**: An internal mpsc channel delivers completion events to a
17//!   reap driver task; no polling interval required.
18//! - **No `JoinSet`**: Individual `JoinHandle`s per task enable per-name abort, status
19//!   tracking, and restart policies — `JoinSet` is better for homogeneous work.
20//! - **Mutex held briefly**: `parking_lot::Mutex` guards only bookkeeping operations
21//!   (insert/remove from `HashMap`). The lock is **never held across `.await`**.
22//!
23//! # Examples
24//!
25//! ```rust,no_run
26//! use std::time::Duration;
27//! use tokio_util::sync::CancellationToken;
28//! use zeph_common::task_supervisor::{RestartPolicy, TaskDescriptor, TaskSupervisor};
29//!
30//! # #[tokio::main]
31//! # async fn main() {
32//! let cancel = CancellationToken::new();
33//! let supervisor = TaskSupervisor::new(cancel.clone());
34//!
35//! supervisor.spawn(TaskDescriptor {
36//!     name: "my-service",
37//!     restart: RestartPolicy::Restart { max: 3, base_delay: Duration::from_secs(1) },
38//!     factory: || async { /* service loop */ },
39//! });
40//!
41//! // Graceful shutdown — waits up to 5 s for all tasks to stop.
42//! supervisor.shutdown_all(Duration::from_secs(5)).await;
43//! # }
44//! ```
45
46use std::collections::HashMap;
47use std::future::Future;
48use std::pin::Pin;
49use std::sync::Arc;
50use std::time::{Duration, Instant};
51
52use tokio::sync::{mpsc, oneshot};
53use tokio::task::AbortHandle;
54use tokio_util::sync::CancellationToken;
55use tracing::Instrument as _;
56
57use crate::BlockingSpawner;
58
59// ── Public types ─────────────────────────────────────────────────────────────
60
61/// Policy governing what happens when a supervised task completes or panics.
62///
63/// Used in [`TaskDescriptor`] to configure restart behaviour for a task.
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum RestartPolicy {
66    /// Task runs once; normal completion removes it from the registry.
67    RunOnce,
68    /// Task is restarted **only on panic**, up to `max` times.
69    ///
70    /// Normal completion (the future returns `()`) does **not** trigger a restart.
71    /// The task is removed from the registry on normal exit.
72    ///
73    /// A `max` of `0` means the task is monitored but **never** restarted —
74    /// a panic leaves the entry as `Failed` in the registry for observability.
75    /// Use `RunOnce` when you want the entry removed on completion.
76    ///
77    /// Restart delays follow **exponential backoff**: the delay before attempt `n`
78    /// is `base_delay * 2^(n-1)`, capped at [`MAX_RESTART_DELAY`].
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// use std::time::Duration;
84    /// use zeph_common::task_supervisor::RestartPolicy;
85    ///
86    /// // Restart up to 3 times with exponential backoff starting at 1 s.
87    /// let policy = RestartPolicy::Restart { max: 3, base_delay: Duration::from_secs(1) };
88    /// ```
89    Restart { max: u32, base_delay: Duration },
90}
91
92/// Maximum delay between restart attempts (caps exponential backoff).
93pub const MAX_RESTART_DELAY: Duration = Duration::from_mins(1);
94
95/// Safety cap on how long the reap driver drains completions after cancellation.
96///
97/// INVARIANT: must be less than the runner shutdown grace period (runner.rs:2387,
98/// currently 10s). If that constant is reduced, this must be reduced proportionally.
99const SHUTDOWN_DRAIN_TIMEOUT: Duration = Duration::from_secs(5);
100
101/// Configuration passed to [`TaskSupervisor::spawn`] to describe a supervised task.
102///
103/// `F` must be `Fn` (not `FnOnce`) to support restarts: the factory is called once on
104/// initial spawn and once per restart attempt.
105pub struct TaskDescriptor<F> {
106    /// Unique name for this task (e.g., `"config-watcher"`, `"scheduler-loop"`).
107    ///
108    /// Names must be `'static` — they are typically compile-time string literals.
109    /// Spawning a task with a name that already exists aborts the prior instance.
110    pub name: &'static str,
111    /// Restart policy applied when the task exits unexpectedly.
112    pub restart: RestartPolicy,
113    /// Factory called to produce a new future. Must be `Fn` for restart support.
114    pub factory: F,
115}
116
117/// Opaque handle to a single supervised task.
118///
119/// Can be used to abort the task by name independently of the supervisor.
120#[derive(Debug, Clone)]
121pub struct TaskHandle {
122    name: &'static str,
123    abort: AbortHandle,
124}
125
126impl TaskHandle {
127    /// Abort the task immediately.
128    pub fn abort(&self) {
129        tracing::debug!(task.name = self.name, "task aborted via handle");
130        self.abort.abort();
131    }
132
133    /// Return the task's name.
134    #[must_use]
135    pub fn name(&self) -> &'static str {
136        self.name
137    }
138}
139
140/// Error returned by [`BlockingHandle::join`].
141#[derive(Debug, PartialEq, Eq)]
142pub enum BlockingError {
143    /// The task panicked before producing a result.
144    Panicked,
145    /// The supervisor (or the task's abort handle) was dropped before the task completed.
146    SupervisorDropped,
147}
148
149impl std::fmt::Display for BlockingError {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        match self {
152            Self::Panicked => write!(f, "supervised blocking task panicked"),
153            Self::SupervisorDropped => write!(f, "supervisor dropped before task completed"),
154        }
155    }
156}
157
158impl std::error::Error for BlockingError {}
159
160/// Handle returned by [`TaskSupervisor::spawn_blocking`].
161///
162/// Awaiting [`BlockingHandle::join`] blocks until the OS-thread task produces a
163/// value. Dropping the handle without joining does **not** cancel the task — it
164/// continues to run on the blocking thread pool but the result is discarded.
165///
166/// A panic inside the closure is captured and returned as
167/// [`BlockingError::Panicked`] rather than propagating to the caller.
168pub struct BlockingHandle<R> {
169    rx: oneshot::Receiver<Result<R, BlockingError>>,
170    abort: AbortHandle,
171}
172
173impl<R> BlockingHandle<R> {
174    /// Await the task result.
175    ///
176    /// # Errors
177    ///
178    /// - [`BlockingError::Panicked`] — the task closure panicked.
179    /// - [`BlockingError::SupervisorDropped`] — the task was aborted or the
180    ///   supervisor was dropped before a value was produced.
181    pub async fn join(self) -> Result<R, BlockingError> {
182        self.rx
183            .await
184            .unwrap_or(Err(BlockingError::SupervisorDropped))
185    }
186
187    /// Non-blocking poll: return the result if the task has already finished, or `None`
188    /// if it is still running.
189    ///
190    /// This is the `BlockingHandle` equivalent of `FutureExt::now_or_never` on a
191    /// [`tokio::task::JoinHandle`]. Call this inside a synchronous context (e.g., between
192    /// agent turns) to apply a completed background result without blocking.
193    ///
194    /// The handle is consumed on success. If the task is not yet done, the handle
195    /// is returned as `Err(self)` so the caller can re-store it.
196    ///
197    /// # Examples
198    ///
199    /// ```rust,no_run
200    /// # use zeph_common::task_supervisor::{BlockingHandle, BlockingError};
201    /// async fn example(mut handle: BlockingHandle<u32>) {
202    ///     // Try to get the result without blocking.
203    ///     match handle.try_join() {
204    ///         Ok(result) => println!("done: {result:?}"),
205    ///         Err(handle) => {
206    ///             // Task still running — `handle` is returned for re-storage.
207    ///             drop(handle);
208    ///         }
209    ///     }
210    /// }
211    /// ```
212    ///
213    /// # Errors
214    ///
215    /// Returns `Err(self)` when the task has not yet produced a result (still running).
216    /// The inner `Ok(Err(BlockingError::...))` variants are returned when the task
217    /// panicked or the supervisor was dropped before the task completed.
218    pub fn try_join(mut self) -> Result<Result<R, BlockingError>, Self> {
219        match self.rx.try_recv() {
220            Ok(result) => Ok(result),
221            Err(tokio::sync::oneshot::error::TryRecvError::Empty) => Err(self),
222            Err(tokio::sync::oneshot::error::TryRecvError::Closed) => {
223                Ok(Err(BlockingError::SupervisorDropped))
224            }
225        }
226    }
227
228    /// Abort the underlying task immediately.
229    pub fn abort(&self) {
230        self.abort.abort();
231    }
232}
233
234/// Point-in-time state of a supervised task.
235#[derive(Debug, Clone, PartialEq, Eq)]
236pub enum TaskStatus {
237    /// Task is actively running.
238    Running,
239    /// Task is waiting for the restart delay before the next attempt.
240    Restarting { attempt: u32, max: u32 },
241    /// Task completed normally.
242    Completed,
243    /// Task was force-aborted during shutdown.
244    Aborted,
245    /// Task exhausted all restart attempts and is permanently failed.
246    Failed { reason: String },
247}
248
249/// Point-in-time snapshot of a supervised task, returned by [`TaskSupervisor::snapshot`].
250#[derive(Debug, Clone)]
251/// Observability surface per field:
252///
253/// | Field | tokio-console | Jaeger / OTLP | TUI | `metrics` histogram |
254/// |-------|--------------|--------------|-----|---------------------|
255/// | `name` | span name | span name | task list | label `"task"` |
256/// | `task.wall_time_ms` | — | span field (`task-metrics`) | — | `zeph.task.wall_time_ms` |
257/// | `task.cpu_time_ms` | — | span field (`task-metrics`) | — | `zeph.task.cpu_time_ms` |
258/// | `status` | — | — | task list | — |
259/// | `restart_count` | — | — | task list | — |
260///
261/// The `task.wall_time_ms` and `task.cpu_time_ms` fields are only populated when
262/// the crate is compiled with the `task-metrics` feature.
263pub struct TaskSnapshot {
264    /// Task name.
265    pub name: Arc<str>,
266    /// Current status.
267    pub status: TaskStatus,
268    /// Instant the task was first spawned.
269    pub started_at: Instant,
270    /// Number of times the task has been restarted.
271    pub restart_count: u32,
272}
273
274// ── Internal types ───────────────────────────────────────────────────────────
275
276type BoxFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
277type BoxFactory = Box<dyn Fn() -> BoxFuture + Send + Sync>;
278
279struct TaskEntry {
280    name: Arc<str>,
281    status: TaskStatus,
282    started_at: Instant,
283    restart_count: u32,
284    restart_policy: RestartPolicy,
285    abort_handle: AbortHandle,
286    /// `Some` only for `Restart` policy tasks.
287    factory: Option<BoxFactory>,
288}
289
290/// How a supervised task ended.
291#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292enum CompletionKind {
293    /// Future returned normally.
294    Normal,
295    /// Future panicked.
296    Panicked,
297    /// Future was cancelled via the cancellation token or abort handle.
298    Cancelled,
299}
300
301struct Completion {
302    name: Arc<str>,
303    kind: CompletionKind,
304}
305
306struct SupervisorState {
307    tasks: HashMap<Arc<str>, TaskEntry>,
308}
309
310struct Inner {
311    state: parking_lot::Mutex<SupervisorState>,
312    /// Completion events from spawned tasks → reap driver.
313    /// Lives in `Inner` (not `SupervisorState`) to avoid double mutex acquisition
314    /// — callers clone it once during spawn without re-locking state.
315    completion_tx: mpsc::UnboundedSender<Completion>,
316    cancel: CancellationToken,
317    /// Limits the number of concurrently running `spawn_blocking` tasks to prevent
318    /// runaway thread-pool growth under burst load.
319    blocking_semaphore: Arc<tokio::sync::Semaphore>,
320}
321
322// ── Main type ────────────────────────────────────────────────────────────────
323
324/// Shared, cloneable handle to the supervised lifecycle task registry.
325///
326/// `TaskSupervisor` manages named, long-lived background tasks with restart
327/// policies, health snapshots, and graceful shutdown. It is `Clone + Send + Sync`
328/// so it can be distributed to bootstrap code, TUI, and shutdown orchestration
329/// without any additional synchronisation.
330///
331/// # Thread safety
332///
333/// Interior state is guarded by a `parking_lot::Mutex`. The lock is **never**
334/// held across `.await` points.
335///
336/// # Examples
337///
338/// ```rust,no_run
339/// use std::time::Duration;
340/// use tokio_util::sync::CancellationToken;
341/// use zeph_common::task_supervisor::{RestartPolicy, TaskDescriptor, TaskSupervisor};
342///
343/// # #[tokio::main]
344/// # async fn main() {
345/// let cancel = CancellationToken::new();
346/// let sup = TaskSupervisor::new(cancel.clone());
347///
348/// let _handle = sup.spawn(TaskDescriptor {
349///     name: "watcher",
350///     restart: RestartPolicy::RunOnce,
351///     factory: || async { tokio::time::sleep(std::time::Duration::from_secs(1)).await },
352/// });
353///
354/// sup.shutdown_all(Duration::from_secs(5)).await;
355/// # }
356/// ```
357#[derive(Clone)]
358pub struct TaskSupervisor {
359    inner: Arc<Inner>,
360}
361
362impl TaskSupervisor {
363    /// Create a new supervisor and start its reap driver.
364    ///
365    /// The `cancel` token is propagated into every spawned task via `tokio::select!`.
366    /// When the token is cancelled, all tasks exit cooperatively on their next
367    /// cancellation check. Call [`shutdown_all`][Self::shutdown_all] to wait for
368    /// them to finish.
369    ///
370    /// When called outside a Tokio runtime context (e.g. in synchronous unit tests),
371    /// the reap driver is skipped. The supervisor still accepts task registrations but
372    /// completion callbacks are not processed — safe because no tasks can actually be
373    /// spawned without a runtime.
374    #[must_use]
375    pub fn new(cancel: CancellationToken) -> Self {
376        // NOTE: unbounded channel is acceptable here because supervised tasks are
377        // O(10–20) lifecycle services, not high-throughput work. Backpressure would
378        // complicate the spawn path without practical benefit.
379        let (completion_tx, completion_rx) = mpsc::unbounded_channel();
380        let inner = Arc::new(Inner {
381            state: parking_lot::Mutex::new(SupervisorState {
382                tasks: HashMap::new(),
383            }),
384            completion_tx,
385            cancel: cancel.clone(),
386            blocking_semaphore: Arc::new(tokio::sync::Semaphore::new(8)),
387        });
388
389        // Only start the reap driver when a Tokio runtime is available. In synchronous
390        // unit tests that construct Agent/LifecycleState directly there is no reactor,
391        // so we skip the spawn. Without a runtime no tasks can be spawned either, so
392        // the driver is not needed.
393        if tokio::runtime::Handle::try_current().is_ok() {
394            Self::start_reap_driver(Arc::clone(&inner), completion_rx, cancel);
395        }
396
397        Self { inner }
398    }
399
400    /// Spawn a named, supervised async task.
401    ///
402    /// If a task with the same `name` already exists, it is aborted before the
403    /// new one is started.
404    ///
405    /// # Examples
406    ///
407    /// ```rust,no_run
408    /// use std::time::Duration;
409    /// use tokio_util::sync::CancellationToken;
410    /// use zeph_common::task_supervisor::{RestartPolicy, TaskDescriptor, TaskHandle, TaskSupervisor};
411    ///
412    /// # #[tokio::main]
413    /// # async fn main() {
414    /// let cancel = CancellationToken::new();
415    /// let sup = TaskSupervisor::new(cancel.clone());
416    ///
417    /// let handle: TaskHandle = sup.spawn(TaskDescriptor {
418    ///     name: "config-watcher",
419    ///     restart: RestartPolicy::Restart { max: 3, base_delay: Duration::from_secs(1) },
420    ///     factory: || async { /* watch loop */ },
421    /// });
422    /// # }
423    /// ```
424    pub fn spawn<F, Fut>(&self, desc: TaskDescriptor<F>) -> TaskHandle
425    where
426        F: Fn() -> Fut + Send + Sync + 'static,
427        Fut: Future<Output = ()> + Send + 'static,
428    {
429        let factory: BoxFactory = Box::new(move || Box::pin((desc.factory)()));
430        let cancel = self.inner.cancel.clone();
431        let completion_tx = self.inner.completion_tx.clone();
432        let name: Arc<str> = Arc::from(desc.name);
433
434        let (abort_handle, jh) = Self::do_spawn(desc.name, &factory, cancel);
435        Self::wire_completion_reporter(Arc::clone(&name), jh, completion_tx);
436
437        let entry = TaskEntry {
438            name: Arc::clone(&name),
439            status: TaskStatus::Running,
440            started_at: Instant::now(),
441            restart_count: 0,
442            restart_policy: desc.restart,
443            abort_handle: abort_handle.clone(),
444            factory: match desc.restart {
445                RestartPolicy::RunOnce => None,
446                RestartPolicy::Restart { .. } => Some(factory),
447            },
448        };
449
450        {
451            let mut state = self.inner.state.lock();
452            if let Some(old) = state.tasks.remove(&name) {
453                old.abort_handle.abort();
454            }
455            state.tasks.insert(Arc::clone(&name), entry);
456        }
457
458        TaskHandle {
459            name: desc.name,
460            abort: abort_handle,
461        }
462    }
463
464    /// Spawn a CPU-bound closure on the OS blocking thread pool.
465    ///
466    /// The closure runs via [`tokio::task::spawn_blocking`] — it is never polled
467    /// on tokio worker threads and cannot block async I/O. The task is registered
468    /// in the supervisor registry and is visible to [`snapshot`][Self::snapshot]
469    /// and [`shutdown_all`][Self::shutdown_all].
470    ///
471    /// Dropping the returned [`BlockingHandle`] without calling `.join()` does
472    /// **not** cancel the task; it runs to completion but the result is discarded.
473    ///
474    /// A panic inside `f` is captured and returned as [`BlockingError::Panicked`]
475    /// rather than propagating to the caller.
476    ///
477    /// # Examples
478    ///
479    /// ```rust,no_run
480    /// use std::sync::Arc;
481    /// use tokio_util::sync::CancellationToken;
482    /// use zeph_common::task_supervisor::{BlockingHandle, TaskSupervisor};
483    ///
484    /// # #[tokio::main]
485    /// # async fn main() {
486    /// let cancel = CancellationToken::new();
487    /// let sup = TaskSupervisor::new(cancel);
488    ///
489    /// let handle: BlockingHandle<u32> = sup.spawn_blocking(Arc::from("compute"), || {
490    ///     // CPU-bound work — safe to block here
491    ///     42_u32
492    /// });
493    /// let result = handle.join().await.unwrap();
494    /// assert_eq!(result, 42);
495    /// # }
496    /// ```
497    ///
498    /// # Capacity limit
499    ///
500    /// At most 8 `spawn_blocking` tasks run concurrently. Additional tasks wait for a
501    /// semaphore permit, bounding thread-pool growth under burst load.
502    ///
503    /// # Panics
504    ///
505    /// Panics inside `f` are captured and returned as [`BlockingError::Panicked`] — they
506    /// do not propagate to the caller.
507    #[allow(clippy::needless_pass_by_value)] // `name` is cloned into async task and registry
508    pub fn spawn_blocking<F, R>(&self, name: Arc<str>, f: F) -> BlockingHandle<R>
509    where
510        F: FnOnce() -> R + Send + 'static,
511        R: Send + 'static,
512    {
513        let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
514        #[cfg(feature = "task-metrics")]
515        let span = tracing::info_span!(
516            "supervised_blocking_task",
517            task.name = %name,
518            task.wall_time_ms = tracing::field::Empty,
519            task.cpu_time_ms = tracing::field::Empty,
520        );
521        #[cfg(not(feature = "task-metrics"))]
522        let span = tracing::info_span!("supervised_blocking_task", task.name = %name);
523
524        let semaphore = Arc::clone(&self.inner.blocking_semaphore);
525        let inner = Arc::clone(&self.inner);
526        let name_clone = Arc::clone(&name);
527        let completion_tx = self.inner.completion_tx.clone();
528
529        // Wrap the blocking spawn in an async task that first acquires a semaphore
530        // permit, bounding the number of concurrently running blocking tasks to 8.
531        let outer = tokio::spawn(async move {
532            let _permit = semaphore
533                .acquire_owned()
534                .await
535                .expect("blocking semaphore closed");
536
537            let name_for_measure = Arc::clone(&name_clone);
538            let join_handle = tokio::task::spawn_blocking(move || {
539                let _enter = span.enter();
540                measure_blocking(&name_for_measure, f)
541            });
542            let abort = join_handle.abort_handle();
543
544            // Update registry with the real abort handle now that spawn_blocking is live.
545            {
546                let mut state = inner.state.lock();
547                if let Some(entry) = state.tasks.get_mut(&name_clone) {
548                    entry.abort_handle = abort;
549                }
550            }
551
552            let kind = match join_handle.await {
553                Ok(val) => {
554                    let _ = tx.send(Ok(val));
555                    CompletionKind::Normal
556                }
557                Err(e) if e.is_panic() => {
558                    let _ = tx.send(Err(BlockingError::Panicked));
559                    CompletionKind::Panicked
560                }
561                Err(_) => {
562                    // Aborted — drop tx so rx returns SupervisorDropped.
563                    CompletionKind::Cancelled
564                }
565            };
566            // _permit released here, freeing the semaphore slot.
567            let _ = completion_tx.send(Completion {
568                name: name_clone,
569                kind,
570            });
571        });
572        let abort = outer.abort_handle();
573
574        // Register in registry so snapshot/shutdown sees the task.
575        {
576            let mut state = self.inner.state.lock();
577            if let Some(old) = state.tasks.remove(&name) {
578                old.abort_handle.abort();
579            }
580            state.tasks.insert(
581                Arc::clone(&name),
582                TaskEntry {
583                    name: Arc::clone(&name),
584                    status: TaskStatus::Running,
585                    started_at: Instant::now(),
586                    restart_count: 0,
587                    restart_policy: RestartPolicy::RunOnce,
588                    abort_handle: abort.clone(),
589                    factory: None,
590                },
591            );
592        }
593
594        BlockingHandle { rx, abort }
595    }
596
597    /// Spawn an async task that produces a typed result value (runs on tokio worker thread).
598    ///
599    /// Unlike [`spawn`][Self::spawn], no restart policy is supported — the task
600    /// runs once. The task is registered in the supervisor registry under the
601    /// provided `name` and is visible to [`snapshot`][Self::snapshot] and
602    /// [`shutdown_all`][Self::shutdown_all].
603    ///
604    /// For CPU-bound work that must not block tokio workers, use
605    /// [`spawn_blocking`][Self::spawn_blocking] instead.
606    ///
607    /// # Examples
608    ///
609    /// ```rust,no_run
610    /// use std::sync::Arc;
611    /// use tokio_util::sync::CancellationToken;
612    /// use zeph_common::task_supervisor::{BlockingHandle, TaskSupervisor};
613    ///
614    /// # #[tokio::main]
615    /// # async fn main() {
616    /// let cancel = CancellationToken::new();
617    /// let sup = TaskSupervisor::new(cancel.clone());
618    ///
619    /// let handle: BlockingHandle<u32> = sup.spawn_oneshot(Arc::from("compute"), || async { 42_u32 });
620    /// let result = handle.join().await.unwrap();
621    /// assert_eq!(result, 42);
622    /// # }
623    /// ```
624    pub fn spawn_oneshot<F, Fut, R>(&self, name: Arc<str>, factory: F) -> BlockingHandle<R>
625    where
626        F: FnOnce() -> Fut + Send + 'static,
627        Fut: Future<Output = R> + Send + 'static,
628        R: Send + 'static,
629    {
630        let (tx, rx) = oneshot::channel::<Result<R, BlockingError>>();
631        let cancel = self.inner.cancel.clone();
632        let span = tracing::info_span!("supervised_task", task.name = %name);
633        let join_handle: tokio::task::JoinHandle<Option<R>> = tokio::spawn(
634            async move {
635                let fut = factory();
636                tokio::select! {
637                    result = fut => Some(result),
638                    () = cancel.cancelled() => None,
639                }
640            }
641            .instrument(span),
642        );
643        let abort = join_handle.abort_handle();
644
645        {
646            let mut state = self.inner.state.lock();
647            if let Some(old) = state.tasks.remove(&name) {
648                old.abort_handle.abort();
649            }
650            state.tasks.insert(
651                Arc::clone(&name),
652                TaskEntry {
653                    name: Arc::clone(&name),
654                    status: TaskStatus::Running,
655                    started_at: Instant::now(),
656                    restart_count: 0,
657                    restart_policy: RestartPolicy::RunOnce,
658                    abort_handle: abort.clone(),
659                    factory: None,
660                },
661            );
662        }
663
664        let completion_tx = self.inner.completion_tx.clone();
665        tokio::spawn(async move {
666            let kind = match join_handle.await {
667                Ok(Some(val)) => {
668                    let _ = tx.send(Ok(val));
669                    CompletionKind::Normal
670                }
671                Err(e) if e.is_panic() => {
672                    let _ = tx.send(Err(BlockingError::Panicked));
673                    CompletionKind::Panicked
674                }
675                _ => CompletionKind::Cancelled,
676            };
677            let _ = completion_tx.send(Completion { name, kind });
678        });
679        BlockingHandle { rx, abort }
680    }
681
682    /// Abort a task by name. No-op if no task with that name is registered.
683    pub fn abort(&self, name: &'static str) {
684        let state = self.inner.state.lock();
685        let key: Arc<str> = Arc::from(name);
686        if let Some(entry) = state.tasks.get(&key) {
687            entry.abort_handle.abort();
688            tracing::debug!(task.name = name, "task aborted via supervisor");
689        }
690    }
691
692    /// Gracefully shut down all supervised tasks.
693    ///
694    /// Cancels the supervisor's [`CancellationToken`] and waits up to `timeout`
695    /// for all tasks to exit. Tasks that do not exit within the timeout are
696    /// aborted forcefully and their registry entries updated to [`TaskStatus::Aborted`].
697    ///
698    /// # Note
699    ///
700    /// This cancels the token passed to [`TaskSupervisor::new`]. If you share
701    /// that token with other subsystems, they will be cancelled too. Use a child
702    /// token (`cancel.child_token()`) when the supervisor should not affect
703    /// unrelated components.
704    pub async fn shutdown_all(&self, timeout: Duration) {
705        self.inner.cancel.cancel();
706        let deadline = tokio::time::Instant::now() + timeout;
707        loop {
708            let active = self.active_count();
709            if active == 0 {
710                break;
711            }
712            if tokio::time::Instant::now() >= deadline {
713                let mut remaining_names: Vec<Arc<str>> = Vec::new();
714                {
715                    let mut state = self.inner.state.lock();
716                    for entry in state.tasks.values_mut() {
717                        if matches!(
718                            entry.status,
719                            TaskStatus::Running | TaskStatus::Restarting { .. }
720                        ) {
721                            remaining_names.push(Arc::clone(&entry.name));
722                            entry.abort_handle.abort();
723                            entry.status = TaskStatus::Aborted;
724                        }
725                    }
726                }
727                tracing::warn!(
728                    remaining = active,
729                    tasks = ?remaining_names,
730                    "shutdown timeout — aborting remaining tasks"
731                );
732                break;
733            }
734            tokio::time::sleep(Duration::from_millis(50)).await;
735        }
736    }
737
738    /// Return a point-in-time snapshot of all registered tasks.
739    ///
740    /// Suitable for TUI status panels and structured logging. The returned
741    /// list is sorted by `started_at` ascending.
742    #[must_use]
743    pub fn snapshot(&self) -> Vec<TaskSnapshot> {
744        let state = self.inner.state.lock();
745        let mut snaps: Vec<TaskSnapshot> = state
746            .tasks
747            .values()
748            .map(|e| TaskSnapshot {
749                name: Arc::clone(&e.name),
750                status: e.status.clone(),
751                started_at: e.started_at,
752                restart_count: e.restart_count,
753            })
754            .collect();
755        snaps.sort_by_key(|s| s.started_at);
756        snaps
757    }
758
759    /// Return the number of tasks currently in `Running` or `Restarting` state.
760    #[must_use]
761    pub fn active_count(&self) -> usize {
762        let state = self.inner.state.lock();
763        state
764            .tasks
765            .values()
766            .filter(|e| {
767                matches!(
768                    e.status,
769                    TaskStatus::Running | TaskStatus::Restarting { .. }
770                )
771            })
772            .count()
773    }
774
775    /// Return a clone of the supervisor's [`CancellationToken`].
776    ///
777    /// Callers can use this to check whether shutdown has been initiated.
778    #[must_use]
779    pub fn cancellation_token(&self) -> CancellationToken {
780        self.inner.cancel.clone()
781    }
782
783    // ── Internal helpers ──────────────────────────────────────────────────────
784
785    /// Spawn the actual tokio task. Returns `(AbortHandle, JoinHandle)`.
786    fn do_spawn(
787        name: &'static str,
788        factory: &BoxFactory,
789        cancel: CancellationToken,
790    ) -> (AbortHandle, tokio::task::JoinHandle<()>) {
791        let fut = factory();
792        let span = tracing::info_span!("supervised_task", task.name = name);
793        let jh = tokio::spawn(
794            async move {
795                tokio::select! {
796                    () = fut => {},
797                    () = cancel.cancelled() => {},
798                }
799            }
800            .instrument(span),
801        );
802        let abort = jh.abort_handle();
803        (abort, jh)
804    }
805
806    /// Wire a completion reporter: drives `jh` and sends the result to `completion_tx`.
807    fn wire_completion_reporter(
808        name: Arc<str>,
809        jh: tokio::task::JoinHandle<()>,
810        completion_tx: mpsc::UnboundedSender<Completion>,
811    ) {
812        tokio::spawn(async move {
813            let kind = match jh.await {
814                Ok(()) => CompletionKind::Normal,
815                Err(e) if e.is_panic() => CompletionKind::Panicked,
816                Err(_) => CompletionKind::Cancelled,
817            };
818            let _ = completion_tx.send(Completion { name, kind });
819        });
820    }
821
822    /// Spawn the reap driver. The driver processes completion events from the mpsc channel.
823    ///
824    /// After the cancellation token fires, the driver continues draining the channel
825    /// until it is empty — this ensures that tasks which completed just before cancel
826    /// have their registry entries updated, allowing `shutdown_all` to observe
827    /// `active_count() == 0` correctly.
828    fn start_reap_driver(
829        inner: Arc<Inner>,
830        mut completion_rx: mpsc::UnboundedReceiver<Completion>,
831        cancel: CancellationToken,
832    ) {
833        tokio::spawn(async move {
834            // Phase 1: normal operation — process completions until cancel fires.
835            loop {
836                tokio::select! {
837                    biased;
838                    Some(completion) = completion_rx.recv() => {
839                        Self::handle_completion(&inner, completion).await;
840                    }
841                    () = cancel.cancelled() => break,
842                }
843            }
844
845            // Phase 2: post-cancel drain — keep receiving completions until the
846            // registry reports no active tasks, or the channel closes, or the safety
847            // deadline expires. This prevents losing completions that arrive after
848            // tasks observe cancellation (#3161).
849            let drain_deadline = tokio::time::Instant::now() + SHUTDOWN_DRAIN_TIMEOUT;
850            let active = Self::has_active_tasks(&inner);
851            tracing::debug!(active, "reap driver entered post-cancel drain phase");
852            loop {
853                if !Self::has_active_tasks(&inner) {
854                    break;
855                }
856                let remaining =
857                    drain_deadline.saturating_duration_since(tokio::time::Instant::now());
858                if remaining.is_zero() {
859                    break;
860                }
861                match tokio::time::timeout(remaining, completion_rx.recv()).await {
862                    Ok(Some(completion)) => Self::handle_completion(&inner, completion).await,
863                    // channel closed (unreachable in practice — senders live in Inner), or deadline elapsed
864                    Ok(None) | Err(_) => break,
865                }
866            }
867            tracing::debug!(
868                active = Self::has_active_tasks(&inner),
869                "reap driver drain phase complete"
870            );
871        });
872    }
873
874    /// Returns `true` if any task is in `Running` or `Restarting` state.
875    fn has_active_tasks(inner: &Arc<Inner>) -> bool {
876        let state = inner.state.lock();
877        state.tasks.values().any(|e| {
878            matches!(
879                e.status,
880                TaskStatus::Running | TaskStatus::Restarting { .. }
881            )
882        })
883    }
884
885    /// Process a single task completion event.
886    ///
887    /// Lock is never held across `.await`. Phase 1 classifies the completion
888    /// under lock; Phase 2 sleeps with exponential backoff without a lock;
889    /// Phase 3 spawns the next instance and updates the registry.
890    async fn handle_completion(inner: &Arc<Inner>, completion: Completion) {
891        // Short-circuit: once cancellation has fired, never schedule restarts.
892        // Without this, Restart-policy tasks re-register as Running, causing
893        // has_active_tasks() to stay true and the drain loop to spin until timeout.
894        if inner.cancel.is_cancelled() {
895            let mut state = inner.state.lock();
896            state.tasks.remove(&completion.name);
897            return;
898        }
899
900        let Some((attempt, max, delay)) = Self::classify_completion(inner, &completion) else {
901            return;
902        };
903
904        tracing::warn!(
905            task.name = %completion.name,
906            attempt,
907            max,
908            delay_ms = delay.as_millis(),
909            "restarting supervised task"
910        );
911
912        if !delay.is_zero() {
913            tokio::time::sleep(delay).await;
914        }
915
916        Self::do_restart(inner, &completion.name, attempt);
917    }
918
919    /// Phase 1: classify the completion under lock and return restart parameters if needed.
920    ///
921    /// Returns `Some((attempt, max, backoff_delay))` when a restart should be scheduled.
922    fn classify_completion(
923        inner: &Arc<Inner>,
924        completion: &Completion,
925    ) -> Option<(u32, u32, Duration)> {
926        let mut state = inner.state.lock();
927        let entry = state.tasks.get_mut(&completion.name)?;
928
929        match completion.kind {
930            CompletionKind::Panicked => {
931                tracing::warn!(task.name = %completion.name, "supervised task panicked");
932            }
933            CompletionKind::Normal => {
934                tracing::info!(task.name = %completion.name, "supervised task completed");
935            }
936            CompletionKind::Cancelled => {
937                tracing::debug!(task.name = %completion.name, "supervised task cancelled");
938            }
939        }
940
941        match entry.restart_policy {
942            RestartPolicy::RunOnce => {
943                entry.status = TaskStatus::Completed;
944                state.tasks.remove(&completion.name);
945                None
946            }
947            RestartPolicy::Restart { max, base_delay } => {
948                // Only restart on panic — normal exit and cancellation are not errors.
949                if completion.kind != CompletionKind::Panicked {
950                    entry.status = TaskStatus::Completed;
951                    state.tasks.remove(&completion.name);
952                    return None;
953                }
954                if entry.restart_count >= max {
955                    let reason = format!("panicked after {max} restart(s)");
956                    tracing::error!(
957                        task.name = %completion.name,
958                        attempts = max,
959                        "task failed permanently"
960                    );
961                    entry.status = TaskStatus::Failed { reason };
962                    None
963                } else {
964                    let attempt = entry.restart_count + 1;
965                    entry.status = TaskStatus::Restarting { attempt, max };
966                    // Exponential backoff: base_delay * 2^(attempt-1), capped at MAX_RESTART_DELAY.
967                    let multiplier = 1_u32
968                        .checked_shl(attempt.saturating_sub(1))
969                        .unwrap_or(u32::MAX);
970                    let delay = base_delay.saturating_mul(multiplier).min(MAX_RESTART_DELAY);
971                    Some((attempt, max, delay))
972                }
973            }
974        }
975        // lock released here
976    }
977
978    /// Phase 3: TOCTOU check, collect spawn params under lock, then spawn outside.
979    fn do_restart(inner: &Arc<Inner>, name: &Arc<str>, attempt: u32) {
980        let spawn_params = {
981            let mut state = inner.state.lock();
982            let Some(entry) = state.tasks.get_mut(name.as_ref()) else {
983                tracing::debug!(
984                    task.name = %name,
985                    "task removed during restart delay — skipping"
986                );
987                return;
988            };
989            if !matches!(entry.status, TaskStatus::Restarting { .. }) {
990                return;
991            }
992            let Some(factory) = &entry.factory else {
993                return;
994            };
995            // Wrap factory() in catch_unwind to prevent a factory panic from crashing
996            // the reap driver and orphaning the registry.
997            match std::panic::catch_unwind(std::panic::AssertUnwindSafe(factory)) {
998                Err(_) => {
999                    let reason = format!("factory panicked on restart attempt {attempt}");
1000                    tracing::error!(task.name = %name, attempt, "factory panicked during restart");
1001                    entry.status = TaskStatus::Failed { reason };
1002                    None
1003                }
1004                Ok(fut) => Some((
1005                    fut,
1006                    inner.cancel.clone(),
1007                    inner.completion_tx.clone(),
1008                    name.clone(),
1009                )),
1010            }
1011            // lock released here
1012        };
1013
1014        let Some((fut, cancel, completion_tx, name)) = spawn_params else {
1015            return;
1016        };
1017
1018        let span = tracing::info_span!("supervised_task", task.name = %name);
1019        let jh = tokio::spawn(
1020            async move {
1021                tokio::select! {
1022                    () = fut => {},
1023                    () = cancel.cancelled() => {},
1024                }
1025            }
1026            .instrument(span),
1027        );
1028        let new_abort = jh.abort_handle();
1029
1030        {
1031            let mut state = inner.state.lock();
1032            if let Some(entry) = state.tasks.get_mut(name.as_ref()) {
1033                entry.restart_count = attempt;
1034                entry.status = TaskStatus::Running;
1035                entry.abort_handle = new_abort;
1036            }
1037        }
1038
1039        Self::wire_completion_reporter(name.clone(), jh, completion_tx);
1040    }
1041}
1042
1043// ── Task metrics helpers ──────────────────────────────────────────────────────
1044
1045/// Run `f` and record wall-time and CPU-time metrics when `task-metrics` is enabled.
1046///
1047/// When the feature is disabled this is a zero-overhead identity wrapper —
1048/// no `cpu-time` or `metrics` crates are linked.
1049#[cfg(feature = "task-metrics")]
1050#[inline]
1051fn measure_blocking<F, R>(name: &str, f: F) -> R
1052where
1053    F: FnOnce() -> R,
1054{
1055    use cpu_time::ThreadTime;
1056    let wall_start = std::time::Instant::now();
1057    let cpu_start = ThreadTime::now();
1058    let result = f();
1059    let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
1060    let cpu_ms = cpu_start.elapsed().as_secs_f64() * 1000.0;
1061    metrics::histogram!("zeph.task.wall_time_ms", "task" => name.to_owned()).record(wall_ms);
1062    metrics::histogram!("zeph.task.cpu_time_ms", "task" => name.to_owned()).record(cpu_ms);
1063    tracing::Span::current().record("task.wall_time_ms", wall_ms);
1064    tracing::Span::current().record("task.cpu_time_ms", cpu_ms);
1065    result
1066}
1067
1068/// Identity wrapper when `task-metrics` feature is disabled.
1069///
1070/// Compiles to a direct call to `f()` with no overhead.
1071#[cfg(not(feature = "task-metrics"))]
1072#[inline]
1073fn measure_blocking<F, R>(_name: &str, f: F) -> R
1074where
1075    F: FnOnce() -> R,
1076{
1077    f()
1078}
1079
1080// ── BlockingSpawner impl ──────────────────────────────────────────────────────
1081
1082impl BlockingSpawner for TaskSupervisor {
1083    /// Spawn a named blocking closure through the supervisor.
1084    ///
1085    /// The task is registered in the supervisor registry (visible in
1086    /// [`snapshot`][Self::snapshot] and subject to graceful shutdown) before
1087    /// the closure begins executing.
1088    fn spawn_blocking_named(
1089        &self,
1090        name: Arc<str>,
1091        f: Box<dyn FnOnce() + Send + 'static>,
1092    ) -> tokio::task::JoinHandle<()> {
1093        let handle = self.spawn_blocking(Arc::clone(&name), f);
1094        tokio::spawn(async move {
1095            if let Err(e) = handle.join().await {
1096                tracing::error!(task.name = %name, error = %e, "supervised blocking task failed");
1097            }
1098        })
1099    }
1100}
1101
1102// ── Unit tests ────────────────────────────────────────────────────────────────
1103
1104#[cfg(test)]
1105mod tests {
1106    use std::sync::Arc;
1107    use std::sync::atomic::{AtomicU32, Ordering};
1108    use std::time::Duration;
1109
1110    use tokio_util::sync::CancellationToken;
1111
1112    use super::*;
1113
1114    fn make_supervisor() -> (TaskSupervisor, CancellationToken) {
1115        let cancel = CancellationToken::new();
1116        let sup = TaskSupervisor::new(cancel.clone());
1117        (sup, cancel)
1118    }
1119
1120    #[tokio::test]
1121    async fn test_spawn_and_complete() {
1122        let (sup, _cancel) = make_supervisor();
1123
1124        let done = Arc::new(tokio::sync::Notify::new());
1125        let done2 = Arc::clone(&done);
1126
1127        sup.spawn(TaskDescriptor {
1128            name: "simple",
1129            restart: RestartPolicy::RunOnce,
1130            factory: move || {
1131                let d = Arc::clone(&done2);
1132                async move {
1133                    d.notify_one();
1134                }
1135            },
1136        });
1137
1138        tokio::time::timeout(Duration::from_secs(2), done.notified())
1139            .await
1140            .expect("task should complete");
1141
1142        tokio::time::sleep(Duration::from_millis(50)).await;
1143        assert_eq!(
1144            sup.active_count(),
1145            0,
1146            "RunOnce task should be removed after completion"
1147        );
1148    }
1149
1150    #[tokio::test]
1151    async fn test_panic_capture() {
1152        let (sup, _cancel) = make_supervisor();
1153
1154        sup.spawn(TaskDescriptor {
1155            name: "panicking",
1156            restart: RestartPolicy::RunOnce,
1157            factory: || async { panic!("intentional test panic") },
1158        });
1159
1160        tokio::time::sleep(Duration::from_millis(200)).await;
1161
1162        let snaps = sup.snapshot();
1163        assert!(
1164            snaps.iter().all(|s| s.name.as_ref() != "panicking"),
1165            "entry should be reaped"
1166        );
1167        assert_eq!(
1168            sup.active_count(),
1169            0,
1170            "active count must be 0 after RunOnce panic"
1171        );
1172    }
1173
1174    /// Regression test for S2: Restart-policy tasks must only restart on panic,
1175    /// not on normal completion.
1176    #[tokio::test]
1177    async fn test_restart_only_on_panic() {
1178        let (sup, _cancel) = make_supervisor();
1179
1180        // Part 1: normal completion — must NOT restart.
1181        let normal_counter = Arc::new(AtomicU32::new(0));
1182        let nc = Arc::clone(&normal_counter);
1183        sup.spawn(TaskDescriptor {
1184            name: "normal-exit",
1185            restart: RestartPolicy::Restart {
1186                max: 3,
1187                base_delay: Duration::from_millis(10),
1188            },
1189            factory: move || {
1190                let c = Arc::clone(&nc);
1191                async move {
1192                    c.fetch_add(1, Ordering::SeqCst);
1193                    // Returns normally — no panic.
1194                }
1195            },
1196        });
1197
1198        tokio::time::sleep(Duration::from_millis(300)).await;
1199        assert_eq!(
1200            normal_counter.load(Ordering::SeqCst),
1201            1,
1202            "normal exit must not restart"
1203        );
1204        assert!(
1205            sup.snapshot()
1206                .iter()
1207                .all(|s| s.name.as_ref() != "normal-exit"),
1208            "entry removed after normal exit"
1209        );
1210
1211        // Part 2: panic — MUST restart up to max times.
1212        let panic_counter = Arc::new(AtomicU32::new(0));
1213        let pc = Arc::clone(&panic_counter);
1214        sup.spawn(TaskDescriptor {
1215            name: "panic-exit",
1216            restart: RestartPolicy::Restart {
1217                max: 2,
1218                base_delay: Duration::from_millis(10),
1219            },
1220            factory: move || {
1221                let c = Arc::clone(&pc);
1222                async move {
1223                    c.fetch_add(1, Ordering::SeqCst);
1224                    panic!("test panic");
1225                }
1226            },
1227        });
1228
1229        // initial + 2 restarts = 3 total
1230        tokio::time::sleep(Duration::from_millis(500)).await;
1231        assert!(
1232            panic_counter.load(Ordering::SeqCst) >= 3,
1233            "panicking task must restart max times"
1234        );
1235        let snap = sup
1236            .snapshot()
1237            .into_iter()
1238            .find(|s| s.name.as_ref() == "panic-exit");
1239        assert!(
1240            matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1241            "task must be Failed after exhausting restarts"
1242        );
1243    }
1244
1245    #[tokio::test]
1246    async fn test_restart_policy() {
1247        let (sup, _cancel) = make_supervisor();
1248
1249        let counter = Arc::new(AtomicU32::new(0));
1250        let counter2 = Arc::clone(&counter);
1251
1252        sup.spawn(TaskDescriptor {
1253            name: "restartable",
1254            restart: RestartPolicy::Restart {
1255                max: 2,
1256                base_delay: Duration::from_millis(10),
1257            },
1258            factory: move || {
1259                let c = Arc::clone(&counter2);
1260                async move {
1261                    c.fetch_add(1, Ordering::SeqCst);
1262                    panic!("always panic");
1263                }
1264            },
1265        });
1266
1267        tokio::time::sleep(Duration::from_millis(500)).await;
1268
1269        let runs = counter.load(Ordering::SeqCst);
1270        assert!(
1271            runs >= 3,
1272            "expected at least 3 invocations (initial + 2 restarts), got {runs}"
1273        );
1274
1275        let snaps = sup.snapshot();
1276        let snap = snaps.iter().find(|s| s.name.as_ref() == "restartable");
1277        assert!(snap.is_some(), "failed task should remain in registry");
1278        assert!(
1279            matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1280            "task should be Failed after exhausting retries"
1281        );
1282    }
1283
1284    /// Verify exponential backoff: delay doubles on each restart attempt.
1285    #[tokio::test]
1286    async fn test_exponential_backoff() {
1287        let (sup, _cancel) = make_supervisor();
1288
1289        let timestamps = Arc::new(parking_lot::Mutex::new(Vec::<std::time::Instant>::new()));
1290        let ts = Arc::clone(&timestamps);
1291
1292        sup.spawn(TaskDescriptor {
1293            name: "backoff-task",
1294            restart: RestartPolicy::Restart {
1295                max: 3,
1296                base_delay: Duration::from_millis(50),
1297            },
1298            factory: move || {
1299                let t = Arc::clone(&ts);
1300                async move {
1301                    t.lock().push(std::time::Instant::now());
1302                    panic!("always panic");
1303                }
1304            },
1305        });
1306
1307        // Wait long enough for all restarts: 50 + 100 + 200 ms = 350 ms + overhead
1308        tokio::time::sleep(Duration::from_millis(800)).await;
1309
1310        let ts = timestamps.lock();
1311        assert!(
1312            ts.len() >= 3,
1313            "expected at least 3 invocations, got {}",
1314            ts.len()
1315        );
1316
1317        // Verify delays are roughly doubling (within 2x tolerance for CI jitter).
1318        if ts.len() >= 3 {
1319            let d1 = ts[1].duration_since(ts[0]);
1320            let d2 = ts[2].duration_since(ts[1]);
1321            // d2 should be at least 1.5x d1 (allowing for jitter).
1322            assert!(
1323                d2 >= d1.mul_f64(1.5),
1324                "expected exponential backoff: d1={d1:?} d2={d2:?}"
1325            );
1326        }
1327    }
1328
1329    #[tokio::test]
1330    async fn test_graceful_shutdown() {
1331        let (sup, _cancel) = make_supervisor();
1332
1333        for name in ["svc-a", "svc-b", "svc-c"] {
1334            sup.spawn(TaskDescriptor {
1335                name,
1336                restart: RestartPolicy::RunOnce,
1337                factory: || async {
1338                    tokio::time::sleep(Duration::from_mins(1)).await;
1339                },
1340            });
1341        }
1342
1343        assert_eq!(sup.active_count(), 3);
1344
1345        tokio::time::timeout(
1346            Duration::from_secs(2),
1347            sup.shutdown_all(Duration::from_secs(1)),
1348        )
1349        .await
1350        .expect("shutdown should complete within timeout");
1351    }
1352
1353    /// Verify that force-aborted tasks get `TaskStatus::Aborted` in the registry (A2 fix).
1354    #[tokio::test]
1355    async fn test_force_abort_marks_aborted() {
1356        let cancel = CancellationToken::new();
1357        let sup = TaskSupervisor::new(cancel.clone());
1358
1359        sup.spawn(TaskDescriptor {
1360            name: "stubborn-for-abort",
1361            restart: RestartPolicy::RunOnce,
1362            factory: || async {
1363                // Does not cooperate with cancellation.
1364                std::future::pending::<()>().await;
1365            },
1366        });
1367
1368        // Use a very short timeout to trigger force-abort.
1369        sup.shutdown_all(Duration::from_millis(1)).await;
1370
1371        // Entry should be Aborted, not Running.
1372        let snaps = sup.snapshot();
1373        if let Some(snap) = snaps
1374            .iter()
1375            .find(|s| s.name.as_ref() == "stubborn-for-abort")
1376        {
1377            assert_eq!(
1378                snap.status,
1379                TaskStatus::Aborted,
1380                "force-aborted task must have Aborted status"
1381            );
1382        }
1383        // If entry was already reaped (cooperative cancel won), that's also acceptable.
1384    }
1385
1386    #[tokio::test]
1387    async fn test_registry_snapshot() {
1388        let (sup, _cancel) = make_supervisor();
1389
1390        for name in ["alpha", "beta"] {
1391            sup.spawn(TaskDescriptor {
1392                name,
1393                restart: RestartPolicy::RunOnce,
1394                factory: || async {
1395                    tokio::time::sleep(Duration::from_secs(10)).await;
1396                },
1397            });
1398        }
1399
1400        let snaps = sup.snapshot();
1401        assert_eq!(snaps.len(), 2);
1402        let names: Vec<&str> = snaps.iter().map(|s| s.name.as_ref()).collect();
1403        assert!(names.contains(&"alpha"));
1404        assert!(names.contains(&"beta"));
1405        assert!(snaps.iter().all(|s| s.status == TaskStatus::Running));
1406    }
1407
1408    #[tokio::test]
1409    async fn test_blocking_returns_value() {
1410        let (sup, cancel) = make_supervisor();
1411
1412        let handle: BlockingHandle<u32> = sup.spawn_blocking(Arc::from("compute"), || 42_u32);
1413        let result = handle.join().await.expect("should return value");
1414        assert_eq!(result, 42);
1415        cancel.cancel();
1416    }
1417
1418    #[tokio::test]
1419    async fn test_blocking_panic() {
1420        let (sup, _cancel) = make_supervisor();
1421
1422        let handle: BlockingHandle<u32> =
1423            sup.spawn_blocking(Arc::from("panicking-compute"), || panic!("intentional"));
1424        let err = handle
1425            .join()
1426            .await
1427            .expect_err("should return error on panic");
1428        assert_eq!(err, BlockingError::Panicked);
1429    }
1430
1431    /// Verify `spawn_blocking` tasks appear in registry (M3 fix).
1432    #[tokio::test]
1433    async fn test_blocking_registered_in_registry() {
1434        let (sup, cancel) = make_supervisor();
1435
1436        let (tx, rx) = std::sync::mpsc::channel::<()>();
1437        let _handle: BlockingHandle<()> =
1438            sup.spawn_blocking(Arc::from("blocking-task"), move || {
1439                // Block until signalled.
1440                let _ = rx.recv();
1441            });
1442
1443        tokio::time::sleep(Duration::from_millis(10)).await;
1444        assert_eq!(
1445            sup.active_count(),
1446            1,
1447            "blocking task must appear in active_count"
1448        );
1449
1450        let _ = tx.send(());
1451        tokio::time::sleep(Duration::from_millis(100)).await;
1452        assert_eq!(
1453            sup.active_count(),
1454            0,
1455            "blocking task must be removed after completion"
1456        );
1457
1458        cancel.cancel();
1459    }
1460
1461    /// Verify `spawn_oneshot` tasks appear in registry (M3 fix).
1462    #[tokio::test]
1463    async fn test_oneshot_registered_in_registry() {
1464        let (sup, cancel) = make_supervisor();
1465
1466        let (tx, rx) = tokio::sync::oneshot::channel::<()>();
1467        let _handle: BlockingHandle<()> =
1468            sup.spawn_oneshot(Arc::from("oneshot-task"), move || async move {
1469                let _ = rx.await;
1470            });
1471
1472        tokio::time::sleep(Duration::from_millis(10)).await;
1473        assert_eq!(
1474            sup.active_count(),
1475            1,
1476            "oneshot task must appear in active_count"
1477        );
1478
1479        let _ = tx.send(());
1480        tokio::time::sleep(Duration::from_millis(50)).await;
1481        assert_eq!(
1482            sup.active_count(),
1483            0,
1484            "oneshot task must be removed after completion"
1485        );
1486
1487        cancel.cancel();
1488    }
1489
1490    #[tokio::test]
1491    async fn test_restart_max_zero() {
1492        let (sup, _cancel) = make_supervisor();
1493
1494        let counter = Arc::new(AtomicU32::new(0));
1495        let counter2 = Arc::clone(&counter);
1496
1497        sup.spawn(TaskDescriptor {
1498            name: "zero-max",
1499            restart: RestartPolicy::Restart {
1500                max: 0,
1501                base_delay: Duration::from_millis(10),
1502            },
1503            factory: move || {
1504                let c = Arc::clone(&counter2);
1505                async move {
1506                    c.fetch_add(1, Ordering::SeqCst);
1507                    panic!("always panic");
1508                }
1509            },
1510        });
1511
1512        tokio::time::sleep(Duration::from_millis(200)).await;
1513
1514        assert_eq!(
1515            counter.load(Ordering::SeqCst),
1516            1,
1517            "max=0 should not restart"
1518        );
1519
1520        let snaps = sup.snapshot();
1521        let snap = snaps.iter().find(|s| s.name.as_ref() == "zero-max");
1522        assert!(snap.is_some(), "entry should remain as Failed");
1523        assert!(
1524            matches!(snap.unwrap().status, TaskStatus::Failed { .. }),
1525            "status should be Failed"
1526        );
1527    }
1528
1529    /// Stress test: spawn 50 tasks concurrently, all must complete and registry must be accurate.
1530    #[tokio::test]
1531    async fn test_concurrent_spawns() {
1532        // All task names must be 'static — pre-defined before any let statements.
1533        static NAMES: [&str; 50] = [
1534            "t00", "t01", "t02", "t03", "t04", "t05", "t06", "t07", "t08", "t09", "t10", "t11",
1535            "t12", "t13", "t14", "t15", "t16", "t17", "t18", "t19", "t20", "t21", "t22", "t23",
1536            "t24", "t25", "t26", "t27", "t28", "t29", "t30", "t31", "t32", "t33", "t34", "t35",
1537            "t36", "t37", "t38", "t39", "t40", "t41", "t42", "t43", "t44", "t45", "t46", "t47",
1538            "t48", "t49",
1539        ];
1540        let (sup, cancel) = make_supervisor();
1541
1542        let completed = Arc::new(AtomicU32::new(0));
1543        for name in &NAMES {
1544            let c = Arc::clone(&completed);
1545            sup.spawn(TaskDescriptor {
1546                name,
1547                restart: RestartPolicy::RunOnce,
1548                factory: move || {
1549                    let c = Arc::clone(&c);
1550                    async move {
1551                        c.fetch_add(1, Ordering::SeqCst);
1552                    }
1553                },
1554            });
1555        }
1556
1557        // Wait for all tasks to complete.
1558        tokio::time::timeout(Duration::from_secs(5), async {
1559            loop {
1560                if completed.load(Ordering::SeqCst) == 50 {
1561                    break;
1562                }
1563                tokio::time::sleep(Duration::from_millis(10)).await;
1564            }
1565        })
1566        .await
1567        .expect("all 50 tasks should complete");
1568
1569        // Give reap driver time to process all completions.
1570        tokio::time::sleep(Duration::from_millis(100)).await;
1571        assert_eq!(sup.active_count(), 0, "all tasks must be reaped");
1572
1573        cancel.cancel();
1574    }
1575
1576    #[tokio::test]
1577    async fn test_shutdown_timeout_expiry() {
1578        let cancel = CancellationToken::new();
1579        let sup = TaskSupervisor::new(cancel.clone());
1580
1581        sup.spawn(TaskDescriptor {
1582            name: "stubborn",
1583            restart: RestartPolicy::RunOnce,
1584            factory: || async {
1585                tokio::time::sleep(Duration::from_mins(1)).await;
1586            },
1587        });
1588
1589        assert_eq!(sup.active_count(), 1);
1590
1591        tokio::time::timeout(
1592            Duration::from_secs(2),
1593            sup.shutdown_all(Duration::from_millis(50)),
1594        )
1595        .await
1596        .expect("shutdown_all should return even on timeout expiry");
1597
1598        assert!(
1599            cancel.is_cancelled(),
1600            "cancel token must be cancelled after shutdown"
1601        );
1602    }
1603
1604    #[tokio::test]
1605    async fn test_cancellation_token() {
1606        let cancel = CancellationToken::new();
1607        let sup = TaskSupervisor::new(cancel.clone());
1608
1609        assert!(!sup.cancellation_token().is_cancelled());
1610
1611        sup.shutdown_all(Duration::from_millis(100)).await;
1612
1613        assert!(
1614            sup.cancellation_token().is_cancelled(),
1615            "token must be cancelled after shutdown"
1616        );
1617    }
1618
1619    /// Regression test for #3161: after `shutdown_all`, all tasks must be reaped
1620    /// even when they complete *after* the cancel signal.
1621    ///
1622    /// The yield loop forces the reap driver to observe cancel and exit phase-1
1623    /// before the tasks send their completions — reliably reproducing the race.
1624    #[tokio::test]
1625    async fn test_shutdown_drains_post_cancel_completions() {
1626        let cancel = CancellationToken::new();
1627        let sup = TaskSupervisor::new(cancel.clone());
1628
1629        for name in [
1630            "loop-1", "loop-2", "loop-3", "loop-4", "loop-5", "loop-6", "loop-7",
1631        ] {
1632            let cancel_inner = cancel.clone();
1633            sup.spawn(TaskDescriptor {
1634                name,
1635                restart: RestartPolicy::RunOnce,
1636                factory: move || {
1637                    let c = cancel_inner.clone();
1638                    async move {
1639                        c.cancelled().await;
1640                        // Yield multiple times so the reap driver observes cancel first.
1641                        for _ in 0..64 {
1642                            tokio::task::yield_now().await;
1643                        }
1644                    }
1645                },
1646            });
1647        }
1648        assert_eq!(sup.active_count(), 7);
1649
1650        sup.shutdown_all(Duration::from_secs(2)).await;
1651
1652        assert_eq!(
1653            sup.active_count(),
1654            0,
1655            "all tasks must be reaped after shutdown (#3161)"
1656        );
1657    }
1658
1659    #[tokio::test]
1660    async fn test_blocking_spawner_task_appears_in_snapshot() {
1661        // Verify that tasks spawned via BlockingSpawner appear in supervisor.snapshot().
1662        use crate::BlockingSpawner;
1663
1664        let cancel = CancellationToken::new();
1665        let sup = TaskSupervisor::new(cancel);
1666
1667        let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<()>();
1668        let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>();
1669
1670        let handle = sup.spawn_blocking_named(
1671            Arc::from("chunk_file"),
1672            Box::new(move || {
1673                // Signal that the task has started.
1674                let _ = ready_tx.send(());
1675                // Block until test signals release.
1676                let _ = release_rx.blocking_recv();
1677            }),
1678        );
1679
1680        // Wait until the blocking task has actually started.
1681        ready_rx.await.expect("task should start");
1682
1683        let snapshot = sup.snapshot();
1684        assert!(
1685            snapshot.iter().any(|t| t.name.as_ref() == "chunk_file"),
1686            "chunk_file task must appear in supervisor snapshot"
1687        );
1688
1689        // Release the blocking task and await completion.
1690        let _ = release_tx.send(());
1691        handle.await.expect("task should complete");
1692    }
1693
1694    /// Verify that `measure_blocking` emits wall-time and CPU-time histograms when
1695    /// the `task-metrics` feature is enabled.
1696    ///
1697    /// `measure_blocking` calls `metrics::histogram!` on the current thread.
1698    /// We test it directly using a `DebuggingRecorder` installed as the thread-local
1699    /// recorder via `metrics::with_local_recorder`.
1700    #[cfg(feature = "task-metrics")]
1701    #[test]
1702    fn test_measure_blocking_emits_metrics() {
1703        use metrics_util::debugging::DebuggingRecorder;
1704
1705        let recorder = DebuggingRecorder::new();
1706        let snapshotter = recorder.snapshotter();
1707
1708        // Call measure_blocking inside the local recorder scope so histogram! calls
1709        // are captured. The closure runs synchronously on this thread.
1710        metrics::with_local_recorder(&recorder, || {
1711            measure_blocking("test_task", || std::hint::black_box(42_u64));
1712        });
1713
1714        let snapshot = snapshotter.snapshot();
1715        let metric_names: Vec<String> = snapshot
1716            .into_vec()
1717            .into_iter()
1718            .map(|(k, _, _, _)| k.key().name().to_owned())
1719            .collect();
1720
1721        assert!(
1722            metric_names.iter().any(|n| n == "zeph.task.wall_time_ms"),
1723            "expected zeph.task.wall_time_ms histogram; got: {metric_names:?}"
1724        );
1725        assert!(
1726            metric_names.iter().any(|n| n == "zeph.task.cpu_time_ms"),
1727            "expected zeph.task.cpu_time_ms histogram; got: {metric_names:?}"
1728        );
1729    }
1730
1731    /// Verify that `spawn_blocking` semaphore limits concurrent OS-thread tasks to 8.
1732    ///
1733    /// Spawns 16 tasks. Each holds a barrier until 8 are waiting; then releases in order.
1734    /// If more than 8 run concurrently the test would either deadlock (waiting for 9+ to reach
1735    /// the barrier) or the counter would exceed 8 — both are caught.
1736    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1737    async fn test_spawn_blocking_semaphore_cap() {
1738        let (sup, _cancel) = make_supervisor();
1739        let concurrent = Arc::new(AtomicU32::new(0));
1740        let max_concurrent = Arc::new(AtomicU32::new(0));
1741        let barrier = Arc::new(std::sync::Barrier::new(1)); // just a sync point
1742
1743        let mut handles = Vec::new();
1744        for i in 0u32..16 {
1745            let c = Arc::clone(&concurrent);
1746            let m = Arc::clone(&max_concurrent);
1747            let name: Arc<str> = Arc::from(format!("blocking-{i}").as_str());
1748            let h = sup.spawn_blocking(name, move || {
1749                let prev = c.fetch_add(1, Ordering::SeqCst);
1750                // Update observed maximum.
1751                let mut cur_max = m.load(Ordering::SeqCst);
1752                while prev + 1 > cur_max {
1753                    match m.compare_exchange(cur_max, prev + 1, Ordering::SeqCst, Ordering::SeqCst)
1754                    {
1755                        Ok(_) => break,
1756                        Err(x) => cur_max = x,
1757                    }
1758                }
1759                // Simulate work.
1760                std::thread::sleep(std::time::Duration::from_millis(20));
1761                c.fetch_sub(1, Ordering::SeqCst);
1762            });
1763            handles.push(h);
1764        }
1765
1766        for h in handles {
1767            h.join().await.expect("blocking task should succeed");
1768        }
1769        drop(barrier);
1770
1771        let observed = max_concurrent.load(Ordering::SeqCst);
1772        assert!(
1773            observed <= 8,
1774            "observed {observed} concurrent blocking tasks; expected ≤ 8 (semaphore cap)"
1775        );
1776    }
1777}