1use crate::error::{SageError, SageResult};
36use std::collections::VecDeque;
37use std::future::Future;
38use std::pin::Pin;
39use std::time::{Duration, Instant};
40use tokio::task::JoinHandle;
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
44pub enum Strategy {
45 #[default]
47 OneForOne,
48 OneForAll,
50 RestForOne,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
56pub enum RestartPolicy {
57 #[default]
59 Permanent,
60 Transient,
62 Temporary,
64}
65
66#[derive(Debug, Clone)]
68pub struct RestartConfig {
69 pub max_restarts: u32,
71 pub within: Duration,
73}
74
75impl Default for RestartConfig {
76 fn default() -> Self {
77 Self {
78 max_restarts: 5,
79 within: Duration::from_secs(60),
80 }
81 }
82}
83
84struct RestartTracker {
86 timestamps: VecDeque<Instant>,
87 config: RestartConfig,
88}
89
90impl RestartTracker {
91 fn new(config: RestartConfig) -> Self {
92 Self {
93 timestamps: VecDeque::new(),
94 config,
95 }
96 }
97
98 fn record_restart(&mut self) -> bool {
101 let now = Instant::now();
102
103 while let Some(&oldest) = self.timestamps.front() {
105 if now.duration_since(oldest) > self.config.within {
106 self.timestamps.pop_front();
107 } else {
108 break;
109 }
110 }
111
112 if self.timestamps.len() >= self.config.max_restarts as usize {
114 return false; }
116
117 self.timestamps.push_back(now);
118 true
119 }
120}
121
122pub type SpawnFn = Box<dyn Fn() -> Pin<Box<dyn Future<Output = SageResult<()>> + Send>> + Send>;
124
125struct ChildHandle {
127 name: String,
128 restart_policy: RestartPolicy,
129 spawn_fn: SpawnFn,
130 handle: Option<JoinHandle<SageResult<()>>>,
131}
132
133impl ChildHandle {
134 fn new(name: String, restart_policy: RestartPolicy, spawn_fn: SpawnFn) -> Self {
135 Self {
136 name,
137 restart_policy,
138 spawn_fn,
139 handle: None,
140 }
141 }
142
143 fn spawn(&mut self) {
145 let future = (self.spawn_fn)();
146 self.handle = Some(tokio::spawn(async move { future.await }));
147 }
148
149 fn is_running(&self) -> bool {
151 self.handle
152 .as_ref()
153 .map(|h| !h.is_finished())
154 .unwrap_or(false)
155 }
156
157 fn take_handle(&mut self) -> Option<JoinHandle<SageResult<()>>> {
159 self.handle.take()
160 }
161}
162
163pub struct Supervisor {
165 strategy: Strategy,
166 children: Vec<ChildHandle>,
167 restart_tracker: RestartTracker,
168}
169
170impl Supervisor {
171 pub fn new(strategy: Strategy, config: RestartConfig) -> Self {
173 Self {
174 strategy,
175 children: Vec::new(),
176 restart_tracker: RestartTracker::new(config),
177 }
178 }
179
180 pub fn add_child<F, Fut>(&mut self, name: impl Into<String>, restart_policy: RestartPolicy, spawn_fn: F)
184 where
185 F: Fn() -> Fut + Send + 'static,
186 Fut: Future<Output = SageResult<()>> + Send + 'static,
187 {
188 let spawn_fn: SpawnFn = Box::new(move || Box::pin(spawn_fn()));
189 self.children.push(ChildHandle::new(name.into(), restart_policy, spawn_fn));
190 }
191
192 pub async fn run(&mut self) -> SageResult<()> {
197 for child in &mut self.children {
199 child.spawn();
200 }
201
202 loop {
204 let (index, result) = self.wait_for_child_exit().await;
206
207 if index.is_none() {
209 break;
211 }
212
213 let index = index.unwrap();
214 let child_name = self.children[index].name.clone();
215 let restart_policy = self.children[index].restart_policy;
216
217 let should_restart = match (restart_policy, &result) {
219 (RestartPolicy::Permanent, _) => true,
220 (RestartPolicy::Transient, Err(_)) => true,
221 (RestartPolicy::Transient, Ok(_)) => false,
222 (RestartPolicy::Temporary, _) => false,
223 };
224
225 if should_restart {
226 if !self.restart_tracker.record_restart() {
228 return Err(SageError::Supervisor(format!(
229 "Maximum restart intensity reached for supervisor (child '{}' failed too many times)",
230 child_name
231 )));
232 }
233
234 match self.strategy {
236 Strategy::OneForOne => {
237 self.restart_child(index);
238 }
239 Strategy::OneForAll => {
240 self.restart_all();
241 }
242 Strategy::RestForOne => {
243 self.restart_rest(index);
244 }
245 }
246 }
247
248 if !self.any_running() {
250 break;
251 }
252 }
253
254 Ok(())
255 }
256
257 async fn wait_for_child_exit(&mut self) -> (Option<usize>, SageResult<()>) {
259 use futures::future::select_all;
260
261 let handles_with_indices: Vec<(usize, JoinHandle<SageResult<()>>)> = self
263 .children
264 .iter_mut()
265 .enumerate()
266 .filter_map(|(i, c)| c.take_handle().map(|h| (i, h)))
267 .collect();
268
269 if handles_with_indices.is_empty() {
270 return (None, Ok(()));
271 }
272
273 let indices: Vec<usize> = handles_with_indices.iter().map(|(i, _)| *i).collect();
275 let handles: Vec<JoinHandle<SageResult<()>>> =
276 handles_with_indices.into_iter().map(|(_, h)| h).collect();
277
278 let (join_result, completed_idx, remaining_handles) = select_all(handles).await;
280
281 let child_index = indices[completed_idx];
283
284 let final_result = join_result.unwrap_or_else(|e| Err(SageError::Agent(e.to_string())));
286
287 let mut remaining_iter = remaining_handles.into_iter();
290 for (pos, &original_idx) in indices.iter().enumerate() {
291 if pos != completed_idx {
292 if let (Some(handle), Some(child)) =
293 (remaining_iter.next(), self.children.get_mut(original_idx))
294 {
295 child.handle = Some(handle);
296 }
297 }
298 }
299
300 (Some(child_index), final_result)
301 }
302
303 fn restart_child(&mut self, index: usize) {
305 if let Some(child) = self.children.get_mut(index) {
306 child.spawn();
307 }
308 }
309
310 fn restart_all(&mut self) {
312 for child in &mut self.children {
314 if let Some(handle) = child.take_handle() {
315 handle.abort();
316 }
317 }
318
319 for child in &mut self.children {
321 child.spawn();
322 }
323 }
324
325 fn restart_rest(&mut self, from_index: usize) {
327 for child in self.children.iter_mut().skip(from_index) {
329 if let Some(handle) = child.take_handle() {
330 handle.abort();
331 }
332 }
333
334 for child in self.children.iter_mut().skip(from_index) {
336 child.spawn();
337 }
338 }
339
340 fn any_running(&self) -> bool {
342 self.children.iter().any(|c| c.is_running())
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use std::sync::atomic::{AtomicU32, Ordering};
350 use std::sync::Arc;
351
352 #[tokio::test]
353 async fn test_one_for_one_restart() {
354 let counter = Arc::new(AtomicU32::new(0));
355 let counter_clone = counter.clone();
356
357 let mut supervisor = Supervisor::new(Strategy::OneForOne, RestartConfig::default());
358
359 supervisor.add_child("Worker", RestartPolicy::Transient, move || {
361 let counter = counter_clone.clone();
362 async move {
363 let count = counter.fetch_add(1, Ordering::SeqCst);
364 if count < 2 {
365 Err(SageError::Agent("Simulated failure".to_string()))
366 } else {
367 Ok(())
368 }
369 }
370 });
371
372 let result = supervisor.run().await;
373 assert!(result.is_ok(), "supervisor failed: {:?}", result);
374 assert_eq!(counter.load(Ordering::SeqCst), 3);
375 }
376
377 #[tokio::test]
378 async fn test_transient_no_restart_on_success() {
379 let counter = Arc::new(AtomicU32::new(0));
380 let counter_clone = counter.clone();
381
382 let mut supervisor = Supervisor::new(Strategy::OneForOne, RestartConfig::default());
383
384 supervisor.add_child("Worker", RestartPolicy::Transient, move || {
385 let counter = counter_clone.clone();
386 async move {
387 counter.fetch_add(1, Ordering::SeqCst);
388 Ok(())
389 }
390 });
391
392 let result = supervisor.run().await;
393 assert!(result.is_ok());
394 assert_eq!(counter.load(Ordering::SeqCst), 1); }
396
397 #[tokio::test]
398 async fn test_temporary_never_restarts() {
399 let counter = Arc::new(AtomicU32::new(0));
400 let counter_clone = counter.clone();
401
402 let mut supervisor = Supervisor::new(Strategy::OneForOne, RestartConfig::default());
403
404 supervisor.add_child("Worker", RestartPolicy::Temporary, move || {
405 let counter = counter_clone.clone();
406 async move {
407 counter.fetch_add(1, Ordering::SeqCst);
408 Err(SageError::Agent("Simulated failure".to_string()))
409 }
410 });
411
412 let result = supervisor.run().await;
413 assert!(result.is_ok()); assert_eq!(counter.load(Ordering::SeqCst), 1); }
416
417 #[tokio::test]
418 async fn test_circuit_breaker() {
419 let counter = Arc::new(AtomicU32::new(0));
420 let counter_clone = counter.clone();
421
422 let config = RestartConfig {
423 max_restarts: 3,
424 within: Duration::from_secs(60),
425 };
426
427 let mut supervisor = Supervisor::new(Strategy::OneForOne, config);
428
429 supervisor.add_child("Worker", RestartPolicy::Permanent, move || {
430 let counter = counter_clone.clone();
431 async move {
432 counter.fetch_add(1, Ordering::SeqCst);
433 Err(SageError::Agent("Always fails".to_string()))
434 }
435 });
436
437 let result = supervisor.run().await;
438 assert!(result.is_err()); assert!(counter.load(Ordering::SeqCst) <= 4); }
441
442 #[tokio::test]
443 async fn test_permanent_restarts_on_success() {
444 let counter = Arc::new(AtomicU32::new(0));
447 let counter_clone = counter.clone();
448
449 let config = RestartConfig {
450 max_restarts: 3,
451 within: Duration::from_secs(60),
452 };
453
454 let mut supervisor = Supervisor::new(Strategy::OneForOne, config);
455
456 supervisor.add_child("Worker", RestartPolicy::Permanent, move || {
457 let counter = counter_clone.clone();
458 async move {
459 counter.fetch_add(1, Ordering::SeqCst);
460 Ok(()) }
462 });
463
464 let result = supervisor.run().await;
465 assert!(result.is_err());
467 assert!(counter.load(Ordering::SeqCst) <= 4);
468 }
469
470 #[tokio::test]
471 async fn test_rest_for_one_restarts_downstream() {
472 let counter1 = Arc::new(AtomicU32::new(0));
474 let counter2 = Arc::new(AtomicU32::new(0));
475 let counter3 = Arc::new(AtomicU32::new(0));
476 let counter1_clone = counter1.clone();
477 let counter2_clone = counter2.clone();
478 let counter3_clone = counter3.clone();
479
480 let mut supervisor = Supervisor::new(Strategy::RestForOne, RestartConfig::default());
481
482 supervisor.add_child("Child1", RestartPolicy::Temporary, move || {
484 let counter = counter1_clone.clone();
485 async move {
486 counter.fetch_add(1, Ordering::SeqCst);
487 tokio::time::sleep(Duration::from_millis(50)).await;
489 Ok(())
490 }
491 });
492
493 supervisor.add_child("Child2", RestartPolicy::Transient, move || {
495 let counter = counter2_clone.clone();
496 async move {
497 let count = counter.fetch_add(1, Ordering::SeqCst);
498 if count < 2 {
499 Err(SageError::Agent("Simulated failure".to_string()))
500 } else {
501 Ok(())
502 }
503 }
504 });
505
506 supervisor.add_child("Child3", RestartPolicy::Temporary, move || {
508 let counter = counter3_clone.clone();
509 async move {
510 counter.fetch_add(1, Ordering::SeqCst);
511 tokio::time::sleep(Duration::from_millis(50)).await;
513 Ok(())
514 }
515 });
516
517 let result = supervisor.run().await;
518 assert!(result.is_ok(), "supervisor failed: {:?}", result);
519
520 assert_eq!(counter1.load(Ordering::SeqCst), 1, "Child1 should run only once");
522
523 assert_eq!(counter2.load(Ordering::SeqCst), 3, "Child2 should run 3 times");
525
526 assert!(
528 counter3.load(Ordering::SeqCst) >= 2,
529 "Child3 should be restarted at least once with RestForOne, got {}",
530 counter3.load(Ordering::SeqCst)
531 );
532 }
533
534 #[tokio::test]
535 async fn test_one_for_all_restarts_all() {
536 let counter1 = Arc::new(AtomicU32::new(0));
538 let counter2 = Arc::new(AtomicU32::new(0));
539 let counter1_clone = counter1.clone();
540 let counter2_clone = counter2.clone();
541
542 let mut supervisor = Supervisor::new(Strategy::OneForAll, RestartConfig::default());
543
544 supervisor.add_child("Child1", RestartPolicy::Temporary, move || {
546 let counter = counter1_clone.clone();
547 async move {
548 counter.fetch_add(1, Ordering::SeqCst);
549 tokio::time::sleep(Duration::from_millis(100)).await;
550 Ok(())
551 }
552 });
553
554 supervisor.add_child("Child2", RestartPolicy::Transient, move || {
556 let counter = counter2_clone.clone();
557 async move {
558 let count = counter.fetch_add(1, Ordering::SeqCst);
559 if count < 2 {
560 Err(SageError::Agent("Simulated failure".to_string()))
561 } else {
562 tokio::time::sleep(Duration::from_millis(10)).await;
563 Ok(())
564 }
565 }
566 });
567
568 let result = supervisor.run().await;
569 assert!(result.is_ok(), "supervisor failed: {:?}", result);
570
571 assert_eq!(counter2.load(Ordering::SeqCst), 3, "Child2 should run 3 times");
573
574 assert!(
576 counter1.load(Ordering::SeqCst) >= 2,
577 "Child1 should be restarted at least once with OneForAll, got {}",
578 counter1.load(Ordering::SeqCst)
579 );
580 }
581}