scuffle_signal/lib.rs
1//! A crate designed to provide a more user friendly interface to
2//! `tokio::signal`.
3//!
4//! ## Why do we need this?
5//!
6//! The `tokio::signal` module provides a way for us to wait for a signal to be
7//! received in a non-blocking way. This crate extends that with a more helpful
8//! interface allowing the ability to listen to multiple signals concurrently.
9//!
10//! ## Example
11//!
12//! ```rust
13//! use scuffle_signal::SignalHandler;
14//! use tokio::signal::unix::SignalKind;
15//!
16//! # tokio_test::block_on(async {
17//! let mut handler = SignalHandler::new()
18//! .with_signal(SignalKind::interrupt())
19//! .with_signal(SignalKind::terminate());
20//!
21//! # // Safety: This is a test, and we control the process.
22//! # unsafe {
23//! # libc::raise(SignalKind::interrupt().as_raw_value());
24//! # }
25//! // Wait for a signal to be received
26//! let signal = handler.await;
27//!
28//! // Handle the signal
29//! let interrupt = SignalKind::interrupt();
30//! let terminate = SignalKind::terminate();
31//! match signal {
32//! interrupt => {
33//! // Handle SIGINT
34//! println!("received SIGINT");
35//! },
36//! terminate => {
37//! // Handle SIGTERM
38//! println!("received SIGTERM");
39//! },
40//! }
41//! # });
42//! ```
43//!
44//! ## Status
45//!
46//! This crate is currently under development and is not yet stable.
47//!
48//! Unit tests are not yet fully implemented. Use at your own risk.
49//!
50//! ## License
51//!
52//! This project is licensed under the [MIT](./LICENSE.MIT) or
53//! [Apache-2.0](./LICENSE.Apache-2.0) license. You can choose between one of
54//! them if you use this work.
55//!
56//! `SPDX-License-Identifier: MIT OR Apache-2.0`
57#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
58
59use std::pin::Pin;
60use std::task::{Context, Poll};
61
62use tokio::signal::unix::{Signal, SignalKind};
63
64#[cfg(feature = "bootstrap")]
65mod bootstrap;
66
67#[cfg(feature = "bootstrap")]
68pub use bootstrap::{SignalConfig, SignalSvc};
69
70/// A handler for listening to multiple Unix signals, and providing a future for
71/// receiving them.
72///
73/// This is useful for applications that need to listen for multiple signals,
74/// and want to react to them in a non-blocking way. Typically you would need to
75/// use a tokio::select{} to listen for multiple signals, but this provides a
76/// more ergonomic interface for doing so.
77///
78/// After a signal is received you can poll the handler again to wait for
79/// another signal. Dropping the handle will cancel the signal subscription
80///
81/// # Example
82///
83/// ```rust
84/// use scuffle_signal::SignalHandler;
85/// use tokio::signal::unix::SignalKind;
86///
87/// # tokio_test::block_on(async {
88/// let mut handler = SignalHandler::new()
89/// .with_signal(SignalKind::interrupt())
90/// .with_signal(SignalKind::terminate());
91///
92/// # // Safety: This is a test, and we control the process.
93/// # unsafe {
94/// # libc::raise(SignalKind::interrupt().as_raw_value());
95/// # }
96/// // Wait for a signal to be received
97/// let signal = handler.await;
98///
99/// // Handle the signal
100/// let interrupt = SignalKind::interrupt();
101/// let terminate = SignalKind::terminate();
102/// match signal {
103/// interrupt => {
104/// // Handle SIGINT
105/// println!("received SIGINT");
106/// },
107/// terminate => {
108/// // Handle SIGTERM
109/// println!("received SIGTERM");
110/// },
111/// }
112/// # });
113/// ```
114#[derive(Debug)]
115#[must_use = "signal handlers must be used to wait for signals"]
116pub struct SignalHandler {
117 signals: Vec<(SignalKind, Signal)>,
118}
119
120impl Default for SignalHandler {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl SignalHandler {
127 /// Create a new `SignalHandler` with no signals.
128 pub const fn new() -> Self {
129 Self { signals: Vec::new() }
130 }
131
132 /// Create a new `SignalHandler` with the given signals.
133 pub fn with_signals(signals: impl IntoIterator<Item = SignalKind>) -> Self {
134 let mut handler = Self::new();
135
136 for signal in signals {
137 handler = handler.with_signal(signal);
138 }
139
140 handler
141 }
142
143 /// Add a signal to the handler.
144 ///
145 /// If the signal is already in the handler, it will not be added again.
146 pub fn with_signal(mut self, kind: SignalKind) -> Self {
147 if self.signals.iter().any(|(k, _)| k == &kind) {
148 return self;
149 }
150
151 let signal = tokio::signal::unix::signal(kind).expect("failed to create signal");
152
153 self.signals.push((kind, signal));
154
155 self
156 }
157
158 /// Add a signal to the handler.
159 ///
160 /// If the signal is already in the handler, it will not be added again.
161 pub fn add_signal(&mut self, kind: SignalKind) -> &mut Self {
162 if self.signals.iter().any(|(k, _)| k == &kind) {
163 return self;
164 }
165
166 let signal = tokio::signal::unix::signal(kind).expect("failed to create signal");
167
168 self.signals.push((kind, signal));
169
170 self
171 }
172
173 /// Wait for a signal to be received.
174 /// This is equivilant to calling (&mut handler).await, but is more
175 /// ergonomic if you want to not take ownership of the handler.
176 pub async fn recv(&mut self) -> SignalKind {
177 self.await
178 }
179
180 /// Poll for a signal to be received.
181 /// Does not require pinning the handler.
182 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<SignalKind> {
183 for (kind, signal) in self.signals.iter_mut() {
184 if signal.poll_recv(cx).is_ready() {
185 return Poll::Ready(*kind);
186 }
187 }
188
189 Poll::Pending
190 }
191}
192
193impl std::future::Future for SignalHandler {
194 type Output = SignalKind;
195
196 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
197 self.poll_recv(cx)
198 }
199}
200
201#[cfg(test)]
202#[cfg_attr(all(coverage_nightly, test), coverage(off))]
203mod tests {
204 use std::time::Duration;
205
206 use scuffle_future_ext::FutureExt;
207
208 use super::*;
209
210 pub fn raise_signal(kind: SignalKind) {
211 // Safety: This is a test, and we control the process.
212 unsafe {
213 libc::raise(kind.as_raw_value());
214 }
215 }
216
217 #[cfg(not(valgrind))] // test is time-sensitive
218 #[tokio::test]
219 async fn signal_handler() {
220 let mut handler = SignalHandler::with_signals([SignalKind::user_defined1()])
221 .with_signal(SignalKind::user_defined2())
222 .with_signal(SignalKind::user_defined1());
223
224 raise_signal(SignalKind::user_defined1());
225
226 let recv = (&mut handler).with_timeout(Duration::from_millis(5)).await.unwrap();
227
228 assert_eq!(recv, SignalKind::user_defined1(), "expected SIGUSR1");
229
230 // We already received the signal, so polling again should return Poll::Pending
231 let recv = (&mut handler).with_timeout(Duration::from_millis(5)).await;
232
233 assert!(recv.is_err(), "expected timeout");
234
235 raise_signal(SignalKind::user_defined2());
236
237 // We should be able to receive the signal again
238 let recv = (&mut handler).with_timeout(Duration::from_millis(5)).await.unwrap();
239
240 assert_eq!(recv, SignalKind::user_defined2(), "expected SIGUSR2");
241 }
242
243 #[cfg(not(valgrind))] // test is time-sensitive
244 #[tokio::test]
245 async fn add_signal() {
246 let mut handler = SignalHandler::new();
247
248 handler
249 .add_signal(SignalKind::user_defined1())
250 .add_signal(SignalKind::user_defined2())
251 .add_signal(SignalKind::user_defined2());
252
253 raise_signal(SignalKind::user_defined1());
254
255 let recv = handler.recv().with_timeout(Duration::from_millis(5)).await.unwrap();
256
257 assert_eq!(recv, SignalKind::user_defined1(), "expected SIGUSR1");
258
259 raise_signal(SignalKind::user_defined2());
260
261 let recv = handler.recv().with_timeout(Duration::from_millis(5)).await.unwrap();
262
263 assert_eq!(recv, SignalKind::user_defined2(), "expected SIGUSR2");
264 }
265
266 #[cfg(not(valgrind))] // test is time-sensitive
267 #[tokio::test]
268 async fn no_signals() {
269 let mut handler = SignalHandler::default();
270
271 // Expected to timeout
272 assert!(handler.recv().with_timeout(Duration::from_millis(50)).await.is_err());
273 }
274}