scuffle_signal/
lib.rs

1//! A crate designed to provide a more user friendly interface to
2//! `tokio::signal`.
3//!
4//! ## Why do we need this?
5//!
6//! The `tokio::signal` module provides a way for us to wait for a signal to be
7//! received in a non-blocking way. This crate extends that with a more helpful
8//! interface allowing the ability to listen to multiple signals concurrently.
9//!
10//! ## Example
11//!
12//! ```rust
13//! use scuffle_signal::SignalHandler;
14//! use tokio::signal::unix::SignalKind;
15//!
16//! # tokio_test::block_on(async {
17//! let mut handler = SignalHandler::new()
18//!     .with_signal(SignalKind::interrupt())
19//!     .with_signal(SignalKind::terminate());
20//!
21//! # // Safety: This is a test, and we control the process.
22//! # unsafe {
23//! #    libc::raise(SignalKind::interrupt().as_raw_value());
24//! # }
25//! // Wait for a signal to be received
26//! let signal = handler.await;
27//!
28//! // Handle the signal
29//! let interrupt = SignalKind::interrupt();
30//! let terminate = SignalKind::terminate();
31//! match signal {
32//!     interrupt => {
33//!         // Handle SIGINT
34//!         println!("received SIGINT");
35//!     },
36//!     terminate => {
37//!         // Handle SIGTERM
38//!         println!("received SIGTERM");
39//!     },
40//! }
41//! # });
42//! ```
43//!
44//! ## Status
45//!
46//! This crate is currently under development and is not yet stable.
47//!
48//! Unit tests are not yet fully implemented. Use at your own risk.
49//!
50//! ## License
51//!
52//! This project is licensed under the [MIT](./LICENSE.MIT) or
53//! [Apache-2.0](./LICENSE.Apache-2.0) license. You can choose between one of
54//! them if you use this work.
55//!
56//! `SPDX-License-Identifier: MIT OR Apache-2.0`
57#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
58
59use std::pin::Pin;
60use std::task::{Context, Poll};
61
62use tokio::signal::unix::{Signal, SignalKind};
63
64#[cfg(feature = "bootstrap")]
65mod bootstrap;
66
67#[cfg(feature = "bootstrap")]
68pub use bootstrap::{SignalConfig, SignalSvc};
69
70/// A handler for listening to multiple Unix signals, and providing a future for
71/// receiving them.
72///
73/// This is useful for applications that need to listen for multiple signals,
74/// and want to react to them in a non-blocking way. Typically you would need to
75/// use a tokio::select{} to listen for multiple signals, but this provides a
76/// more ergonomic interface for doing so.
77///
78/// After a signal is received you can poll the handler again to wait for
79/// another signal. Dropping the handle will cancel the signal subscription
80///
81/// # Example
82///
83/// ```rust
84/// use scuffle_signal::SignalHandler;
85/// use tokio::signal::unix::SignalKind;
86///
87/// # tokio_test::block_on(async {
88/// let mut handler = SignalHandler::new()
89///     .with_signal(SignalKind::interrupt())
90///     .with_signal(SignalKind::terminate());
91///
92/// # // Safety: This is a test, and we control the process.
93/// # unsafe {
94/// #    libc::raise(SignalKind::interrupt().as_raw_value());
95/// # }
96/// // Wait for a signal to be received
97/// let signal = handler.await;
98///
99/// // Handle the signal
100/// let interrupt = SignalKind::interrupt();
101/// let terminate = SignalKind::terminate();
102/// match signal {
103///     interrupt => {
104///         // Handle SIGINT
105///         println!("received SIGINT");
106///     },
107///     terminate => {
108///         // Handle SIGTERM
109///         println!("received SIGTERM");
110///     },
111/// }
112/// # });
113/// ```
114#[derive(Debug)]
115#[must_use = "signal handlers must be used to wait for signals"]
116pub struct SignalHandler {
117    signals: Vec<(SignalKind, Signal)>,
118}
119
120impl Default for SignalHandler {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl SignalHandler {
127    /// Create a new `SignalHandler` with no signals.
128    pub const fn new() -> Self {
129        Self { signals: Vec::new() }
130    }
131
132    /// Create a new `SignalHandler` with the given signals.
133    pub fn with_signals(signals: impl IntoIterator<Item = SignalKind>) -> Self {
134        let mut handler = Self::new();
135
136        for signal in signals {
137            handler = handler.with_signal(signal);
138        }
139
140        handler
141    }
142
143    /// Add a signal to the handler.
144    ///
145    /// If the signal is already in the handler, it will not be added again.
146    pub fn with_signal(mut self, kind: SignalKind) -> Self {
147        if self.signals.iter().any(|(k, _)| k == &kind) {
148            return self;
149        }
150
151        let signal = tokio::signal::unix::signal(kind).expect("failed to create signal");
152
153        self.signals.push((kind, signal));
154
155        self
156    }
157
158    /// Add a signal to the handler.
159    ///
160    /// If the signal is already in the handler, it will not be added again.
161    pub fn add_signal(&mut self, kind: SignalKind) -> &mut Self {
162        if self.signals.iter().any(|(k, _)| k == &kind) {
163            return self;
164        }
165
166        let signal = tokio::signal::unix::signal(kind).expect("failed to create signal");
167
168        self.signals.push((kind, signal));
169
170        self
171    }
172
173    /// Wait for a signal to be received.
174    /// This is equivilant to calling (&mut handler).await, but is more
175    /// ergonomic if you want to not take ownership of the handler.
176    pub async fn recv(&mut self) -> SignalKind {
177        self.await
178    }
179
180    /// Poll for a signal to be received.
181    /// Does not require pinning the handler.
182    pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<SignalKind> {
183        for (kind, signal) in self.signals.iter_mut() {
184            if signal.poll_recv(cx).is_ready() {
185                return Poll::Ready(*kind);
186            }
187        }
188
189        Poll::Pending
190    }
191}
192
193impl std::future::Future for SignalHandler {
194    type Output = SignalKind;
195
196    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
197        self.poll_recv(cx)
198    }
199}
200
201#[cfg(test)]
202#[cfg_attr(all(coverage_nightly, test), coverage(off))]
203mod tests {
204    use std::time::Duration;
205
206    use scuffle_future_ext::FutureExt;
207
208    use super::*;
209
210    pub fn raise_signal(kind: SignalKind) {
211        // Safety: This is a test, and we control the process.
212        unsafe {
213            libc::raise(kind.as_raw_value());
214        }
215    }
216
217    #[cfg(not(valgrind))] // test is time-sensitive
218    #[tokio::test]
219    async fn signal_handler() {
220        let mut handler = SignalHandler::with_signals([SignalKind::user_defined1()])
221            .with_signal(SignalKind::user_defined2())
222            .with_signal(SignalKind::user_defined1());
223
224        raise_signal(SignalKind::user_defined1());
225
226        let recv = (&mut handler).with_timeout(Duration::from_millis(5)).await.unwrap();
227
228        assert_eq!(recv, SignalKind::user_defined1(), "expected SIGUSR1");
229
230        // We already received the signal, so polling again should return Poll::Pending
231        let recv = (&mut handler).with_timeout(Duration::from_millis(5)).await;
232
233        assert!(recv.is_err(), "expected timeout");
234
235        raise_signal(SignalKind::user_defined2());
236
237        // We should be able to receive the signal again
238        let recv = (&mut handler).with_timeout(Duration::from_millis(5)).await.unwrap();
239
240        assert_eq!(recv, SignalKind::user_defined2(), "expected SIGUSR2");
241    }
242
243    #[cfg(not(valgrind))] // test is time-sensitive
244    #[tokio::test]
245    async fn add_signal() {
246        let mut handler = SignalHandler::new();
247
248        handler
249            .add_signal(SignalKind::user_defined1())
250            .add_signal(SignalKind::user_defined2())
251            .add_signal(SignalKind::user_defined2());
252
253        raise_signal(SignalKind::user_defined1());
254
255        let recv = handler.recv().with_timeout(Duration::from_millis(5)).await.unwrap();
256
257        assert_eq!(recv, SignalKind::user_defined1(), "expected SIGUSR1");
258
259        raise_signal(SignalKind::user_defined2());
260
261        let recv = handler.recv().with_timeout(Duration::from_millis(5)).await.unwrap();
262
263        assert_eq!(recv, SignalKind::user_defined2(), "expected SIGUSR2");
264    }
265
266    #[cfg(not(valgrind))] // test is time-sensitive
267    #[tokio::test]
268    async fn no_signals() {
269        let mut handler = SignalHandler::default();
270
271        // Expected to timeout
272        assert!(handler.recv().with_timeout(Duration::from_millis(50)).await.is_err());
273    }
274}