1use core::{convert::TryInto, fmt, ptr::NonNull, task::Poll};
5use errno::{errno, Errno};
6use libc::c_char;
7use s2n_tls_sys::*;
8use std::{convert::TryFrom, ffi::CStr};
9
10#[non_exhaustive]
12#[derive(Copy, Clone, Debug, PartialEq)]
13pub enum ErrorType {
14 UnknownErrorType,
15 NoError,
16 IOError,
17 ConnectionClosed,
18 Blocked,
19 Alert,
20 ProtocolError,
21 InternalError,
22 UsageError,
23 Application,
24}
25
26#[non_exhaustive]
27#[derive(Debug, PartialEq)]
28pub enum ErrorSource {
29 Library,
30 Bindings,
31 Application,
32}
33
34impl From<libc::c_int> for ErrorType {
35 fn from(input: libc::c_int) -> Self {
36 match input as s2n_error_type::Type {
37 s2n_error_type::OK => ErrorType::NoError,
38 s2n_error_type::IO => ErrorType::IOError,
39 s2n_error_type::CLOSED => ErrorType::ConnectionClosed,
40 s2n_error_type::BLOCKED => ErrorType::Blocked,
41 s2n_error_type::ALERT => ErrorType::Alert,
42 s2n_error_type::PROTO => ErrorType::ProtocolError,
43 s2n_error_type::INTERNAL => ErrorType::InternalError,
44 s2n_error_type::USAGE => ErrorType::UsageError,
45 _ => ErrorType::UnknownErrorType,
46 }
47 }
48}
49
50enum Context {
51 Bindings(ErrorType, &'static str, &'static str),
52 Code(s2n_status_code::Type, Errno),
53 Application(Box<dyn std::error::Error + Send + Sync + 'static>),
54}
55
56pub struct Error(Context);
57
58pub trait Fallible {
59 type Output;
60
61 fn into_result(self) -> Result<Self::Output, Error>;
62}
63
64impl Fallible for s2n_status_code::Type {
65 type Output = s2n_status_code::Type;
66
67 fn into_result(self) -> Result<Self::Output, Error> {
68 if self >= s2n_status_code::SUCCESS {
69 Ok(self)
70 } else {
71 Err(Error::capture())
72 }
73 }
74}
75
76impl Fallible for isize {
77 type Output = usize;
78
79 fn into_result(self) -> Result<Self::Output, Error> {
80 self.try_into().map_err(|_| Error::capture())
83 }
84}
85
86impl Fallible for u64 {
87 type Output = Self;
88
89 fn into_result(self) -> Result<Self::Output, Error> {
101 if self != Self::MAX {
102 Ok(self)
103 } else {
104 Err(Error::capture())
105 }
106 }
107}
108
109impl<T> Fallible for *mut T {
110 type Output = NonNull<T>;
111
112 fn into_result(self) -> Result<Self::Output, Error> {
113 if let Some(value) = NonNull::new(self) {
114 Ok(value)
115 } else {
116 Err(Error::capture())
117 }
118 }
119}
120
121impl<T> Fallible for *const T {
122 type Output = *const T;
123
124 fn into_result(self) -> Result<Self::Output, Error> {
125 if !self.is_null() {
126 Ok(self)
127 } else {
128 Err(Error::capture())
129 }
130 }
131}
132
133pub trait Pollable {
134 type Output;
135
136 fn into_poll(self) -> Poll<Result<Self::Output, Error>>;
137}
138
139impl<T: Fallible> Pollable for T {
140 type Output = T::Output;
141
142 fn into_poll(self) -> Poll<Result<Self::Output, Error>> {
143 match self.into_result() {
144 Ok(r) => Ok(r).into(),
145 Err(err) if err.is_retryable() => Poll::Pending,
146 Err(err) => Err(err).into(),
147 }
148 }
149}
150
151impl Error {
152 pub(crate) const INVALID_INPUT: Error = Self::bindings(
153 ErrorType::UsageError,
154 "InvalidInput",
155 "An input parameter was incorrect",
156 );
157 pub(crate) const MISSING_WAKER: Error = Self::bindings(
158 ErrorType::UsageError,
159 "MissingWaker",
160 "Tried to perform an asynchronous operation without a configured waker",
161 );
162
163 pub fn io_error(err: std::io::Error) -> Error {
165 let errno = err.raw_os_error().unwrap_or(1);
166 errno::set_errno(errno::Errno(errno));
167 s2n_status_code::FAILURE.into_result().unwrap_err()
168 }
169
170 pub fn application(error: Box<dyn std::error::Error + Send + Sync + 'static>) -> Self {
175 Self(Context::Application(error))
176 }
177
178 pub(crate) const fn bindings(
180 kind: ErrorType,
181 name: &'static str,
182 message: &'static str,
183 ) -> Self {
184 Self(Context::Bindings(kind, name, message))
185 }
186
187 fn capture() -> Self {
188 unsafe {
189 let s2n_errno = s2n_errno_location();
190
191 let code = *s2n_errno;
192
193 *s2n_errno = s2n_error_type::OK as _;
197
198 Self(Context::Code(code, errno()))
199 }
200 }
201
202 pub fn name(&self) -> &'static str {
204 match self.0 {
205 Context::Bindings(_, name, _) => name,
206 Context::Application(_) => "ApplicationError",
207 Context::Code(code, _) => unsafe {
208 cstr_to_str(s2n_strerror_name(code))
210 },
211 }
212 }
213
214 pub fn message(&self) -> &'static str {
216 match self.0 {
217 Context::Bindings(_, _, msg) => msg,
218 Context::Application(_) => "An error occurred while executing application code",
219 Context::Code(code, _) => unsafe {
220 cstr_to_str(s2n_strerror(code, core::ptr::null()))
222 },
223 }
224 }
225
226 pub fn debug(&self) -> Option<&'static str> {
228 match self.0 {
229 Context::Bindings(_, _, _) | Context::Application(_) => None,
230 Context::Code(code, _) => unsafe {
231 let debug_info = s2n_strerror_debug(code, core::ptr::null());
232
233 if debug_info.is_null() {
237 None
238 } else {
239 Some(cstr_to_str(debug_info))
242 }
243 },
244 }
245 }
246
247 pub fn kind(&self) -> ErrorType {
249 match self.0 {
250 Context::Bindings(error_type, _, _) => error_type,
251 Context::Application(_) => ErrorType::Application,
252 Context::Code(code, _) => unsafe { ErrorType::from(s2n_error_get_type(code)) },
253 }
254 }
255
256 pub fn source(&self) -> ErrorSource {
257 match self.0 {
258 Context::Bindings(_, _, _) => ErrorSource::Bindings,
259 Context::Application(_) => ErrorSource::Application,
260 Context::Code(_, _) => ErrorSource::Library,
261 }
262 }
263
264 #[allow(clippy::borrowed_box)]
265 pub fn application_error(&self) -> Option<&Box<dyn std::error::Error + Send + Sync + 'static>> {
268 if let Self(Context::Application(err)) = self {
269 Some(err)
270 } else {
271 None
272 }
273 }
274
275 pub fn is_retryable(&self) -> bool {
276 matches!(self.kind(), ErrorType::Blocked)
277 }
278}
279
280#[cfg(feature = "quic")]
281impl Error {
282 pub fn alert(&self) -> Option<u8> {
291 match self.0 {
292 Context::Bindings(_, _, _) | Context::Application(_) => None,
293 Context::Code(code, _) => {
294 let mut alert = 0;
295 let r = unsafe { s2n_error_get_alert(code, &mut alert) };
296 match r.into_result() {
297 Ok(_) => Some(alert),
298 Err(_) => None,
299 }
300 }
301 }
302 }
303}
304
305unsafe fn cstr_to_str(v: *const c_char) -> &'static str {
310 let slice = CStr::from_ptr(v);
311 let bytes = slice.to_bytes();
312 core::str::from_utf8_unchecked(bytes)
313}
314
315impl TryFrom<std::io::Error> for Error {
316 type Error = Error;
317 fn try_from(value: std::io::Error) -> Result<Self, Self::Error> {
318 let io_inner = value.into_inner().ok_or(Error::INVALID_INPUT)?;
319 io_inner
320 .downcast::<Self>()
321 .map(|error| *error)
322 .map_err(|_| Error::INVALID_INPUT)
323 }
324}
325
326impl From<Error> for std::io::Error {
327 fn from(input: Error) -> Self {
328 let kind = match input.kind() {
329 ErrorType::IOError => {
330 if let Context::Code(_, errno) = input.0 {
331 let bare = std::io::Error::from_raw_os_error(errno.0);
332 bare.kind()
333 } else {
334 std::io::ErrorKind::Other
335 }
336 }
337 ErrorType::ConnectionClosed => std::io::ErrorKind::UnexpectedEof,
338 _ => std::io::ErrorKind::Other,
339 };
340 std::io::Error::new(kind, input)
341 }
342}
343
344impl fmt::Debug for Error {
345 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
346 let mut s = f.debug_struct("Error");
347 if let Context::Code(code, _) = self.0 {
348 s.field("code", &code);
349 }
350
351 s.field("name", &self.name());
352 s.field("message", &self.message());
353 s.field("kind", &self.kind());
354 s.field("source", &self.source());
355
356 if let Some(debug) = self.debug() {
357 s.field("debug", &debug);
358 }
359
360 if let Context::Code(_, errno) = self.0 {
364 s.field("errno", &errno.to_string());
365 }
366
367 s.finish()
368 }
369}
370
371impl fmt::Display for Error {
372 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
373 if let Self(Context::Application(err)) = self {
374 err.fmt(f)
375 } else {
376 f.write_str(self.message())
377 }
378 }
379}
380
381impl std::error::Error for Error {
382 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
383 if let Self(Context::Application(err)) = self {
386 err.source()
387 } else {
388 None
389 }
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use crate::{enums::Version, testing::client_hello::CustomError};
397 use errno::set_errno;
398
399 const FAILURE: isize = -1;
400
401 const S2N_IO_ERROR_CODE: s2n_status_code::Type = 1 << 26;
406
407 #[test]
408 fn s2n_io_error_to_std_io_error() -> Result<(), Box<dyn std::error::Error>> {
409 set_errno(Errno(libc::ECONNRESET));
410 unsafe {
411 let s2n_errno_ptr = s2n_errno_location();
412 *s2n_errno_ptr = S2N_IO_ERROR_CODE;
413 }
414
415 let s2n_error = FAILURE.into_result().unwrap_err();
416 assert_eq!(ErrorType::IOError, s2n_error.kind());
417
418 let io_error = std::io::Error::from(s2n_error);
419 assert_eq!(std::io::ErrorKind::ConnectionReset, io_error.kind());
420 assert!(io_error.into_inner().is_some());
421 Ok(())
422 }
423
424 #[test]
425 fn s2n_error_to_std_io_error() -> Result<(), Box<dyn std::error::Error>> {
426 set_errno(Errno(libc::ECONNRESET));
427 unsafe {
428 let s2n_errno_ptr = s2n_errno_location();
429 *s2n_errno_ptr = S2N_IO_ERROR_CODE - 1;
430 }
431
432 let s2n_error = FAILURE.into_result().unwrap_err();
433 assert_ne!(ErrorType::IOError, s2n_error.kind());
434
435 let io_error = std::io::Error::from(s2n_error);
436 assert_eq!(std::io::ErrorKind::Other, io_error.kind());
437 assert!(io_error.into_inner().is_some());
438 Ok(())
439 }
440
441 #[test]
442 fn invalid_input_to_std_io_error() -> Result<(), Box<dyn std::error::Error>> {
443 let s2n_error = Version::try_from(0).unwrap_err();
444 assert_eq!(ErrorType::UsageError, s2n_error.kind());
445
446 let io_error = std::io::Error::from(s2n_error);
447 assert_eq!(std::io::ErrorKind::Other, io_error.kind());
448 assert!(io_error.into_inner().is_some());
449 Ok(())
450 }
451
452 #[test]
453 fn error_source() -> Result<(), Box<dyn std::error::Error>> {
454 let bindings_error = Version::try_from(0).unwrap_err();
455 assert_eq!(ErrorSource::Bindings, bindings_error.source());
456
457 let library_error = FAILURE.into_result().unwrap_err();
458 assert_eq!(ErrorSource::Library, library_error.source());
459
460 Ok(())
461 }
462
463 #[test]
464 fn application_error() {
465 {
467 let error = Error::application(Box::new(CustomError));
468
469 let app_error = error.application_error().unwrap();
470 let _custom_error = app_error.downcast_ref::<CustomError>().unwrap();
471 }
472
473 {
475 let io_error = std::io::Error::new(std::io::ErrorKind::Other, CustomError);
476 let error = Error::application(Box::new(io_error));
477
478 let app_error = error.application_error().unwrap();
479 let io_error = app_error.downcast_ref::<std::io::Error>().unwrap();
480 let _custom_error = io_error
481 .get_ref()
482 .unwrap()
483 .downcast_ref::<CustomError>()
484 .unwrap();
485 }
486 }
487
488 #[test]
489 fn bindings_error() {
490 let name = "TestError";
491 let message = "Custom error for test";
492 let kind = ErrorType::InternalError;
493 let error = Error::bindings(kind, name, message);
494 assert_eq!(error.kind(), kind);
495 assert_eq!(error.name(), name);
496 assert_eq!(error.message(), message);
497 assert_eq!(error.debug(), None);
498 assert_eq!(error.source(), ErrorSource::Bindings);
499 }
500}