1use std::collections::HashMap;
18use std::sync::{Arc, Mutex, mpsc};
19use std::thread;
20use std::time::{Duration, Instant};
21use std::sync::atomic::{AtomicBool, Ordering};
22
23#[derive(Debug, Clone)]
25pub enum RestartStrategy {
26 OneForOne,
28 OneForAll,
30 RestForOne,
32}
33
34#[derive(Debug, Clone, PartialEq)]
36pub enum ProcessState {
37 Running,
39 Failed,
41 Restarting,
43 Stopped,
45 Unstarted,
47}
48
49#[derive(Debug, Clone)]
51pub enum ChildType {
52 Permanent,
54 Temporary,
56 Transient,
58}
59
60#[derive(Debug, Clone)]
62pub enum ShutdownStrategy {
63 BrutalKill,
65 Shutdown(Duration),
67}
68
69#[derive(Debug, Clone)]
71pub struct SupervisorConfig {
72 pub max_restarts: usize,
74 pub max_time: Duration,
76 pub restart_strategy: RestartStrategy,
78 pub shutdown_strategy: ShutdownStrategy,
80}
81
82impl Default for SupervisorConfig {
83 fn default() -> Self {
85 SupervisorConfig {
86 max_restarts: 3,
87 max_time: Duration::from_secs(5),
88 restart_strategy: RestartStrategy::OneForOne,
89 shutdown_strategy: ShutdownStrategy::Shutdown(Duration::from_secs(5)),
90 }
91 }
92}
93
94pub trait EventCallback: Send + Sync {
96 fn on_process_started(&self, _process_name: &str) {}
98 fn on_process_failed(&self, _process_name: &str) {}
100 fn on_process_restarted(&self, _process_name: &str, _restart_count: usize) {}
102 fn on_process_stopped(&self, _process_name: &str) {}
104}
105
106pub struct NoOpCallback;
108
109impl EventCallback for NoOpCallback {}
110
111struct ChildSpec {
113 child_type: ChildType,
115 factory: Box<dyn Fn() -> thread::JoinHandle<()> + Send + 'static>,
117 shutdown_strategy: ShutdownStrategy,
119 shutdown_signal: Arc<AtomicBool>,
121}
122
123struct ProcessInfo {
125 handle: Option<thread::JoinHandle<()>>,
127 restart_times: Vec<Instant>,
129 state: ProcessState,
131 restart_count: usize,
133 spec: ChildSpec,
135}
136
137pub struct Supervisor {
139 processes: Arc<Mutex<HashMap<String, ProcessInfo>>>,
141 config: SupervisorConfig,
143 dependencies: Arc<Mutex<HashMap<String, Vec<String>>>>,
145 event_callback: Arc<dyn EventCallback>,
147 monitor_handle: Arc<Mutex<Option<thread::JoinHandle<()>>>>,
149 shutdown_flag: Arc<AtomicBool>,
151 signal_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
153}
154
155impl Supervisor {
156 pub fn new(config: SupervisorConfig) -> Self {
169 Supervisor::with_callback(config, Arc::new(NoOpCallback))
170 }
171
172 pub fn with_callback(config: SupervisorConfig, callback: Arc<dyn EventCallback>) -> Self {
174 Supervisor {
175 processes: Arc::new(Mutex::new(HashMap::new())),
176 config,
177 dependencies: Arc::new(Mutex::new(HashMap::new())),
178 event_callback: callback,
179 monitor_handle: Arc::new(Mutex::new(None)),
180 shutdown_flag: Arc::new(AtomicBool::new(false)),
181 signal_tx: Arc::new(Mutex::new(None)),
182 }
183 }
184
185 pub fn add_process<F>(&mut self, name: &str, child_type: ChildType, factory: F)
207 where
208 F: Fn() -> thread::JoinHandle<()> + Send + 'static,
209 {
210 self.add_process_with_shutdown(
211 name,
212 child_type,
213 factory,
214 self.config.shutdown_strategy.clone(),
215 );
216 }
217
218 pub fn add_process_with_shutdown<F>(
220 &mut self,
221 name: &str,
222 child_type: ChildType,
223 factory: F,
224 shutdown_strategy: ShutdownStrategy,
225 )
226 where
227 F: Fn() -> thread::JoinHandle<()> + Send + 'static,
228 {
229 let factory_box = Box::new(factory);
230 let shutdown_signal = Arc::new(AtomicBool::new(false));
231
232 let spec = ChildSpec {
233 child_type,
234 factory: factory_box,
235 shutdown_strategy,
236 shutdown_signal,
237 };
238
239 let mut processes = self.processes.lock().unwrap();
240 processes.insert(
241 name.to_string(),
242 ProcessInfo {
243 handle: None,
244 restart_times: Vec::new(),
245 state: ProcessState::Unstarted,
246 restart_count: 0,
247 spec,
248 },
249 );
250 }
251
252 pub fn add_dependency(&self, process: &str, depends_on: &str) {
259 let mut dependencies = self.dependencies.lock().unwrap();
260 dependencies
261 .entry(process.to_string())
262 .or_insert_with(Vec::new)
263 .push(depends_on.to_string());
264 }
265
266 pub fn start_monitoring(self) -> Arc<Self>
273 where
274 Self: Sized,
275 {
276 let supervisor = Arc::new(self);
277
278 let should_start = {
280 let handle = supervisor.monitor_handle.lock().unwrap();
281 handle.is_none()
282 };
283
284 if !should_start {
285 return supervisor;
286 }
287
288 {
290 let (tx, _rx) = mpsc::channel();
291 *supervisor.signal_tx.lock().unwrap() = Some(tx);
292
293 let supervisor_clone = Arc::clone(&supervisor);
294 let monitor_thread = thread::spawn(move || {
295 supervisor_clone.monitor_loop();
296 });
297
298 let mut handle = supervisor.monitor_handle.lock().unwrap();
299 *handle = Some(monitor_thread);
300 }
301
302 {
304 let mut processes = supervisor.processes.lock().unwrap();
305 for (name, info) in processes.iter_mut() {
306 info.state = ProcessState::Restarting;
307 info.handle = Some((info.spec.factory)());
308 info.state = ProcessState::Running;
309 info.restart_times.push(Instant::now());
310 supervisor.event_callback.on_process_started(name);
311 }
312 }
313
314 supervisor
315 }
316
317 fn monitor_loop(&self) {
319 loop {
320 if self.shutdown_flag.load(Ordering::Relaxed) {
321 break;
322 }
323
324 thread::sleep(Duration::from_millis(100));
325
326 let mut failed_processes = Vec::new();
328 {
329 let mut processes = self.processes.lock().unwrap();
330 for (name, info) in processes.iter_mut() {
331 if info.state == ProcessState::Unstarted {
332 continue;
333 }
334
335 if let Some(handle) = &info.handle {
336 if handle.is_finished() {
337 info.state = ProcessState::Failed;
338 info.handle = None;
339 self.event_callback.on_process_failed(name);
340
341 let should_check_restart = match info.spec.child_type {
343 ChildType::Permanent => true,
344 ChildType::Temporary => false,
345 ChildType::Transient => {
346 true
349 }
350 };
351
352 if should_check_restart {
353 let now = Instant::now();
354 info.restart_times
355 .retain(|time| now.duration_since(*time) < self.config.max_time);
356
357 if info.restart_times.len() < self.config.max_restarts {
358 failed_processes.push(name.clone());
359 } else {
360 info.state = ProcessState::Stopped;
361 }
362 } else {
363 info.state = ProcessState::Stopped;
364 }
365 }
366 }
367 }
368 }
369
370 for failed_process in failed_processes {
372 let processes_to_restart = {
373 let processes = self.processes.lock().unwrap();
374 let dependencies = self.dependencies.lock().unwrap();
375
376 match self.config.restart_strategy {
377 RestartStrategy::OneForOne => vec![failed_process.clone()],
378 RestartStrategy::OneForAll => processes.keys().cloned().collect(),
379 RestartStrategy::RestForOne => {
380 let mut to_restart = vec![failed_process.clone()];
381 for (proc_name, deps) in dependencies.iter() {
382 if deps.contains(&failed_process) {
383 to_restart.push(proc_name.clone());
384 }
385 }
386 to_restart
387 }
388 }
389 };
390
391 let now = Instant::now();
392 for proc_name in processes_to_restart {
393 let mut processes = self.processes.lock().unwrap();
394 if let Some(proc_info) = processes.get_mut(&proc_name) {
395 if matches!(proc_info.spec.child_type, ChildType::Temporary)
397 || proc_info.state == ProcessState::Stopped
398 {
399 continue;
400 }
401
402 proc_info.state = ProcessState::Restarting;
403 proc_info.restart_count += 1;
404 proc_info.handle = Some((proc_info.spec.factory)());
405 proc_info.restart_times.push(now);
406 proc_info.state = ProcessState::Running;
407
408 self.event_callback
409 .on_process_restarted(&proc_name, proc_info.restart_count);
410 }
411 }
412 }
413 }
414 }
415
416 pub fn stop_process(&self, name: &str) -> bool {
426 let mut processes = self.processes.lock().unwrap();
427 if let Some(info) = processes.get_mut(name) {
428 if let Some(handle) = info.handle.take() {
429 info.spec.shutdown_signal.store(true, Ordering::Relaxed);
431
432 match &info.spec.shutdown_strategy {
433 ShutdownStrategy::BrutalKill => {
434 drop(handle);
435 }
436 ShutdownStrategy::Shutdown(timeout) => {
437 let start = Instant::now();
439 while !handle.is_finished() && start.elapsed() < *timeout {
440 thread::sleep(Duration::from_millis(10));
441 }
442 drop(handle);
443 }
444 }
445
446 info.state = ProcessState::Stopped;
447 self.event_callback.on_process_stopped(name);
448 return true;
449 }
450 }
451 false
452 }
453
454 pub fn shutdown(&self) {
456 self.shutdown_flag.store(true, Ordering::Relaxed);
457
458 let process_names: Vec<String> = {
459 let processes = self.processes.lock().unwrap();
460 processes.keys().cloned().collect()
461 };
462
463 for name in process_names {
464 self.stop_process(&name);
465 }
466
467 if let Ok(mut handle) = self.monitor_handle.lock() {
469 if let Some(thread) = handle.take() {
470 let _ = thread.join();
471 }
472 }
473 }
474
475 pub fn get_process_state(&self, name: &str) -> Option<ProcessState> {
485 let processes = self.processes.lock().unwrap();
486 processes.get(name).map(|info| info.state.clone())
487 }
488
489 pub fn get_restart_count(&self, name: &str) -> Option<usize> {
491 let processes = self.processes.lock().unwrap();
492 processes.get(name).map(|info| info.restart_count)
493 }
494
495 pub fn get_all_states(&self) -> HashMap<String, (ProcessState, usize)> {
497 let processes = self.processes.lock().unwrap();
498 processes
499 .iter()
500 .map(|(name, info)| {
501 (
502 name.clone(),
503 (info.state.clone(), info.restart_count),
504 )
505 })
506 .collect()
507 }
508
509 pub fn get_shutdown_signal(&self, name: &str) -> Option<Arc<AtomicBool>> {
511 let processes = self.processes.lock().unwrap();
512 processes.get(name).map(|info| Arc::clone(&info.spec.shutdown_signal))
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use std::sync::atomic::{AtomicUsize, Ordering};
520
521 #[test]
523 fn test_supervisor_creation() {
524 let supervisor = Supervisor::new(SupervisorConfig::default());
525 assert_eq!(supervisor.get_all_states().len(), 0);
526 }
527
528 #[test]
530 fn test_add_process() {
531 let mut supervisor = Supervisor::new(SupervisorConfig::default());
532 supervisor.add_process("worker1", ChildType::Permanent, || {
533 thread::spawn(|| {
534 thread::sleep(Duration::from_secs(10));
535 })
536 });
537
538 assert_eq!(supervisor.get_all_states().len(), 1);
539 assert_eq!(
540 supervisor.get_process_state("worker1"),
541 Some(ProcessState::Unstarted)
542 );
543 }
544
545 #[test]
547 fn test_process_starts_on_monitoring() {
548 let mut supervisor = Supervisor::new(SupervisorConfig::default());
549 supervisor.add_process("worker1", ChildType::Permanent, || {
550 thread::spawn(|| {
551 thread::sleep(Duration::from_secs(10));
552 })
553 });
554
555 let supervisor = supervisor.start_monitoring();
556 thread::sleep(Duration::from_millis(200));
557
558 assert_eq!(
559 supervisor.get_process_state("worker1"),
560 Some(ProcessState::Running)
561 );
562
563 supervisor.shutdown();
564 }
565
566 #[test]
568 fn test_permanent_process_restart() {
569 let counter = Arc::new(AtomicUsize::new(0));
570 let counter_clone = Arc::clone(&counter);
571
572 let mut supervisor = Supervisor::new(SupervisorConfig::default());
573 supervisor.add_process("failing_worker", ChildType::Permanent, move || {
574 let cnt = Arc::clone(&counter_clone);
575 thread::spawn(move || {
576 cnt.fetch_add(1, Ordering::Relaxed);
577 panic!("Intentional failure");
578 })
579 });
580
581 let supervisor = supervisor.start_monitoring();
582 thread::sleep(Duration::from_millis(500));
583
584 assert!(counter.load(Ordering::Relaxed) > 1);
586
587 supervisor.shutdown();
588 }
589
590 #[test]
592 fn test_temporary_process_no_restart() {
593 let counter = Arc::new(AtomicUsize::new(0));
594 let counter_clone = Arc::clone(&counter);
595
596 let mut supervisor = Supervisor::new(SupervisorConfig::default());
597 supervisor.add_process("temp_worker", ChildType::Temporary, move || {
598 let cnt = Arc::clone(&counter_clone);
599 thread::spawn(move || {
600 cnt.fetch_add(1, Ordering::Relaxed);
601 panic!("Intentional failure");
602 })
603 });
604
605 let supervisor = supervisor.start_monitoring();
606 thread::sleep(Duration::from_millis(500));
607
608 assert_eq!(counter.load(Ordering::Relaxed), 1);
610 assert_eq!(
611 supervisor.get_process_state("temp_worker"),
612 Some(ProcessState::Stopped)
613 );
614
615 supervisor.shutdown();
616 }
617
618 #[test]
620 fn test_stop_process() {
621 let mut supervisor = Supervisor::new(SupervisorConfig::default());
622 supervisor.add_process("worker1", ChildType::Permanent, || {
623 thread::spawn(|| {
624 thread::sleep(Duration::from_secs(10));
625 })
626 });
627
628 let supervisor = supervisor.start_monitoring();
629 thread::sleep(Duration::from_millis(200));
630
631 assert!(supervisor.stop_process("worker1"));
632 thread::sleep(Duration::from_millis(100));
633 assert_eq!(
634 supervisor.get_process_state("worker1"),
635 Some(ProcessState::Stopped)
636 );
637
638 supervisor.shutdown();
639 }
640
641 #[test]
643 fn test_restart_count() {
644 let mut supervisor = Supervisor::new(SupervisorConfig::default());
645 supervisor.add_process("failing_worker", ChildType::Permanent, || {
646 thread::spawn(|| {
647 panic!("Intentional failure");
648 })
649 });
650
651 let supervisor = supervisor.start_monitoring();
652 thread::sleep(Duration::from_millis(500));
653
654 let restart_count = supervisor.get_restart_count("failing_worker").unwrap_or(0);
655 assert!(restart_count > 0);
656
657 supervisor.shutdown();
658 }
659
660 #[test]
662 fn test_restart_strategy_one_for_one() {
663 let mut config = SupervisorConfig::default();
664 config.restart_strategy = RestartStrategy::OneForOne;
665
666 let counter1 = Arc::new(AtomicUsize::new(0));
667 let counter1_clone = Arc::clone(&counter1);
668 let counter2 = Arc::new(AtomicUsize::new(0));
669 let counter2_clone = Arc::clone(&counter2);
670
671 let mut supervisor = Supervisor::new(config);
672
673 supervisor.add_process("failing_worker", ChildType::Permanent, move || {
674 let cnt = Arc::clone(&counter1_clone);
675 thread::spawn(move || {
676 cnt.fetch_add(1, Ordering::Relaxed);
677 panic!("Intentional failure");
678 })
679 });
680
681 supervisor.add_process("stable_worker", ChildType::Permanent, move || {
682 let cnt = Arc::clone(&counter2_clone);
683 thread::spawn(move || {
684 cnt.fetch_add(1, Ordering::Relaxed);
685 thread::sleep(Duration::from_secs(10));
686 })
687 });
688
689 let supervisor = supervisor.start_monitoring();
690 thread::sleep(Duration::from_millis(500));
691
692 let count1 = counter1.load(Ordering::Relaxed);
693 let count2 = counter2.load(Ordering::Relaxed);
694
695 assert!(count1 > count2);
697
698 supervisor.shutdown();
699 }
700
701 #[test]
703 fn test_process_dependencies() {
704 let mut supervisor = Supervisor::new(SupervisorConfig::default());
705
706 supervisor.add_process("base_worker", ChildType::Permanent, || {
707 thread::spawn(|| {
708 thread::sleep(Duration::from_secs(10));
709 })
710 });
711
712 supervisor.add_process("dependent_worker", ChildType::Permanent, || {
713 thread::spawn(|| {
714 thread::sleep(Duration::from_secs(10));
715 })
716 });
717
718 supervisor.add_dependency("dependent_worker", "base_worker");
719
720 let supervisor = supervisor.start_monitoring();
721 thread::sleep(Duration::from_millis(200));
722
723 assert_eq!(
724 supervisor.get_process_state("base_worker"),
725 Some(ProcessState::Running)
726 );
727 assert_eq!(
728 supervisor.get_process_state("dependent_worker"),
729 Some(ProcessState::Running)
730 );
731
732 supervisor.shutdown();
733 }
734
735 #[test]
737 fn test_max_restarts_limit() {
738 let mut config = SupervisorConfig::default();
739 config.max_restarts = 2;
740 config.max_time = Duration::from_secs(5);
741 let max_restarts = config.max_restarts;
742
743 let counter = Arc::new(AtomicUsize::new(0));
744 let counter_clone = Arc::clone(&counter);
745
746 let mut supervisor = Supervisor::new(config);
747 supervisor.add_process("failing_worker", ChildType::Permanent, move || {
748 let cnt = Arc::clone(&counter_clone);
749 thread::spawn(move || {
750 cnt.fetch_add(1, Ordering::Relaxed);
751 panic!("Intentional failure");
752 })
753 });
754
755 let supervisor = supervisor.start_monitoring();
756 thread::sleep(Duration::from_millis(1000));
757
758 assert_eq!(
760 supervisor.get_process_state("failing_worker"),
761 Some(ProcessState::Stopped)
762 );
763
764 assert!(counter.load(Ordering::Relaxed) <= max_restarts + 1);
766
767 supervisor.shutdown();
768 }
769
770 #[test]
772 fn test_supervisor_shutdown() {
773 let mut supervisor = Supervisor::new(SupervisorConfig::default());
774
775 supervisor.add_process("worker1", ChildType::Permanent, || {
776 thread::spawn(|| {
777 thread::sleep(Duration::from_secs(10));
778 })
779 });
780
781 supervisor.add_process("worker2", ChildType::Permanent, || {
782 thread::spawn(|| {
783 thread::sleep(Duration::from_secs(10));
784 })
785 });
786
787 let supervisor = supervisor.start_monitoring();
788 thread::sleep(Duration::from_millis(200));
789
790 supervisor.shutdown();
791 thread::sleep(Duration::from_millis(200));
792
793 assert_eq!(
794 supervisor.get_process_state("worker1"),
795 Some(ProcessState::Stopped)
796 );
797 assert_eq!(
798 supervisor.get_process_state("worker2"),
799 Some(ProcessState::Stopped)
800 );
801 }
802
803 #[test]
805 fn test_event_callback() {
806 struct TestCallback {
807 started: AtomicUsize,
808 failed: AtomicUsize,
809 restarted: AtomicUsize,
810 }
811
812 impl EventCallback for TestCallback {
813 fn on_process_started(&self, _process_name: &str) {
814 self.started.fetch_add(1, Ordering::Relaxed);
815 }
816
817 fn on_process_failed(&self, _process_name: &str) {
818 self.failed.fetch_add(1, Ordering::Relaxed);
819 }
820
821 fn on_process_restarted(&self, _process_name: &str, _restart_count: usize) {
822 self.restarted.fetch_add(1, Ordering::Relaxed);
823 }
824 }
825
826 let callback: Arc<dyn EventCallback> = Arc::new(TestCallback {
827 started: AtomicUsize::new(0),
828 failed: AtomicUsize::new(0),
829 restarted: AtomicUsize::new(0),
830 });
831
832 let mut supervisor = Supervisor::with_callback(SupervisorConfig::default(), callback.clone());
833
834 supervisor.add_process("failing_worker", ChildType::Permanent, || {
835 thread::spawn(|| {
836 panic!("Intentional failure");
837 })
838 });
839
840 let supervisor = supervisor.start_monitoring();
841 thread::sleep(Duration::from_millis(500));
842
843 let callback_test = callback.as_ref() as *const dyn EventCallback as *const TestCallback;
845 unsafe {
846 assert!((*callback_test).started.load(Ordering::Relaxed) > 0);
847 assert!((*callback_test).failed.load(Ordering::Relaxed) > 0);
848 assert!((*callback_test).restarted.load(Ordering::Relaxed) > 0);
849 }
850
851 supervisor.shutdown();
852 }
853}
854