1use core::{convert::TryInto, fmt, ptr::NonNull, task::Poll};
5use errno::{errno, Errno};
6use libc::c_char;
7use s2n_tls_sys::*;
8use std::{convert::TryFrom, ffi::CStr};
9
10#[non_exhaustive]
12#[derive(Copy, Clone, Debug, PartialEq)]
13pub enum ErrorType {
14 UnknownErrorType,
15 NoError,
16 IOError,
17 ConnectionClosed,
18 Blocked,
19 Alert,
20 ProtocolError,
21 InternalError,
22 UsageError,
23 Application,
24}
25
26#[non_exhaustive]
27#[derive(Debug, PartialEq)]
28pub enum ErrorSource {
29 Library,
30 Bindings,
31 Application,
32}
33
34impl From<libc::c_int> for ErrorType {
35 fn from(input: libc::c_int) -> Self {
36 match input as s2n_error_type::Type {
37 s2n_error_type::OK => ErrorType::NoError,
38 s2n_error_type::IO => ErrorType::IOError,
39 s2n_error_type::CLOSED => ErrorType::ConnectionClosed,
40 s2n_error_type::BLOCKED => ErrorType::Blocked,
41 s2n_error_type::ALERT => ErrorType::Alert,
42 s2n_error_type::PROTO => ErrorType::ProtocolError,
43 s2n_error_type::INTERNAL => ErrorType::InternalError,
44 s2n_error_type::USAGE => ErrorType::UsageError,
45 _ => ErrorType::UnknownErrorType,
46 }
47 }
48}
49
50enum Context {
51 Bindings(ErrorType, &'static str, &'static str),
52 Code(s2n_status_code::Type, Errno),
53 Application(Box<dyn std::error::Error + Send + Sync + 'static>),
54}
55
56pub struct Error(Context);
57
58pub trait Fallible {
59 type Output;
60
61 fn into_result(self) -> Result<Self::Output, Error>;
62}
63
64impl Fallible for s2n_status_code::Type {
65 type Output = s2n_status_code::Type;
66
67 fn into_result(self) -> Result<Self::Output, Error> {
68 if self >= s2n_status_code::SUCCESS {
69 Ok(self)
70 } else {
71 Err(Error::capture())
72 }
73 }
74}
75
76impl Fallible for isize {
77 type Output = usize;
78
79 fn into_result(self) -> Result<Self::Output, Error> {
80 self.try_into().map_err(|_| Error::capture())
83 }
84}
85
86impl Fallible for u64 {
87 type Output = Self;
88
89 fn into_result(self) -> Result<Self::Output, Error> {
101 if self != Self::MAX {
102 Ok(self)
103 } else {
104 Err(Error::capture())
105 }
106 }
107}
108
109impl<T> Fallible for *mut T {
110 type Output = NonNull<T>;
111
112 fn into_result(self) -> Result<Self::Output, Error> {
113 if let Some(value) = NonNull::new(self) {
114 Ok(value)
115 } else {
116 Err(Error::capture())
117 }
118 }
119}
120
121impl<T> Fallible for *const T {
122 type Output = *const T;
123
124 fn into_result(self) -> Result<Self::Output, Error> {
125 if !self.is_null() {
126 Ok(self)
127 } else {
128 Err(Error::capture())
129 }
130 }
131}
132
133pub trait Pollable {
134 type Output;
135
136 fn into_poll(self) -> Poll<Result<Self::Output, Error>>;
137}
138
139impl<T: Fallible> Pollable for T {
140 type Output = T::Output;
141
142 fn into_poll(self) -> Poll<Result<Self::Output, Error>> {
143 match self.into_result() {
144 Ok(r) => Ok(r).into(),
145 Err(err) if err.is_retryable() => Poll::Pending,
146 Err(err) => Err(err).into(),
147 }
148 }
149}
150
151impl Error {
152 pub(crate) const INVALID_INPUT: Error = Self::bindings(
153 ErrorType::UsageError,
154 "InvalidInput",
155 "An input parameter was incorrect",
156 );
157 pub(crate) const MISSING_WAKER: Error = Self::bindings(
158 ErrorType::UsageError,
159 "MissingWaker",
160 "Tried to perform an asynchronous operation without a configured waker",
161 );
162
163 pub fn io_error(err: std::io::Error) -> Error {
165 let errno = err.raw_os_error().unwrap_or(1);
166 errno::set_errno(errno::Errno(errno));
167 s2n_status_code::FAILURE.into_result().unwrap_err()
168 }
169
170 pub fn application(error: Box<dyn std::error::Error + Send + Sync + 'static>) -> Self {
175 Self(Context::Application(error))
176 }
177
178 pub(crate) const fn bindings(
180 kind: ErrorType,
181 name: &'static str,
182 message: &'static str,
183 ) -> Self {
184 Self(Context::Bindings(kind, name, message))
185 }
186
187 fn capture() -> Self {
188 unsafe {
189 let s2n_errno = s2n_errno_location();
190
191 let code = *s2n_errno;
192
193 *s2n_errno = s2n_error_type::OK as _;
197
198 Self(Context::Code(code, errno()))
199 }
200 }
201
202 pub fn name(&self) -> &'static str {
204 match self.0 {
205 Context::Bindings(_, name, _) => name,
206 Context::Application(_) => "ApplicationError",
207 Context::Code(code, _) => unsafe {
208 cstr_to_str(s2n_strerror_name(code))
210 },
211 }
212 }
213
214 pub fn message(&self) -> &'static str {
216 match self.0 {
217 Context::Bindings(_, _, msg) => msg,
218 Context::Application(_) => "An error occurred while executing application code",
219 Context::Code(code, _) => unsafe {
220 cstr_to_str(s2n_strerror(code, core::ptr::null()))
222 },
223 }
224 }
225
226 pub fn debug(&self) -> Option<&'static str> {
228 match self.0 {
229 Context::Bindings(_, _, _) | Context::Application(_) => None,
230 Context::Code(code, _) => unsafe {
231 let debug_info = s2n_strerror_debug(code, core::ptr::null());
232
233 if debug_info.is_null() {
237 None
238 } else {
239 Some(cstr_to_str(debug_info))
242 }
243 },
244 }
245 }
246
247 pub fn kind(&self) -> ErrorType {
249 match self.0 {
250 Context::Bindings(error_type, _, _) => error_type,
251 Context::Application(_) => ErrorType::Application,
252 Context::Code(code, _) => unsafe { ErrorType::from(s2n_error_get_type(code)) },
253 }
254 }
255
256 pub fn source(&self) -> ErrorSource {
257 match self.0 {
258 Context::Bindings(_, _, _) => ErrorSource::Bindings,
259 Context::Application(_) => ErrorSource::Application,
260 Context::Code(_, _) => ErrorSource::Library,
261 }
262 }
263
264 #[allow(clippy::borrowed_box)]
265 pub fn application_error(&self) -> Option<&Box<dyn std::error::Error + Send + Sync + 'static>> {
268 if let Self(Context::Application(err)) = self {
269 Some(err)
270 } else {
271 None
272 }
273 }
274
275 pub fn is_retryable(&self) -> bool {
276 matches!(self.kind(), ErrorType::Blocked)
277 }
278}
279
280#[cfg(feature = "quic")]
281impl Error {
282 pub fn alert(&self) -> Option<u8> {
291 match self.0 {
292 Context::Bindings(_, _, _) | Context::Application(_) => None,
293 Context::Code(code, _) => {
294 let mut alert = 0;
295 let r = unsafe { s2n_error_get_alert(code, &mut alert) };
296 match r.into_result() {
297 Ok(_) => Some(alert),
298 Err(_) => None,
299 }
300 }
301 }
302 }
303}
304
305unsafe fn cstr_to_str(v: *const c_char) -> &'static str {
310 let slice = CStr::from_ptr(v);
311 let bytes = slice.to_bytes();
312 core::str::from_utf8_unchecked(bytes)
313}
314
315impl TryFrom<std::io::Error> for Error {
316 type Error = Error;
317 fn try_from(value: std::io::Error) -> Result<Self, Self::Error> {
318 let io_inner = value.into_inner().ok_or(Error::INVALID_INPUT)?;
319 io_inner
320 .downcast::<Self>()
321 .map(|error| *error)
322 .map_err(|_| Error::INVALID_INPUT)
323 }
324}
325
326impl From<Error> for std::io::Error {
327 fn from(input: Error) -> Self {
328 let kind = match input.kind() {
329 ErrorType::IOError => {
330 if let Context::Code(_, errno) = input.0 {
331 let bare = std::io::Error::from_raw_os_error(errno.0);
332 bare.kind()
333 } else {
334 std::io::ErrorKind::Other
335 }
336 }
337 ErrorType::ConnectionClosed => std::io::ErrorKind::UnexpectedEof,
338 _ => std::io::ErrorKind::Other,
339 };
340 std::io::Error::new(kind, input)
341 }
342}
343
344impl fmt::Debug for Error {
345 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
346 if let Self(Context::Application(err)) = self {
349 return err.fmt(f);
350 }
351
352 let mut s = f.debug_struct("Error");
353 if let Context::Code(code, _) = self.0 {
354 s.field("code", &code);
355 }
356
357 s.field("name", &self.name());
358 s.field("message", &self.message());
359 s.field("kind", &self.kind());
360 s.field("source", &self.source());
361
362 if let Some(debug) = self.debug() {
363 s.field("debug", &debug);
364 }
365
366 if let Context::Code(_, errno) = self.0 {
370 s.field("errno", &errno.to_string());
371 }
372
373 s.finish()
374 }
375}
376
377impl fmt::Display for Error {
378 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
379 if let Self(Context::Application(err)) = self {
380 err.fmt(f)
381 } else {
382 f.write_str(self.message())
383 }
384 }
385}
386
387impl std::error::Error for Error {
388 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
389 if let Self(Context::Application(err)) = self {
392 err.source()
393 } else {
394 None
395 }
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use crate::{enums::Version, testing::client_hello::CustomError};
403 use errno::set_errno;
404
405 const FAILURE: isize = -1;
406
407 const S2N_IO_ERROR_CODE: s2n_status_code::Type = 1 << 26;
412
413 #[test]
414 fn s2n_io_error_to_std_io_error() -> Result<(), Box<dyn std::error::Error>> {
415 set_errno(Errno(libc::ECONNRESET));
416 unsafe {
417 let s2n_errno_ptr = s2n_errno_location();
418 *s2n_errno_ptr = S2N_IO_ERROR_CODE;
419 }
420
421 let s2n_error = FAILURE.into_result().unwrap_err();
422 assert_eq!(ErrorType::IOError, s2n_error.kind());
423
424 let io_error = std::io::Error::from(s2n_error);
425 assert_eq!(std::io::ErrorKind::ConnectionReset, io_error.kind());
426 assert!(io_error.into_inner().is_some());
427 Ok(())
428 }
429
430 #[test]
431 fn s2n_error_to_std_io_error() -> Result<(), Box<dyn std::error::Error>> {
432 set_errno(Errno(libc::ECONNRESET));
433 unsafe {
434 let s2n_errno_ptr = s2n_errno_location();
435 *s2n_errno_ptr = S2N_IO_ERROR_CODE - 1;
436 }
437
438 let s2n_error = FAILURE.into_result().unwrap_err();
439 assert_ne!(ErrorType::IOError, s2n_error.kind());
440
441 let io_error = std::io::Error::from(s2n_error);
442 assert_eq!(std::io::ErrorKind::Other, io_error.kind());
443 assert!(io_error.into_inner().is_some());
444 Ok(())
445 }
446
447 #[test]
448 fn invalid_input_to_std_io_error() -> Result<(), Box<dyn std::error::Error>> {
449 let s2n_error = Version::try_from(0).unwrap_err();
450 assert_eq!(ErrorType::UsageError, s2n_error.kind());
451
452 let io_error = std::io::Error::from(s2n_error);
453 assert_eq!(std::io::ErrorKind::Other, io_error.kind());
454 assert!(io_error.into_inner().is_some());
455 Ok(())
456 }
457
458 #[test]
459 fn error_source() -> Result<(), Box<dyn std::error::Error>> {
460 let bindings_error = Version::try_from(0).unwrap_err();
461 assert_eq!(ErrorSource::Bindings, bindings_error.source());
462
463 let library_error = FAILURE.into_result().unwrap_err();
464 assert_eq!(ErrorSource::Library, library_error.source());
465
466 Ok(())
467 }
468
469 #[test]
470 fn application_error() {
471 {
473 let error = Error::application(Box::new(CustomError));
474
475 let app_error = error.application_error().unwrap();
476 let _custom_error = app_error.downcast_ref::<CustomError>().unwrap();
477
478 let display = format!("{error}");
479 assert_eq!(display, "custom error");
480 let debug = format!("{error:?}");
481 assert_eq!(debug, "CustomError");
482 }
483
484 {
486 let io_error = std::io::Error::new(std::io::ErrorKind::Other, CustomError);
487 let error = Error::application(Box::new(io_error));
488
489 let app_error = error.application_error().unwrap();
490 let io_error = app_error.downcast_ref::<std::io::Error>().unwrap();
491 let _custom_error = io_error
492 .get_ref()
493 .unwrap()
494 .downcast_ref::<CustomError>()
495 .unwrap();
496 }
497 }
498
499 #[test]
500 fn bindings_error() {
501 let name = "TestError";
502 let message = "Custom error for test";
503 let kind = ErrorType::InternalError;
504 let error = Error::bindings(kind, name, message);
505 assert_eq!(error.kind(), kind);
506 assert_eq!(error.name(), name);
507 assert_eq!(error.message(), message);
508 assert_eq!(error.debug(), None);
509 assert_eq!(error.source(), ErrorSource::Bindings);
510 }
511}