s2n_tls/
error.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use core::{convert::TryInto, fmt, ptr::NonNull, task::Poll};
5use errno::{errno, Errno};
6use libc::c_char;
7use s2n_tls_sys::*;
8use std::{convert::TryFrom, ffi::CStr};
9
10/// Corresponds to [s2n_error_type].
11#[non_exhaustive]
12#[derive(Copy, Clone, Debug, PartialEq)]
13pub enum ErrorType {
14    UnknownErrorType,
15    NoError,
16    IOError,
17    ConnectionClosed,
18    Blocked,
19    Alert,
20    ProtocolError,
21    InternalError,
22    UsageError,
23    Application,
24}
25
26#[non_exhaustive]
27#[derive(Debug, PartialEq)]
28pub enum ErrorSource {
29    Library,
30    Bindings,
31    Application,
32}
33
34impl From<libc::c_int> for ErrorType {
35    fn from(input: libc::c_int) -> Self {
36        match input as s2n_error_type::Type {
37            s2n_error_type::OK => ErrorType::NoError,
38            s2n_error_type::IO => ErrorType::IOError,
39            s2n_error_type::CLOSED => ErrorType::ConnectionClosed,
40            s2n_error_type::BLOCKED => ErrorType::Blocked,
41            s2n_error_type::ALERT => ErrorType::Alert,
42            s2n_error_type::PROTO => ErrorType::ProtocolError,
43            s2n_error_type::INTERNAL => ErrorType::InternalError,
44            s2n_error_type::USAGE => ErrorType::UsageError,
45            _ => ErrorType::UnknownErrorType,
46        }
47    }
48}
49
50enum Context {
51    Bindings(ErrorType, &'static str, &'static str),
52    Code(s2n_status_code::Type, Errno),
53    Application(Box<dyn std::error::Error + Send + Sync + 'static>),
54}
55
56pub struct Error(Context);
57
58pub trait Fallible {
59    type Output;
60
61    fn into_result(self) -> Result<Self::Output, Error>;
62}
63
64impl Fallible for s2n_status_code::Type {
65    type Output = s2n_status_code::Type;
66
67    fn into_result(self) -> Result<Self::Output, Error> {
68        if self >= s2n_status_code::SUCCESS {
69            Ok(self)
70        } else {
71            Err(Error::capture())
72        }
73    }
74}
75
76impl Fallible for isize {
77    type Output = usize;
78
79    fn into_result(self) -> Result<Self::Output, Error> {
80        // Negative values can't be converted to a real size
81        // and instead indicate an error.
82        self.try_into().map_err(|_| Error::capture())
83    }
84}
85
86impl Fallible for u64 {
87    type Output = Self;
88
89    /// Converts a u64 to a Result by checking for u64::MAX.
90    ///
91    /// If a method that returns an unsigned int is fallible,
92    /// then the -1 error result wraps around to u64::MAX.
93    ///
94    /// For a u64 to be Fallible, a result of u64::MAX must not be
95    /// possible without an error. For example, [`s2n_connection_get_delay`]
96    /// can't return u64::MAX as a valid result because
97    /// s2n-tls blinding delays are limited to 30s, or a return value of 3^10 ns,
98    /// which is significantly less than u64::MAX. [`s2n_connection_get_delay`]
99    /// would therefore only return u64::MAX for a -1 error result.
100    fn into_result(self) -> Result<Self::Output, Error> {
101        if self != Self::MAX {
102            Ok(self)
103        } else {
104            Err(Error::capture())
105        }
106    }
107}
108
109impl<T> Fallible for *mut T {
110    type Output = NonNull<T>;
111
112    fn into_result(self) -> Result<Self::Output, Error> {
113        if let Some(value) = NonNull::new(self) {
114            Ok(value)
115        } else {
116            Err(Error::capture())
117        }
118    }
119}
120
121impl<T> Fallible for *const T {
122    type Output = *const T;
123
124    fn into_result(self) -> Result<Self::Output, Error> {
125        if !self.is_null() {
126            Ok(self)
127        } else {
128            Err(Error::capture())
129        }
130    }
131}
132
133pub trait Pollable {
134    type Output;
135
136    fn into_poll(self) -> Poll<Result<Self::Output, Error>>;
137}
138
139impl<T: Fallible> Pollable for T {
140    type Output = T::Output;
141
142    fn into_poll(self) -> Poll<Result<Self::Output, Error>> {
143        match self.into_result() {
144            Ok(r) => Ok(r).into(),
145            Err(err) if err.is_retryable() => Poll::Pending,
146            Err(err) => Err(err).into(),
147        }
148    }
149}
150
151impl Error {
152    pub(crate) const INVALID_INPUT: Error = Self::bindings(
153        ErrorType::UsageError,
154        "InvalidInput",
155        "An input parameter was incorrect",
156    );
157    pub(crate) const MISSING_WAKER: Error = Self::bindings(
158        ErrorType::UsageError,
159        "MissingWaker",
160        "Tried to perform an asynchronous operation without a configured waker",
161    );
162
163    /// Converts an io::Error into an s2n-tls Error
164    pub fn io_error(err: std::io::Error) -> Error {
165        let errno = err.raw_os_error().unwrap_or(1);
166        errno::set_errno(errno::Errno(errno));
167        s2n_status_code::FAILURE.into_result().unwrap_err()
168    }
169
170    /// An error occurred while running application code.
171    ///
172    /// Can be emitted from [`crate::callbacks::ConnectionFuture::poll()`] to indicate
173    /// async task failure.
174    pub fn application(error: Box<dyn std::error::Error + Send + Sync + 'static>) -> Self {
175        Self(Context::Application(error))
176    }
177
178    /// An error occurred while running bindings code.
179    pub(crate) const fn bindings(
180        kind: ErrorType,
181        name: &'static str,
182        message: &'static str,
183    ) -> Self {
184        Self(Context::Bindings(kind, name, message))
185    }
186
187    fn capture() -> Self {
188        unsafe {
189            let s2n_errno = s2n_errno_location();
190
191            let code = *s2n_errno;
192
193            // https://github.com/aws/s2n-tls/blob/main/docs/USAGE-GUIDE.md#error-handling
194            //# To avoid possible confusion, s2n_errno should be cleared after processing
195            //# an error: s2n_errno = S2N_ERR_T_OK
196            *s2n_errno = s2n_error_type::OK as _;
197
198            Self(Context::Code(code, errno()))
199        }
200    }
201
202    /// Corresponds to [s2n_strerror_name] for ErrorSource::Library errors.
203    pub fn name(&self) -> &'static str {
204        match self.0 {
205            Context::Bindings(_, name, _) => name,
206            Context::Application(_) => "ApplicationError",
207            Context::Code(code, _) => unsafe {
208                // Safety: we assume the string has a valid encoding coming from s2n
209                cstr_to_str(s2n_strerror_name(code))
210            },
211        }
212    }
213
214    /// Corresponds to [s2n_strerror] for ErrorSource::Library errors.
215    pub fn message(&self) -> &'static str {
216        match self.0 {
217            Context::Bindings(_, _, msg) => msg,
218            Context::Application(_) => "An error occurred while executing application code",
219            Context::Code(code, _) => unsafe {
220                // Safety: we assume the string has a valid encoding coming from s2n
221                cstr_to_str(s2n_strerror(code, core::ptr::null()))
222            },
223        }
224    }
225
226    /// Corresponds to [s2n_strerror_debug] for ErrorSource::Library errors.
227    pub fn debug(&self) -> Option<&'static str> {
228        match self.0 {
229            Context::Bindings(_, _, _) | Context::Application(_) => None,
230            Context::Code(code, _) => unsafe {
231                let debug_info = s2n_strerror_debug(code, core::ptr::null());
232
233                // The debug string should be set to a constant static string
234                // when an error occurs, but because it starts out as NULL
235                // we should defend against mistakes.
236                if debug_info.is_null() {
237                    None
238                } else {
239                    // If the string is not null, then we can assume that
240                    // it is constant and static.
241                    Some(cstr_to_str(debug_info))
242                }
243            },
244        }
245    }
246
247    /// Corresponds to [s2n_error_get_type] for ErrorSource::Library errors.
248    pub fn kind(&self) -> ErrorType {
249        match self.0 {
250            Context::Bindings(error_type, _, _) => error_type,
251            Context::Application(_) => ErrorType::Application,
252            Context::Code(code, _) => unsafe { ErrorType::from(s2n_error_get_type(code)) },
253        }
254    }
255
256    pub fn source(&self) -> ErrorSource {
257        match self.0 {
258            Context::Bindings(_, _, _) => ErrorSource::Bindings,
259            Context::Application(_) => ErrorSource::Application,
260            Context::Code(_, _) => ErrorSource::Library,
261        }
262    }
263
264    #[allow(clippy::borrowed_box)]
265    /// Returns an [`std::error::Error`] if the error source was [`ErrorSource::Application`],
266    /// otherwise returns None.
267    pub fn application_error(&self) -> Option<&Box<dyn std::error::Error + Send + Sync + 'static>> {
268        if let Self(Context::Application(err)) = self {
269            Some(err)
270        } else {
271            None
272        }
273    }
274
275    pub fn is_retryable(&self) -> bool {
276        matches!(self.kind(), ErrorType::Blocked)
277    }
278}
279
280#[cfg(feature = "quic")]
281impl Error {
282    /// s2n-tls does not send specific errors.
283    ///
284    /// However, we can attempt to map local errors into the alerts
285    /// that we would have sent if we sent alerts.
286    ///
287    /// This API is currently incomplete and should not be relied upon.
288    ///
289    /// Corresponds to [s2n_error_get_alert] for ErrorSource::Library errors.
290    pub fn alert(&self) -> Option<u8> {
291        match self.0 {
292            Context::Bindings(_, _, _) | Context::Application(_) => None,
293            Context::Code(code, _) => {
294                let mut alert = 0;
295                let r = unsafe { s2n_error_get_alert(code, &mut alert) };
296                match r.into_result() {
297                    Ok(_) => Some(alert),
298                    Err(_) => None,
299                }
300            }
301        }
302    }
303}
304
305/// # Safety
306///
307/// The caller must ensure the char pointer must contain a valid
308/// UTF-8 string from a trusted source
309unsafe fn cstr_to_str(v: *const c_char) -> &'static str {
310    let slice = CStr::from_ptr(v);
311    let bytes = slice.to_bytes();
312    core::str::from_utf8_unchecked(bytes)
313}
314
315impl TryFrom<std::io::Error> for Error {
316    type Error = Error;
317    fn try_from(value: std::io::Error) -> Result<Self, Self::Error> {
318        let io_inner = value.into_inner().ok_or(Error::INVALID_INPUT)?;
319        io_inner
320            .downcast::<Self>()
321            .map(|error| *error)
322            .map_err(|_| Error::INVALID_INPUT)
323    }
324}
325
326impl From<Error> for std::io::Error {
327    fn from(input: Error) -> Self {
328        let kind = match input.kind() {
329            ErrorType::IOError => {
330                if let Context::Code(_, errno) = input.0 {
331                    let bare = std::io::Error::from_raw_os_error(errno.0);
332                    bare.kind()
333                } else {
334                    std::io::ErrorKind::Other
335                }
336            }
337            ErrorType::ConnectionClosed => std::io::ErrorKind::UnexpectedEof,
338            _ => std::io::ErrorKind::Other,
339        };
340        std::io::Error::new(kind, input)
341    }
342}
343
344impl fmt::Debug for Error {
345    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
346        // Application errors don't carry any interesting s2n context, so
347        // forward directly to the underyling error.
348        if let Self(Context::Application(err)) = self {
349            return err.fmt(f);
350        }
351
352        let mut s = f.debug_struct("Error");
353        if let Context::Code(code, _) = self.0 {
354            s.field("code", &code);
355        }
356
357        s.field("name", &self.name());
358        s.field("message", &self.message());
359        s.field("kind", &self.kind());
360        s.field("source", &self.source());
361
362        if let Some(debug) = self.debug() {
363            s.field("debug", &debug);
364        }
365
366        // "errno" is only known to be meaningful for IOErrors.
367        // However, it has occasionally proved useful for debugging
368        // other errors, so include it for all errors.
369        if let Context::Code(_, errno) = self.0 {
370            s.field("errno", &errno.to_string());
371        }
372
373        s.finish()
374    }
375}
376
377impl fmt::Display for Error {
378    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
379        if let Self(Context::Application(err)) = self {
380            err.fmt(f)
381        } else {
382            f.write_str(self.message())
383        }
384    }
385}
386
387impl std::error::Error for Error {
388    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
389        // implement `source` in the same way `std::io::Error` implements it:
390        // https://doc.rust-lang.org/std/io/struct.Error.html#method.source
391        if let Self(Context::Application(err)) = self {
392            err.source()
393        } else {
394            None
395        }
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use crate::{enums::Version, testing::client_hello::CustomError};
403    use errno::set_errno;
404
405    const FAILURE: isize = -1;
406
407    // This relies on an implementation detail of s2n-tls errors,
408    // and could make these tests brittle. However, the alternative
409    // is a real handshake producing a real IO error, so just updating
410    // this value if the definition of an IO error changes might be easier.
411    const S2N_IO_ERROR_CODE: s2n_status_code::Type = 1 << 26;
412
413    #[test]
414    fn s2n_io_error_to_std_io_error() -> Result<(), Box<dyn std::error::Error>> {
415        set_errno(Errno(libc::ECONNRESET));
416        unsafe {
417            let s2n_errno_ptr = s2n_errno_location();
418            *s2n_errno_ptr = S2N_IO_ERROR_CODE;
419        }
420
421        let s2n_error = FAILURE.into_result().unwrap_err();
422        assert_eq!(ErrorType::IOError, s2n_error.kind());
423
424        let io_error = std::io::Error::from(s2n_error);
425        assert_eq!(std::io::ErrorKind::ConnectionReset, io_error.kind());
426        assert!(io_error.into_inner().is_some());
427        Ok(())
428    }
429
430    #[test]
431    fn s2n_error_to_std_io_error() -> Result<(), Box<dyn std::error::Error>> {
432        set_errno(Errno(libc::ECONNRESET));
433        unsafe {
434            let s2n_errno_ptr = s2n_errno_location();
435            *s2n_errno_ptr = S2N_IO_ERROR_CODE - 1;
436        }
437
438        let s2n_error = FAILURE.into_result().unwrap_err();
439        assert_ne!(ErrorType::IOError, s2n_error.kind());
440
441        let io_error = std::io::Error::from(s2n_error);
442        assert_eq!(std::io::ErrorKind::Other, io_error.kind());
443        assert!(io_error.into_inner().is_some());
444        Ok(())
445    }
446
447    #[test]
448    fn invalid_input_to_std_io_error() -> Result<(), Box<dyn std::error::Error>> {
449        let s2n_error = Version::try_from(0).unwrap_err();
450        assert_eq!(ErrorType::UsageError, s2n_error.kind());
451
452        let io_error = std::io::Error::from(s2n_error);
453        assert_eq!(std::io::ErrorKind::Other, io_error.kind());
454        assert!(io_error.into_inner().is_some());
455        Ok(())
456    }
457
458    #[test]
459    fn error_source() -> Result<(), Box<dyn std::error::Error>> {
460        let bindings_error = Version::try_from(0).unwrap_err();
461        assert_eq!(ErrorSource::Bindings, bindings_error.source());
462
463        let library_error = FAILURE.into_result().unwrap_err();
464        assert_eq!(ErrorSource::Library, library_error.source());
465
466        Ok(())
467    }
468
469    #[test]
470    fn application_error() {
471        // test single level errors
472        {
473            let error = Error::application(Box::new(CustomError));
474
475            let app_error = error.application_error().unwrap();
476            let _custom_error = app_error.downcast_ref::<CustomError>().unwrap();
477
478            let display = format!("{error}");
479            assert_eq!(display, "custom error");
480            let debug = format!("{error:?}");
481            assert_eq!(debug, "CustomError");
482        }
483
484        // make sure nested errors work
485        {
486            let io_error = std::io::Error::new(std::io::ErrorKind::Other, CustomError);
487            let error = Error::application(Box::new(io_error));
488
489            let app_error = error.application_error().unwrap();
490            let io_error = app_error.downcast_ref::<std::io::Error>().unwrap();
491            let _custom_error = io_error
492                .get_ref()
493                .unwrap()
494                .downcast_ref::<CustomError>()
495                .unwrap();
496        }
497    }
498
499    #[test]
500    fn bindings_error() {
501        let name = "TestError";
502        let message = "Custom error for test";
503        let kind = ErrorType::InternalError;
504        let error = Error::bindings(kind, name, message);
505        assert_eq!(error.kind(), kind);
506        assert_eq!(error.name(), name);
507        assert_eq!(error.message(), message);
508        assert_eq!(error.debug(), None);
509        assert_eq!(error.source(), ErrorSource::Bindings);
510    }
511}