signal_stack/
stack.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::sync::Arc;
4
5use libc::c_int;
6
7use super::backend::{PlatformSigData, PlatformSigHandler, SigHandler};
8use super::signal_safe::RwLock;
9
10/// This trait is implemented for functions which match the required signature
11/// for signal handlers.
12///
13/// The signal number is passed in as a parameter.
14/// The handler should return `true` if the signal was handled, in which case
15/// no further action will be taken. If `false` is returned, then the next
16/// handler on the stack will be called, or, if there are no more handlers,
17/// the default behaviour for the signal will occur.
18pub trait Handler: Fn(c_int) -> bool + Send + Sync {}
19impl<T: Fn(c_int) -> bool + Send + Sync> Handler for T {}
20
21#[derive(Clone)]
22struct Slot {
23    stack: Vec<Arc<dyn Handler>>,
24    prev: PlatformSigHandler,
25}
26
27impl Slot {
28    pub fn new(signum: c_int) -> Self {
29        Self {
30            stack: Vec::new(),
31            prev: PlatformSigHandler::detect(signum),
32        }
33    }
34}
35
36type Handlers = HashMap<c_int, Slot>;
37
38#[derive(Clone)]
39pub struct HandlerId(Arc<dyn Handler>);
40
41impl Debug for HandlerId {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.write_str("HandlerId { ... }")
44    }
45}
46
47impl Eq for HandlerId {}
48impl PartialEq for HandlerId {
49    fn eq(&self, other: &Self) -> bool {
50        // Comparing wide pointers has unpredictable results, so compare thin pointers.
51        std::ptr::eq(
52            Arc::as_ptr(&self.0) as *const (),
53            Arc::as_ptr(&other.0) as *const (),
54        )
55    }
56}
57
58static HANDLERS: RwLock<Option<Handlers>> = RwLock::const_new(None, None);
59
60pub(crate) fn our_handler(signum: c_int, data: PlatformSigData) {
61    if let Some(handlers) = &*HANDLERS.read() {
62        if let Some(slot) = handlers.get(&signum) {
63            for item in slot.stack.iter().rev() {
64                if item(signum) {
65                    return;
66                }
67            }
68            unsafe {
69                slot.prev.delegate(signum, data);
70            }
71        }
72    }
73}
74
75pub(crate) unsafe fn add_handler(signums: &[c_int], handler: Arc<dyn Handler>) -> HandlerId {
76    let handler_id = HandlerId(handler.clone());
77
78    if !signums.is_empty() {
79        let mut install_c_handlers = Vec::new();
80        {
81            let mut guard = HANDLERS.write();
82            let handlers = guard.get_or_insert_with(Default::default);
83            for &signum in signums {
84                handlers
85                    .entry(signum)
86                    .or_insert_with(|| {
87                        install_c_handlers.push(signum);
88                        Slot::new(signum)
89                    })
90                    .stack
91                    .push(handler.clone());
92            }
93        }
94
95        if !install_c_handlers.is_empty() {
96            let prevs: Vec<_> = install_c_handlers
97                .into_iter()
98                .map(|signum| (signum, PlatformSigHandler::ours().install(signum)))
99                .collect();
100
101            let mut guard = HANDLERS.write();
102            let handlers = guard.as_mut().unwrap();
103            for (signum, prev) in prevs {
104                handlers.get_mut(&signum).unwrap().prev = prev;
105            }
106        }
107    }
108
109    handler_id
110}
111
112pub(crate) unsafe fn remove_handler(signums: &[c_int], handler_id: &HandlerId) {
113    if signums.is_empty() {
114        return;
115    }
116    let ptr = Arc::as_ptr(&handler_id.0) as *const ();
117    if let Some(handlers) = HANDLERS.write().as_mut() {
118        for &signum in signums {
119            if let Some(slot) = handlers.get_mut(&signum) {
120                if let Some((index, _)) = slot
121                    .stack
122                    .iter()
123                    .enumerate()
124                    .rev()
125                    .find(|&(_, item)| Arc::as_ptr(item) as *const () == ptr)
126                {
127                    slot.stack.remove(index);
128                }
129            }
130        }
131    }
132}