Skip to main content

tor_basic_utils/
error_sources.rs

1//! Helpers for iterating over error sources.
2
3use std::{io, sync::Arc};
4
5/// An iterator over the lower-level error sources, and possibly their wrapped errors, of
6/// an [`std::error::Error`].
7///
8/// One of the main reasons why you might want to use this instead of calling [`std::error::Error::source`]
9/// repeatedly is because the `source` implementation of [`io::Error`] doesn't return wrapped errors unless
10/// you call `get_ref` on them (see: <https://github.com/rust-lang/rust/pull/124536>). You can think of this
11/// iterator as walking down the chain of how an error was constructed. However, this iterator shouldn't be
12/// used to display or format errors. Doing so could result in displaying the same error twice (due to the
13/// wrapping behavior of `io::Error`).
14///
15/// Each call to [`Iterator::next`] will attempt to peel off the outer layer of the error.
16///
17/// The first item returned is always the original error. Subsequent items are generated by calling:
18///   * [`io::Error::get_ref`] if the last error could be downcast to an [`io::Error`] or
19///     [`Arc<io::Error>`], or
20///   * [`std::error::Error::source`] in all other cases
21///
22/// # Limitations
23///
24/// This is currently not handling [`io::Error`]s that are wrapped in containers such as `Box`, `Rc`, etc.
25pub struct ErrorSources<'a> {
26    /// The last error we managed to get via `get_ref` or `source`.
27    ///
28    /// Initially this is set to the error passed in via [`Self::new`].
29    error: Option<&'a (dyn std::error::Error + 'static)>,
30}
31
32impl<'a> ErrorSources<'a> {
33    /// Create an iterator over the lower-level sources of this error.
34    pub fn new(error: &'a (dyn std::error::Error + 'static)) -> Self {
35        Self { error: Some(error) }
36    }
37}
38
39impl<'a> Iterator for ErrorSources<'a> {
40    type Item = &'a (dyn std::error::Error + 'static);
41
42    fn next(&mut self) -> Option<Self::Item> {
43        let error = self.error.take()?;
44
45        if let Some(io_error) = error.downcast_ref::<io::Error>() {
46            // This match is necessary to cast from `&dyn Error + Send + Sync` to `&dyn Error` :/
47            //
48            // The use of `get_ref` here is intentional because we want to save the error that
49            // this `io::Error` is wrapping. If we used `source` that would give us the source of
50            // the error that's being wrapped.
51            self.error = io_error.get_ref().map(|e| e as _);
52        } else if let Some(io_error) = error.downcast_ref::<Arc<io::Error>>() {
53            self.error = io_error.get_ref().map(|e| e as _);
54        } else {
55            self.error = error.source();
56        }
57
58        Some(error)
59    }
60}
61
62// ----------------------------------------------------------------------
63
64#[cfg(test)]
65mod test {
66    // @@ begin test lint list maintained by maint/add_warning @@
67    #![allow(clippy::bool_assert_comparison)]
68    #![allow(clippy::clone_on_copy)]
69    #![allow(clippy::dbg_macro)]
70    #![allow(clippy::mixed_attributes_style)]
71    #![allow(clippy::print_stderr)]
72    #![allow(clippy::print_stdout)]
73    #![allow(clippy::single_char_pattern)]
74    #![allow(clippy::unwrap_used)]
75    #![allow(clippy::unchecked_time_subtraction)]
76    #![allow(clippy::useless_vec)]
77    #![allow(clippy::needless_pass_by_value)]
78    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
79    use super::*;
80
81    #[derive(thiserror::Error, Debug)]
82    #[error("my error")]
83    struct MyError;
84
85    macro_rules! downcast_next {
86        ($errors:expr, $ty:ty) => {
87            $errors.next().unwrap().downcast_ref::<$ty>().unwrap()
88        };
89    }
90
91    #[test]
92    fn error_sources() {
93        let wrapped_error = io::Error::new(
94            io::ErrorKind::ConnectionReset,
95            Arc::new(io::Error::new(io::ErrorKind::ConnectionReset, MyError)),
96        );
97        let mut errors = ErrorSources::new(&wrapped_error);
98
99        downcast_next!(errors, io::Error);
100        downcast_next!(errors, Arc<io::Error>);
101        downcast_next!(errors, MyError);
102        assert!(errors.next().is_none());
103    }
104}