Skip to main content

scoped_panic_hook/
hook.rs

1use std::cell::Cell;
2use std::mem;
3use std::panic::{PanicHookInfo, UnwindSafe, catch_unwind};
4use std::sync::Once;
5use std::thread;
6
7use defer::defer;
8
9/// Specifies whether to stop at currently executed scoped hook
10/// or continue with other options
11#[derive(Copy, Clone, Debug, PartialEq, Eq)]
12pub enum NextHook {
13    /// Don't call anything after hook which returned this value
14    Break,
15    /// Call hook which precedes scoped hooks; usually that's default hook
16    PrevInstalledHook,
17}
18/// Initializes scoped hook infrastructure
19///
20/// Usually you don't need to call it explicitly, yet it's left accessible
21/// in case crate user wants to perform initialization at well-defined point
22pub fn init_scoped_hooks() {
23    INIT_SCOPED_HOOKS.call_once(install_hook);
24}
25
26#[cfg(nightly)]
27fn install_hook() {
28    std::panic::update_hook(|prev, info| {
29        if scoped_hook_fn(info) {
30            prev(info);
31        }
32    });
33}
34
35#[cfg(not(nightly))]
36fn install_hook() {
37    let old_handler = std::panic::take_hook();
38    std::panic::set_hook(Box::new(move |info| {
39        if scoped_hook_fn(info) {
40            old_handler(info);
41        }
42    }));
43}
44/// Executes `body` function wrapped in `catch_unwind` with specified scoped panic hook installed
45///
46/// # Parameters
47/// * `hook` - panic hook function which receives panic info during `body`'s execution
48/// * `body` - Actual payload closure.
49///   The closure has same requirements as the parameter to [`std::panic::catch_unwind`].
50///   In particular, if you want to assert that closure is unwind safe when type system
51///   can't deduce it, you can use [`std::panic::AssertUnwindSafe`].
52///   See [`std::panic::catch_unwind`] for details.
53///
54/// # Returns
55/// Usual `catch_unwind` result with `Ok(...)` being return value from `body`
56pub fn catch_unwind_with_scoped_hook<R>(
57    mut hook: impl FnMut(&PanicHookInfo<'_>) -> NextHook,
58    body: impl FnOnce() -> R + UnwindSafe,
59) -> thread::Result<R> {
60    init_scoped_hooks();
61
62    let new_info = HookInfo::from_hook(&mut hook);
63    let old_info = CURRENT_SCOPED_HOOK.replace(Some(new_info));
64    defer! { CURRENT_SCOPED_HOOK.set(old_info) }
65
66    catch_unwind(body)
67}
68
69thread_local! {
70    static CURRENT_SCOPED_HOOK: Cell<Option<HookInfo>> = const { Cell::new(None) };
71}
72
73static INIT_SCOPED_HOOKS: Once = Once::new();
74
75fn scoped_hook_fn(info: &PanicHookInfo<'_>) -> bool {
76    CURRENT_SCOPED_HOOK
77        .get()
78        .is_none_or(|hook| unsafe { hook.call_handler(info) } == NextHook::PrevInstalledHook)
79}
80
81type DynHookPtr<'a> = *mut (dyn FnMut(&PanicHookInfo<'_>) -> NextHook + 'a);
82type StaticHookPtr = DynHookPtr<'static>;
83
84#[derive(Copy, Clone)]
85struct HookInfo {
86    hook: StaticHookPtr,
87}
88
89impl HookInfo {
90    fn from_hook(hook: &mut impl FnMut(&PanicHookInfo<'_>) -> NextHook) -> Self {
91        Self {
92            // transmute is required to erase lifetimes from actual closure type
93            hook: unsafe {
94                mem::transmute::<DynHookPtr<'_>, StaticHookPtr>(hook as DynHookPtr<'_>)
95            },
96        }
97    }
98
99    unsafe fn call_handler(&self, info: &PanicHookInfo<'_>) -> NextHook {
100        unsafe { (*self.hook)(info) }
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use subprocess_test::subprocess_test;
107
108    use super::*;
109
110    #[test]
111    fn simple_ok() {
112        let mut counter = 0;
113
114        catch_unwind_with_scoped_hook(
115            |_| {
116                counter += 1;
117                NextHook::Break
118            },
119            || (),
120        )
121        .unwrap();
122
123        assert_eq!(counter, 0);
124    }
125
126    #[test]
127    fn simple_panic() {
128        let mut counter = 0;
129
130        catch_unwind_with_scoped_hook(
131            |_| {
132                counter += 1;
133                NextHook::Break
134            },
135            || panic!("Oops!"),
136        )
137        .unwrap_err();
138
139        assert_eq!(counter, 1);
140    }
141
142    subprocess_test! {
143        #[test]
144        fn hook_forwarding() {
145            use std::panic::set_hook;
146            use std::sync::atomic::{AtomicUsize, Ordering};
147
148            static GLOB_COUNTER: AtomicUsize = AtomicUsize::new(0);
149
150            fn counter_hook(_info: &PanicHookInfo<'_>) {
151                GLOB_COUNTER.fetch_add(1, Ordering::Relaxed);
152            }
153            // If we return [`NextHook::PrevInstalledHook`] from hook,
154            // previous installed hook should be actually invoked
155            //
156            // Must be in separate binary because it overrides default hook
157            // prior to scoped one takes effect
158            set_hook(Box::new(counter_hook));
159
160            let mut counter = 0;
161
162            let result = catch_unwind_with_scoped_hook(
163                |_| {
164                    counter += 1;
165                    NextHook::PrevInstalledHook
166                },
167                || panic!("Oops!"),
168            );
169            assert!(result.is_err());
170
171            assert_eq!(counter, 1);
172            assert_eq!(GLOB_COUNTER.load(Ordering::Relaxed), 1);
173        }
174    }
175}