Skip to main content

sage_runtime/
supervisor.rs

1//! Supervision tree implementation for Sage v2.
2//!
3//! This module provides Erlang/OTP-style supervision trees for managing
4//! agent lifecycles with automatic restart capabilities.
5//!
6//! # Supervision Strategies
7//!
8//! - **OneForOne**: Restart only the failed child
9//! - **OneForAll**: Restart all children if one fails
10//! - **RestForOne**: Restart the failed child and all children started after it
11//!
12//! # Restart Policies
13//!
14//! - **Permanent**: Always restart, regardless of exit reason
15//! - **Transient**: Restart only on abnormal termination (error)
16//! - **Temporary**: Never restart
17
18use crate::error::{SageError, SageResult};
19use std::time::Duration;
20
21/// Supervision strategy (OTP-inspired).
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
23pub enum Strategy {
24    /// Restart only the failed child.
25    #[default]
26    OneForOne,
27    /// Restart all children if one fails.
28    OneForAll,
29    /// Restart the failed child and all children started after it.
30    RestForOne,
31}
32
33/// Restart policy for supervised children.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
35pub enum RestartPolicy {
36    /// Always restart, regardless of exit reason.
37    #[default]
38    Permanent,
39    /// Restart only on abnormal termination (error).
40    Transient,
41    /// Never restart.
42    Temporary,
43}
44
45/// Configuration for restart intensity limiting (circuit breaker).
46#[derive(Debug, Clone)]
47pub struct RestartConfig {
48    /// Maximum number of restarts allowed within the time window.
49    pub max_restarts: u32,
50    /// Time window in which max_restarts is measured.
51    pub within: Duration,
52}
53
54impl Default for RestartConfig {
55    fn default() -> Self {
56        Self {
57            max_restarts: 5,
58            within: Duration::from_secs(60),
59        }
60    }
61}
62
63// ===========================================================================
64// Native implementation — uses tokio::spawn and JoinHandle
65// ===========================================================================
66
67#[cfg(not(target_arch = "wasm32"))]
68mod native {
69    use super::*;
70    use std::collections::VecDeque;
71    use std::future::Future;
72    use std::pin::Pin;
73    use std::time::Instant;
74    use tokio::task::JoinHandle;
75
76    /// Tracks restart history for circuit breaker functionality.
77    struct RestartTracker {
78        timestamps: VecDeque<Instant>,
79        config: RestartConfig,
80    }
81
82    impl RestartTracker {
83        fn new(config: RestartConfig) -> Self {
84            Self {
85                timestamps: VecDeque::new(),
86                config,
87            }
88        }
89
90        /// Record a restart and check if we've exceeded the limit.
91        /// Returns true if we should allow the restart, false if circuit breaker trips.
92        fn record_restart(&mut self) -> bool {
93            let now = Instant::now();
94
95            // Remove old timestamps outside the window
96            while let Some(&oldest) = self.timestamps.front() {
97                if now.duration_since(oldest) > self.config.within {
98                    self.timestamps.pop_front();
99                } else {
100                    break;
101                }
102            }
103
104            // Check if we're at the limit
105            if self.timestamps.len() >= self.config.max_restarts as usize {
106                return false; // Circuit breaker trips
107            }
108
109            self.timestamps.push_back(now);
110            true
111        }
112    }
113
114    /// A spawn function that creates an agent and returns its join handle.
115    pub type SpawnFn = Box<dyn Fn() -> Pin<Box<dyn Future<Output = SageResult<()>> + Send>> + Send>;
116
117    /// Handle to a supervised child.
118    struct ChildHandle {
119        name: String,
120        restart_policy: RestartPolicy,
121        spawn_fn: SpawnFn,
122        handle: Option<JoinHandle<SageResult<()>>>,
123    }
124
125    impl ChildHandle {
126        fn new(name: String, restart_policy: RestartPolicy, spawn_fn: SpawnFn) -> Self {
127            Self {
128                name,
129                restart_policy,
130                spawn_fn,
131                handle: None,
132            }
133        }
134
135        /// Spawn (or respawn) this child.
136        fn spawn(&mut self) {
137            let future = (self.spawn_fn)();
138            self.handle = Some(tokio::spawn(future));
139        }
140
141        /// Check if the child is running.
142        fn is_running(&self) -> bool {
143            self.handle
144                .as_ref()
145                .map(|h| !h.is_finished())
146                .unwrap_or(false)
147        }
148
149        /// Take the join handle (for awaiting).
150        fn take_handle(&mut self) -> Option<JoinHandle<SageResult<()>>> {
151            self.handle.take()
152        }
153    }
154
155    /// A supervisor that manages child agents with restart strategies.
156    pub struct Supervisor {
157        strategy: Strategy,
158        children: Vec<ChildHandle>,
159        restart_tracker: RestartTracker,
160    }
161
162    impl Supervisor {
163        /// Create a new supervisor with the given strategy and restart configuration.
164        pub fn new(strategy: Strategy, config: RestartConfig) -> Self {
165            Self {
166                strategy,
167                children: Vec::new(),
168                restart_tracker: RestartTracker::new(config),
169            }
170        }
171
172        /// Add a child to the supervisor.
173        ///
174        /// The spawn function should create the agent and return its future.
175        pub fn add_child<F, Fut>(
176            &mut self,
177            name: impl Into<String>,
178            restart_policy: RestartPolicy,
179            spawn_fn: F,
180        ) where
181            F: Fn() -> Fut + Send + 'static,
182            Fut: Future<Output = SageResult<()>> + Send + 'static,
183        {
184            let spawn_fn: SpawnFn = Box::new(move || Box::pin(spawn_fn()));
185            self.children
186                .push(ChildHandle::new(name.into(), restart_policy, spawn_fn));
187        }
188
189        /// Start all children and begin supervision.
190        ///
191        /// This method runs until all children have terminated (according to their
192        /// restart policies) or the circuit breaker trips.
193        pub async fn run(&mut self) -> SageResult<()> {
194            // Start all children
195            for child in &mut self.children {
196                child.spawn();
197            }
198
199            // Monitor loop
200            loop {
201                // Wait for any child to complete
202                let (index, result) = self.wait_for_child_exit().await;
203
204                // Check if all children are done
205                if index.is_none() {
206                    // All children have finished
207                    break;
208                }
209
210                let index = index.unwrap();
211                let child_name = self.children[index].name.clone();
212                let restart_policy = self.children[index].restart_policy;
213
214                // Determine if we should restart
215                let should_restart = match (restart_policy, &result) {
216                    (RestartPolicy::Permanent, _) => true,
217                    (RestartPolicy::Transient, Err(_)) => true,
218                    (RestartPolicy::Transient, Ok(_)) => false,
219                    (RestartPolicy::Temporary, _) => false,
220                };
221
222                if should_restart {
223                    // Check circuit breaker
224                    if !self.restart_tracker.record_restart() {
225                        return Err(SageError::Supervisor(format!(
226                            "Maximum restart intensity reached for supervisor (child '{}' failed too many times)",
227                            child_name
228                        )));
229                    }
230
231                    // Apply restart strategy
232                    match self.strategy {
233                        Strategy::OneForOne => {
234                            self.restart_child(index);
235                        }
236                        Strategy::OneForAll => {
237                            self.restart_all();
238                        }
239                        Strategy::RestForOne => {
240                            self.restart_rest(index);
241                        }
242                    }
243                }
244
245                // Check if any children are still running
246                if !self.any_running() {
247                    break;
248                }
249            }
250
251            Ok(())
252        }
253
254        /// Wait for any child to exit, returning the index and result.
255        async fn wait_for_child_exit(&mut self) -> (Option<usize>, SageResult<()>) {
256            use futures::future::select_all;
257
258            // Collect all running children's handles with their indices
259            let handles_with_indices: Vec<(usize, JoinHandle<SageResult<()>>)> = self
260                .children
261                .iter_mut()
262                .enumerate()
263                .filter_map(|(i, c)| c.take_handle().map(|h| (i, h)))
264                .collect();
265
266            if handles_with_indices.is_empty() {
267                return (None, Ok(()));
268            }
269
270            // We need to track indices separately since select_all works on the handles
271            let indices: Vec<usize> = handles_with_indices.iter().map(|(i, _)| *i).collect();
272            let handles: Vec<JoinHandle<SageResult<()>>> =
273                handles_with_indices.into_iter().map(|(_, h)| h).collect();
274
275            // Wait for any handle to complete
276            let (join_result, completed_idx, remaining_handles) = select_all(handles).await;
277
278            // Get the original child index
279            let child_index = indices[completed_idx];
280
281            // Convert JoinError to SageError
282            let final_result =
283                join_result.unwrap_or_else(|e| Err(SageError::Agent(e.to_string())));
284
285            // Put back the remaining handles to their respective children
286            let mut remaining_iter = remaining_handles.into_iter();
287            for (pos, &original_idx) in indices.iter().enumerate() {
288                if pos != completed_idx {
289                    if let (Some(handle), Some(child)) =
290                        (remaining_iter.next(), self.children.get_mut(original_idx))
291                    {
292                        child.handle = Some(handle);
293                    }
294                }
295            }
296
297            (Some(child_index), final_result)
298        }
299
300        /// Restart a single child.
301        fn restart_child(&mut self, index: usize) {
302            if let Some(child) = self.children.get_mut(index) {
303                child.spawn();
304            }
305        }
306
307        /// Restart all children (stop all first, then start all).
308        fn restart_all(&mut self) {
309            // Abort all running children
310            for child in &mut self.children {
311                if let Some(handle) = child.take_handle() {
312                    handle.abort();
313                }
314            }
315
316            // Start all children
317            for child in &mut self.children {
318                child.spawn();
319            }
320        }
321
322        /// Restart the failed child and all children started after it.
323        fn restart_rest(&mut self, from_index: usize) {
324            // Abort children from index onwards
325            for child in self.children.iter_mut().skip(from_index) {
326                if let Some(handle) = child.take_handle() {
327                    handle.abort();
328                }
329            }
330
331            // Restart children from index onwards
332            for child in self.children.iter_mut().skip(from_index) {
333                child.spawn();
334            }
335        }
336
337        /// Check if any children are still running.
338        fn any_running(&self) -> bool {
339            self.children.iter().any(|c| c.is_running())
340        }
341    }
342}
343
344#[cfg(not(target_arch = "wasm32"))]
345pub use native::{SpawnFn, Supervisor};
346
347// ===========================================================================
348// WASM stub — supervision not yet supported in the browser
349// ===========================================================================
350
351#[cfg(target_arch = "wasm32")]
352mod wasm_stub {
353    use super::*;
354    use std::future::Future;
355
356    /// A supervisor that manages child agents with restart strategies.
357    ///
358    /// **WASM note:** Supervision is not yet supported in the browser target.
359    /// This is a stub that will return an error if `run()` is called.
360    pub struct Supervisor {
361        _strategy: Strategy,
362    }
363
364    impl Supervisor {
365        /// Create a new supervisor.
366        pub fn new(strategy: Strategy, _config: RestartConfig) -> Self {
367            Self {
368                _strategy: strategy,
369            }
370        }
371
372        /// Add a child to the supervisor (no-op on WASM).
373        pub fn add_child<F, Fut>(
374            &mut self,
375            _name: impl Into<String>,
376            _restart_policy: RestartPolicy,
377            _spawn_fn: F,
378        ) where
379            F: Fn() -> Fut + 'static,
380            Fut: Future<Output = SageResult<()>> + 'static,
381        {
382            // No-op — children are not tracked on WASM
383        }
384
385        /// Run the supervisor.
386        ///
387        /// Returns an error on WASM as supervision is not yet supported.
388        pub async fn run(&mut self) -> SageResult<()> {
389            Err(SageError::Supervisor(
390                "Supervision trees are not yet supported in the WASM target".to_string(),
391            ))
392        }
393    }
394}
395
396#[cfg(target_arch = "wasm32")]
397pub use wasm_stub::Supervisor;
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use std::sync::atomic::{AtomicU32, Ordering};
403    use std::sync::Arc;
404
405    #[tokio::test]
406    async fn test_one_for_one_restart() {
407        let counter = Arc::new(AtomicU32::new(0));
408        let counter_clone = counter.clone();
409
410        let mut supervisor = Supervisor::new(Strategy::OneForOne, RestartConfig::default());
411
412        // Use Transient policy - restart on error, stop on success
413        supervisor.add_child("Worker", RestartPolicy::Transient, move || {
414            let counter = counter_clone.clone();
415            async move {
416                let count = counter.fetch_add(1, Ordering::SeqCst);
417                if count < 2 {
418                    Err(SageError::Agent("Simulated failure".to_string()))
419                } else {
420                    Ok(())
421                }
422            }
423        });
424
425        let result = supervisor.run().await;
426        assert!(result.is_ok(), "supervisor failed: {:?}", result);
427        assert_eq!(counter.load(Ordering::SeqCst), 3);
428    }
429
430    #[tokio::test]
431    async fn test_transient_no_restart_on_success() {
432        let counter = Arc::new(AtomicU32::new(0));
433        let counter_clone = counter.clone();
434
435        let mut supervisor = Supervisor::new(Strategy::OneForOne, RestartConfig::default());
436
437        supervisor.add_child("Worker", RestartPolicy::Transient, move || {
438            let counter = counter_clone.clone();
439            async move {
440                counter.fetch_add(1, Ordering::SeqCst);
441                Ok(())
442            }
443        });
444
445        let result = supervisor.run().await;
446        assert!(result.is_ok());
447        assert_eq!(counter.load(Ordering::SeqCst), 1); // Only ran once
448    }
449
450    #[tokio::test]
451    async fn test_temporary_never_restarts() {
452        let counter = Arc::new(AtomicU32::new(0));
453        let counter_clone = counter.clone();
454
455        let mut supervisor = Supervisor::new(Strategy::OneForOne, RestartConfig::default());
456
457        supervisor.add_child("Worker", RestartPolicy::Temporary, move || {
458            let counter = counter_clone.clone();
459            async move {
460                counter.fetch_add(1, Ordering::SeqCst);
461                Err(SageError::Agent("Simulated failure".to_string()))
462            }
463        });
464
465        let result = supervisor.run().await;
466        assert!(result.is_ok()); // Supervisor should succeed even if child fails
467        assert_eq!(counter.load(Ordering::SeqCst), 1); // Only ran once
468    }
469
470    #[tokio::test]
471    async fn test_circuit_breaker() {
472        let counter = Arc::new(AtomicU32::new(0));
473        let counter_clone = counter.clone();
474
475        let config = RestartConfig {
476            max_restarts: 3,
477            within: Duration::from_secs(60),
478        };
479
480        let mut supervisor = Supervisor::new(Strategy::OneForOne, config);
481
482        supervisor.add_child("Worker", RestartPolicy::Permanent, move || {
483            let counter = counter_clone.clone();
484            async move {
485                counter.fetch_add(1, Ordering::SeqCst);
486                Err(SageError::Agent("Always fails".to_string()))
487            }
488        });
489
490        let result = supervisor.run().await;
491        assert!(result.is_err()); // Circuit breaker should trip
492        assert!(counter.load(Ordering::SeqCst) <= 4); // At most 4 attempts (1 + 3 restarts)
493    }
494
495    #[tokio::test]
496    async fn test_permanent_restarts_on_success() {
497        // Permanent policy restarts even when child exits normally.
498        // This test verifies the circuit breaker eventually stops it.
499        let counter = Arc::new(AtomicU32::new(0));
500        let counter_clone = counter.clone();
501
502        let config = RestartConfig {
503            max_restarts: 3,
504            within: Duration::from_secs(60),
505        };
506
507        let mut supervisor = Supervisor::new(Strategy::OneForOne, config);
508
509        supervisor.add_child("Worker", RestartPolicy::Permanent, move || {
510            let counter = counter_clone.clone();
511            async move {
512                counter.fetch_add(1, Ordering::SeqCst);
513                Ok(()) // Exits successfully each time
514            }
515        });
516
517        let result = supervisor.run().await;
518        // Circuit breaker trips because Permanent keeps restarting even on success
519        assert!(result.is_err());
520        assert!(counter.load(Ordering::SeqCst) <= 4);
521    }
522
523    #[tokio::test]
524    async fn test_rest_for_one_restarts_downstream() {
525        // RestForOne: when child fails, it and all children added after it restart.
526        let counter1 = Arc::new(AtomicU32::new(0));
527        let counter2 = Arc::new(AtomicU32::new(0));
528        let counter3 = Arc::new(AtomicU32::new(0));
529        let counter1_clone = counter1.clone();
530        let counter2_clone = counter2.clone();
531        let counter3_clone = counter3.clone();
532
533        let mut supervisor = Supervisor::new(Strategy::RestForOne, RestartConfig::default());
534
535        // Child 1: Always succeeds
536        supervisor.add_child("Child1", RestartPolicy::Temporary, move || {
537            let counter = counter1_clone.clone();
538            async move {
539                counter.fetch_add(1, Ordering::SeqCst);
540                // Wait a bit so it doesn't exit before child 2 fails
541                tokio::time::sleep(Duration::from_millis(50)).await;
542                Ok(())
543            }
544        });
545
546        // Child 2: Fails twice then succeeds (this triggers RestForOne)
547        supervisor.add_child("Child2", RestartPolicy::Transient, move || {
548            let counter = counter2_clone.clone();
549            async move {
550                let count = counter.fetch_add(1, Ordering::SeqCst);
551                if count < 2 {
552                    Err(SageError::Agent("Simulated failure".to_string()))
553                } else {
554                    Ok(())
555                }
556            }
557        });
558
559        // Child 3: Succeeds but should be restarted when Child2 fails
560        supervisor.add_child("Child3", RestartPolicy::Temporary, move || {
561            let counter = counter3_clone.clone();
562            async move {
563                counter.fetch_add(1, Ordering::SeqCst);
564                // Wait a bit so it doesn't exit before child 2 fails
565                tokio::time::sleep(Duration::from_millis(50)).await;
566                Ok(())
567            }
568        });
569
570        let result = supervisor.run().await;
571        assert!(result.is_ok(), "supervisor failed: {:?}", result);
572
573        // Child1 should only run once (it's before the failing child)
574        assert_eq!(
575            counter1.load(Ordering::SeqCst),
576            1,
577            "Child1 should run only once"
578        );
579
580        // Child2 runs 3 times (2 failures + 1 success)
581        assert_eq!(
582            counter2.load(Ordering::SeqCst),
583            3,
584            "Child2 should run 3 times"
585        );
586
587        // Child3 should be restarted when Child2 fails (2 restarts + initial)
588        assert!(
589            counter3.load(Ordering::SeqCst) >= 2,
590            "Child3 should be restarted at least once with RestForOne, got {}",
591            counter3.load(Ordering::SeqCst)
592        );
593    }
594
595    #[tokio::test]
596    async fn test_one_for_all_restarts_all() {
597        // OneForAll: when any child fails, all children restart.
598        let counter1 = Arc::new(AtomicU32::new(0));
599        let counter2 = Arc::new(AtomicU32::new(0));
600        let counter1_clone = counter1.clone();
601        let counter2_clone = counter2.clone();
602
603        let mut supervisor = Supervisor::new(Strategy::OneForAll, RestartConfig::default());
604
605        // Child 1: Always succeeds but runs longer
606        supervisor.add_child("Child1", RestartPolicy::Temporary, move || {
607            let counter = counter1_clone.clone();
608            async move {
609                counter.fetch_add(1, Ordering::SeqCst);
610                tokio::time::sleep(Duration::from_millis(100)).await;
611                Ok(())
612            }
613        });
614
615        // Child 2: Fails twice then succeeds (this triggers OneForAll)
616        supervisor.add_child("Child2", RestartPolicy::Transient, move || {
617            let counter = counter2_clone.clone();
618            async move {
619                let count = counter.fetch_add(1, Ordering::SeqCst);
620                if count < 2 {
621                    Err(SageError::Agent("Simulated failure".to_string()))
622                } else {
623                    tokio::time::sleep(Duration::from_millis(10)).await;
624                    Ok(())
625                }
626            }
627        });
628
629        let result = supervisor.run().await;
630        assert!(result.is_ok(), "supervisor failed: {:?}", result);
631
632        // Child2 runs 3 times (2 failures + 1 success)
633        assert_eq!(
634            counter2.load(Ordering::SeqCst),
635            3,
636            "Child2 should run 3 times"
637        );
638
639        // Child1 should be restarted when Child2 fails (OneForAll restarts all)
640        assert!(
641            counter1.load(Ordering::SeqCst) >= 2,
642            "Child1 should be restarted at least once with OneForAll, got {}",
643            counter1.load(Ordering::SeqCst)
644        );
645    }
646}