spirit/
utils.rs

1//! Various utilities.
2//!
3//! All the little things that are useful through the spirit's or user's code, and don't really fit
4//! anywhere else.
5
6use std::env;
7use std::error::Error;
8use std::ffi::OsStr;
9use std::fmt::{Debug, Display, Formatter, Result as FmtResult};
10use std::ops::{Deref, DerefMut};
11use std::path::PathBuf;
12use std::str::FromStr;
13use std::sync::atomic::AtomicBool;
14use std::sync::Arc;
15use std::time::Duration;
16
17use err_context::prelude::*;
18use libc::c_int;
19use log::{debug, error, warn};
20use serde::de::{Deserializer, Error as DeError, Unexpected};
21use serde::ser::Serializer;
22use serde::{Deserialize, Serialize};
23
24use crate::AnyError;
25
26/// Tries to read an absolute path from the given OS string.
27///
28/// This converts the path to PathBuf. Then it tries to make it absolute and canonical, so changing
29/// current directory later on doesn't make it invalid.
30///
31/// The function never fails. However, the substeps (finding current directory to make it absolute
32/// and canonization) might fail. In such case, the failing step is skipped.
33///
34/// The motivation is parsing command line arguments using the
35/// [`structopt`](https://lib.rs/crates/structopt) crate. Users are used
36/// to passing relative paths to command line (as opposed to configuration files). However, if the
37/// daemon changes the current directory (for example during daemonization), the relative paths now
38/// point somewhere else.
39///
40/// # Examples
41///
42/// ```rust
43/// use std::path::PathBuf;
44///
45/// use structopt::StructOpt;
46///
47/// # #[allow(dead_code)]
48/// #[derive(Debug, StructOpt)]
49/// struct MyOpts {
50///     #[structopt(short = "p", parse(from_os_str = spirit::utils::absolute_from_os_str))]
51///     path: PathBuf,
52/// }
53///
54/// # fn main() { }
55/// ```
56pub fn absolute_from_os_str(path: &OsStr) -> PathBuf {
57    let mut current = env::current_dir().unwrap_or_else(|e| {
58        warn!(
59            "Some paths may not be turned to absolute. Couldn't read current dir: {}",
60            e,
61        );
62        PathBuf::new()
63    });
64    current.push(path);
65    if let Ok(canonicized) = current.canonicalize() {
66        canonicized
67    } else {
68        current
69    }
70}
71
72/// An error returned when the user passes a key-value option without the equal sign.
73///
74/// Some internal options take a key-value pairs on the command line. If such option is expected,
75/// but it doesn't contain the equal sign, this is the used error.
76#[derive(Copy, Clone, Debug)]
77pub struct MissingEquals;
78
79impl Display for MissingEquals {
80    fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
81        write!(fmt, "Missing = in map option")
82    }
83}
84
85impl Error for MissingEquals {}
86
87/// A helper for deserializing map-like command line arguments.
88///
89/// # Examples
90///
91/// ```rust
92/// # use structopt::StructOpt;
93/// # #[allow(dead_code)]
94/// #[derive(Debug, StructOpt)]
95/// struct MyOpts {
96///     #[structopt(
97///         short = "D",
98///         long = "define",
99///         parse(try_from_str = spirit::utils::key_val),
100///         number_of_values(1),
101///     )]
102///     defines: Vec<(String, String)>,
103/// }
104///
105/// # fn main() {}
106/// ```
107pub fn key_val<K, V>(opt: &str) -> Result<(K, V), AnyError>
108where
109    K: FromStr,
110    K::Err: Error + Send + Sync + 'static,
111    V: FromStr,
112    V::Err: Error + Send + Sync + 'static,
113{
114    let pos = opt.find('=').ok_or(MissingEquals)?;
115    Ok((opt[..pos].parse()?, opt[pos + 1..].parse()?))
116}
117
118/// A wrapper to hide a configuration field from logs.
119///
120/// This acts in as much transparent way as possible towards the field inside. It only replaces the
121/// [`Debug`] and [`Serialize`] implementations with returning `"******"`.
122///
123/// The idea is if the configuration contains passwords, they shouldn't leak into the logs.
124/// Therefore, wrap them in this, eg:
125///
126/// ```rust
127/// use std::io::Write;
128/// use std::str;
129///
130/// use spirit::utils::Hidden;
131///
132/// # #[allow(dead_code)]
133/// #[derive(Debug)]
134/// struct Cfg {
135///     username: String,
136///     password: Hidden<String>,
137/// }
138///
139/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
140/// let cfg = Cfg {
141///     username: "me".to_owned(),
142///     password: "secret".to_owned().into(),
143/// };
144///
145/// let mut buffer: Vec<u8> = Vec::new();
146/// write!(&mut buffer, "{:?}", cfg)?;
147/// assert_eq!(r#"Cfg { username: "me", password: "******" }"#, str::from_utf8(&buffer)?);
148/// # Ok(())
149/// # }
150/// ```
151#[derive(Clone, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash)]
152#[cfg_attr(feature = "cfg-help", derive(structdoc::StructDoc))]
153#[repr(transparent)]
154#[serde(transparent)]
155pub struct Hidden<T>(pub T);
156
157impl<T> From<T> for Hidden<T> {
158    fn from(val: T) -> Self {
159        Hidden(val)
160    }
161}
162
163impl<T> Deref for Hidden<T> {
164    type Target = T;
165    fn deref(&self) -> &T {
166        &self.0
167    }
168}
169
170impl<T> DerefMut for Hidden<T> {
171    fn deref_mut(&mut self) -> &mut T {
172        &mut self.0
173    }
174}
175
176impl<T> Debug for Hidden<T> {
177    fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
178        write!(fmt, "\"******\"")
179    }
180}
181
182impl<T> Serialize for Hidden<T> {
183    fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
184        s.serialize_str("******")
185    }
186}
187
188/// Serialize a duration.
189///
190/// This can be used in configuration structures containing durations. See [`deserialize_duration`]
191/// for the counterpart.
192///
193/// The default serialization produces human unreadable values, this is more suitable for dumping
194/// configuration users will read.
195///
196/// # Examples
197///
198/// ```rust
199/// use std::time::Duration;
200///
201/// use serde::{Deserialize, Serialize};
202///
203/// #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
204/// struct Cfg {
205///     #[serde(
206///         serialize_with = "spirit::utils::serialize_duration",
207///         deserialize_with = "spirit::utils::deserialize_duration",
208///     )]
209///     how_long: Duration,
210/// }
211/// ```
212pub fn serialize_duration<S: Serializer>(dur: &Duration, s: S) -> Result<S::Ok, S::Error> {
213    s.serialize_str(&humantime::format_duration(*dur).to_string())
214}
215
216/// Deserialize a human-readable duration.
217///
218/// # Examples
219///
220/// ```rust
221/// use std::time::Duration;
222///
223/// use serde::{Deserialize, Serialize};
224///
225/// #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
226/// struct Cfg {
227///     #[serde(
228///         serialize_with = "spirit::utils::serialize_duration",
229///         deserialize_with = "spirit::utils::deserialize_duration",
230///     )]
231///     how_long: Duration,
232/// }
233/// ```
234pub fn deserialize_duration<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
235    let s = String::deserialize(d)?;
236
237    humantime::parse_duration(&s)
238        .map_err(|_| DeError::invalid_value(Unexpected::Str(&s), &"Human readable duration"))
239}
240
241/// Deserialize an `Option<Duration>` using the [`humantime`](https://lib.rs/crates/humantime) crate.
242///
243/// This allows reading human-friendly representations of time, like `30s` or `5days`. It should be
244/// paired with [`serialize_opt_duration`]. Also, to act like [`Option`] does when deserializing by
245/// default, the `#[serde(default)]` is recommended.
246///
247/// # Examples
248///
249/// ```rust
250/// use std::time::Duration;
251///
252/// use serde::{Deserialize, Serialize};
253///
254/// #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
255/// struct Cfg {
256///     #[serde(
257///         serialize_with = "spirit::utils::serialize_opt_duration",
258///         deserialize_with = "spirit::utils::deserialize_opt_duration",
259///         default,
260///     )]
261///     how_long: Option<Duration>,
262/// }
263/// ```
264pub fn deserialize_opt_duration<'de, D: Deserializer<'de>>(
265    d: D,
266) -> Result<Option<Duration>, D::Error> {
267    if let Some(dur) = Option::<String>::deserialize(d)? {
268        humantime::parse_duration(&dur)
269            .map_err(|_| DeError::invalid_value(Unexpected::Str(&dur), &"Human readable duration"))
270            .map(Some)
271    } else {
272        Ok(None)
273    }
274}
275
276/// Serialize an `Option<Duration>` in a human friendly form.
277///
278/// See the [`deserialize_opt_duration`] for more details and an example.
279pub fn serialize_opt_duration<S: Serializer>(
280    dur: &Option<Duration>,
281    s: S,
282) -> Result<S::Ok, S::Error> {
283    match dur {
284        Some(d) => serialize_duration(d, s),
285        None => s.serialize_none(),
286    }
287}
288
289#[deprecated(note = "Abstraction at the wrong place. Use support_emergency_shutdown instead.")]
290#[doc(hidden)]
291pub fn cleanup_signals() {
292    debug!("Resetting termination signal handlers to defaults");
293    // Originally, this was done by removing all signals and resetting to defaults. We now install
294    // default-handler emulation instead. That's a little bit problematic, if it's the signal
295    // handlers that get stuck, but folks are recommended to use the support_emergency_shutdown
296    // instead anyway.
297    for sig in signal_hook::consts::TERM_SIGNALS {
298        let registered =
299            signal_hook::flag::register_conditional_default(*sig, Arc::new(AtomicBool::new(true)));
300        if let Err(e) = registered {
301            let name = signal_hook::low_level::signal_name(*sig).unwrap_or_default();
302            error!(
303                "Failed to register forced shutdown signal {}/{}: {}",
304                name, sig, e
305            );
306        }
307    }
308}
309
310/// Installs a stage-shutdown handling.
311///
312/// If CTRL+C (or some other similar signal) is received for the first time, a graceful shutdown is
313/// initiated and a flag is set to true. If it is received for a second time, the application is
314/// terminated abruptly.
315///
316/// The flag handle is returned to the caller, so the graceful shutdown and second stage kill can
317/// be aborted.
318///
319/// Note that this API doesn't allow for removing the staged shutdown (due to the needed API
320/// clutter). If that is needed, you can use [`signal_hook`] directly.
321///
322/// # Usage
323///
324/// This is supposed to be called early in the program (usually as the first thing in `main`). This
325/// is for two reasons:
326///
327/// * One usually wants this kind of emergency handling even during startup ‒ if something gets
328///   stuck during the initialization.
329/// * Installing signal handlers once there are multiple threads is inherently racy, therefore it
330///   is better to be done before any additional threads are started.
331///
332/// # Examples
333///
334/// ```rust
335/// use spirit::prelude::*;
336/// use spirit::{utils, Empty, Spirit};
337///
338/// fn main() {
339///     // Do this first, so double CTRL+C works from the very beginning.
340///     utils::support_emergency_shutdown().expect("This doesn't fail on healthy systems");
341///     // Proceed to doing something useful.
342///     Spirit::<Empty, Empty>::new()
343///         .run(|_spirit| {
344///             println!("Hello world");
345///             Ok(())
346///         });
347/// }
348/// ```
349///
350/// # Errors
351///
352/// This manipulates low-level signal handlers internally, so in theory this can fail. But this is
353/// not expected to fail in practice (not on a system that isn't severely broken anyway). As such,
354/// it is probably reasonable to unwrap here.
355pub fn support_emergency_shutdown() -> Result<Arc<AtomicBool>, AnyError> {
356    let flag = Arc::new(AtomicBool::new(false));
357
358    let install = |sig: c_int| -> Result<(), AnyError> {
359        signal_hook::flag::register_conditional_shutdown(sig, 2, Arc::clone(&flag))?;
360        signal_hook::flag::register(sig, Arc::clone(&flag))?;
361        Ok(())
362    };
363
364    for sig in signal_hook::consts::TERM_SIGNALS {
365        let name = signal_hook::low_level::signal_name(*sig).unwrap_or_default();
366        debug!("Installing emergency shutdown support for {}/{}", name, sig);
367        install(*sig).with_context(|_| {
368            format!(
369                "Failed to install staged shutdown handler for {}/{}",
370                name, sig
371            )
372        })?
373    }
374
375    Ok(flag)
376}
377
378/// Checks if value is default.
379///
380/// Useful in `#[serde(skip_serializing_if = "is_default")]`
381pub fn is_default<T: Default + PartialEq>(v: &T) -> bool {
382    v == &T::default()
383}
384
385/// Checks if value is set to true.
386///
387/// Useful in `#[serde(skip_serializing_if = "is_true")]`
388pub fn is_true(v: &bool) -> bool {
389    *v
390}
391
392pub(crate) struct FlushGuard;
393
394impl Drop for FlushGuard {
395    fn drop(&mut self) {
396        log::logger().flush();
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use std::ffi::OsString;
403    use std::net::{AddrParseError, IpAddr};
404    use std::num::ParseIntError;
405
406    use super::*;
407
408    #[test]
409    fn abs() {
410        let current = env::current_dir().unwrap();
411        let parent = absolute_from_os_str(&OsString::from(".."));
412        assert!(parent.is_absolute());
413        assert!(current.starts_with(parent));
414
415        let child = absolute_from_os_str(&OsString::from("this-likely-doesn't-exist"));
416        assert!(child.is_absolute());
417        assert!(child.starts_with(current));
418    }
419
420    /// Valid inputs for the key-value parser
421    #[test]
422    fn key_val_success() {
423        assert_eq!(
424            ("hello".to_owned(), "world".to_owned()),
425            key_val("hello=world").unwrap()
426        );
427        let ip: IpAddr = "192.0.2.1".parse().unwrap();
428        assert_eq!(("ip".to_owned(), ip), key_val("ip=192.0.2.1").unwrap());
429        assert_eq!(("count".to_owned(), 4), key_val("count=4").unwrap());
430    }
431
432    /// The extra equals sign go into the value part.
433    #[test]
434    fn key_val_extra_equals() {
435        assert_eq!(
436            ("greeting".to_owned(), "hello=world".to_owned()),
437            key_val("greeting=hello=world").unwrap(),
438        );
439    }
440
441    /// Test when the key or value doesn't get parsed.
442    #[test]
443    fn key_val_parse_fail() {
444        key_val::<String, IpAddr>("hello=192.0.2.1.0")
445            .unwrap_err()
446            .downcast_ref::<AddrParseError>()
447            .expect("Different error returned");
448        key_val::<usize, String>("hello=world")
449            .unwrap_err()
450            .downcast_ref::<ParseIntError>()
451            .expect("Different error returned");
452    }
453
454    #[test]
455    fn key_val_missing_eq() {
456        key_val::<String, String>("no equal sign")
457            .unwrap_err()
458            .downcast_ref::<MissingEquals>()
459            .expect("Different error returned");
460    }
461}