1use std::any::TypeId;
4
5use bytes::Bytes;
6use http::StatusCode;
7use serde::Serialize;
8
9use crate::constants::{APPLICATION_JSON, INTERNAL_ERROR_MESSAGE};
10use crate::response::{with_body, IntoResponse, Response};
11
12pub type Result<T, E = Error> = core::result::Result<T, E>;
15
16const VALIDATION_ERROR_CODE: &str = "VALIDATION_ERROR";
18const VALIDATION_ERROR_MESSAGE: &str = "The submitted data failed validation.";
20const GENERIC_ISSUE: &str = "INVALID";
22const TRACE_ID_PREFIX: &str = "req-";
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum ErrorKind {
28 BadRequest,
30 Unauthorized,
32 Forbidden,
34 NotFound,
36 MethodNotAllowed,
38 Conflict,
40 PayloadTooLarge,
42 Unprocessable,
44 TooManyRequests,
46 Internal,
48 ServiceUnavailable,
50 GatewayTimeout,
52}
53
54impl ErrorKind {
55 pub fn status(self) -> StatusCode {
57 match self {
58 ErrorKind::BadRequest => StatusCode::BAD_REQUEST,
59 ErrorKind::Unauthorized => StatusCode::UNAUTHORIZED,
60 ErrorKind::Forbidden => StatusCode::FORBIDDEN,
61 ErrorKind::NotFound => StatusCode::NOT_FOUND,
62 ErrorKind::MethodNotAllowed => StatusCode::METHOD_NOT_ALLOWED,
63 ErrorKind::Conflict => StatusCode::CONFLICT,
64 ErrorKind::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE,
65 ErrorKind::Unprocessable => StatusCode::UNPROCESSABLE_ENTITY,
66 ErrorKind::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
67 ErrorKind::Internal => StatusCode::INTERNAL_SERVER_ERROR,
68 ErrorKind::ServiceUnavailable => StatusCode::SERVICE_UNAVAILABLE,
69 ErrorKind::GatewayTimeout => StatusCode::GATEWAY_TIMEOUT,
70 }
71 }
72
73 pub fn code(self) -> &'static str {
78 match self {
79 ErrorKind::BadRequest => "BAD_REQUEST",
80 ErrorKind::Unauthorized => "UNAUTHORIZED",
81 ErrorKind::Forbidden => "FORBIDDEN",
82 ErrorKind::NotFound => "NOT_FOUND",
83 ErrorKind::MethodNotAllowed => "METHOD_NOT_ALLOWED",
84 ErrorKind::Conflict => "CONFLICT",
85 ErrorKind::PayloadTooLarge => "PAYLOAD_TOO_LARGE",
86 ErrorKind::Unprocessable => "UNPROCESSABLE_ENTITY",
87 ErrorKind::TooManyRequests => "TOO_MANY_REQUESTS",
88 ErrorKind::Internal => "INTERNAL_SERVER_ERROR",
89 ErrorKind::ServiceUnavailable => "SERVICE_UNAVAILABLE",
90 ErrorKind::GatewayTimeout => "GATEWAY_TIMEOUT",
91 }
92 }
93}
94
95#[derive(Debug)]
102pub struct Error {
103 kind: ErrorKind,
104 code: Option<&'static str>,
105 message: String,
106 source: Option<Box<dyn std::error::Error + Send + Sync>>,
107 source_type: Option<TypeId>,
110 details: Vec<ErrorDetail>,
111}
112
113#[derive(Debug, Clone, Serialize)]
115pub struct ErrorDetail {
116 pub field: String,
118 pub issue: String,
121 pub message: String,
123}
124
125impl ErrorDetail {
126 pub fn new(
128 field: impl Into<String>,
129 issue: impl Into<String>,
130 message: impl Into<String>,
131 ) -> Self {
132 Self {
133 field: field.into(),
134 issue: issue.into(),
135 message: message.into(),
136 }
137 }
138}
139
140impl Error {
141 pub fn new(kind: ErrorKind, message: impl Into<String>) -> Self {
143 Self {
144 kind,
145 code: None,
146 message: message.into(),
147 source: None,
148 source_type: None,
149 details: Vec::new(),
150 }
151 }
152
153 pub fn bad_request(message: impl Into<String>) -> Self {
155 Self::new(ErrorKind::BadRequest, message)
156 }
157
158 pub fn unauthorized(message: impl Into<String>) -> Self {
160 Self::new(ErrorKind::Unauthorized, message)
161 }
162
163 pub fn forbidden(message: impl Into<String>) -> Self {
165 Self::new(ErrorKind::Forbidden, message)
166 }
167
168 pub fn not_found(message: impl Into<String>) -> Self {
170 Self::new(ErrorKind::NotFound, message)
171 }
172
173 pub fn method_not_allowed(message: impl Into<String>) -> Self {
175 Self::new(ErrorKind::MethodNotAllowed, message)
176 }
177
178 pub fn conflict(message: impl Into<String>) -> Self {
180 Self::new(ErrorKind::Conflict, message)
181 }
182
183 pub fn unprocessable(message: impl Into<String>) -> Self {
185 Self::new(ErrorKind::Unprocessable, message)
186 }
187
188 pub fn payload_too_large(message: impl Into<String>) -> Self {
190 Self::new(ErrorKind::PayloadTooLarge, message)
191 }
192
193 pub fn too_many_requests(message: impl Into<String>) -> Self {
195 Self::new(ErrorKind::TooManyRequests, message)
196 }
197
198 pub fn internal(message: impl Into<String>) -> Self {
202 Self::new(ErrorKind::Internal, message)
203 }
204
205 pub fn service_unavailable(message: impl Into<String>) -> Self {
207 Self::new(ErrorKind::ServiceUnavailable, message)
208 }
209
210 pub fn gateway_timeout(message: impl Into<String>) -> Self {
212 Self::new(ErrorKind::GatewayTimeout, message)
213 }
214
215 pub fn with_code(mut self, code: &'static str) -> Self {
217 self.code = Some(code);
218 self
219 }
220
221 pub fn with_source<E>(mut self, source: E) -> Self
228 where
229 E: std::error::Error + Send + Sync + 'static,
230 {
231 self.source = Some(Box::new(source));
232 self.source_type = Some(TypeId::of::<E>());
233 self
234 }
235
236 pub fn with_details(mut self, details: Vec<ErrorDetail>) -> Self {
238 self.details = details;
239 self
240 }
241
242 pub fn from_garde_report(report: garde::error::Report) -> Self {
248 let details = report
249 .iter()
250 .map(|(path, error)| {
251 let message = error.to_string();
252 ErrorDetail::new(path.to_string(), classify_issue(&message), message)
253 })
254 .collect();
255 Self::unprocessable(VALIDATION_ERROR_MESSAGE)
256 .with_code(VALIDATION_ERROR_CODE)
257 .with_details(details)
258 }
259
260 pub fn kind(&self) -> ErrorKind {
262 self.kind
263 }
264
265 pub fn code(&self) -> &str {
267 self.code.unwrap_or_else(|| self.kind.code())
268 }
269
270 pub(crate) fn static_code(&self) -> &'static str {
275 self.code.unwrap_or_else(|| self.kind.code())
276 }
277
278 pub fn details(&self) -> &[ErrorDetail] {
280 &self.details
281 }
282
283 pub fn message(&self) -> &str {
285 &self.message
286 }
287
288 pub(crate) fn source_type(&self) -> Option<TypeId> {
292 self.source_type
293 }
294
295 pub(crate) fn is_validation(&self) -> bool {
297 self.code() == VALIDATION_ERROR_CODE
298 }
299
300 pub fn take_source<E>(&mut self) -> Option<E>
308 where
309 E: std::error::Error + Send + Sync + 'static,
310 {
311 if self.source_type != Some(TypeId::of::<E>()) {
312 return None;
313 }
314 let source = self.source.take()?;
315 self.source_type = None;
316 match source.downcast::<E>() {
317 Ok(typed) => Some(*typed),
318 Err(restored) => {
319 self.source = Some(restored);
321 self.source_type = Some(TypeId::of::<E>());
322 None
323 }
324 }
325 }
326}
327
328impl std::fmt::Display for Error {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 write!(f, "{}: {}", self.code(), self.message)
331 }
332}
333
334impl std::error::Error for Error {
335 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
336 self.source
337 .as_ref()
338 .map(|boxed| boxed.as_ref() as &(dyn std::error::Error + 'static))
339 }
340}
341
342#[derive(Serialize)]
344struct ErrorBody<'a> {
345 status: u16,
346 code: &'a str,
347 title: &'a str,
348 message: &'a str,
349 #[serde(skip_serializing_if = "slice_is_empty")]
350 details: &'a [ErrorDetail],
351 #[serde(rename = "traceId")]
352 trace_id: &'a str,
353 timestamp: String,
354}
355
356fn slice_is_empty(details: &&[ErrorDetail]) -> bool {
358 details.is_empty()
359}
360
361const FALLBACK_ERROR_BODY: &[u8] = br#"{"status":500,"code":"INTERNAL_SERVER_ERROR","title":"Internal Server Error","message":"Internal server error"}"#;
363
364impl IntoResponse for Error {
365 fn into_response(self) -> Response {
366 let status = self.kind.status();
367 let trace_id = generate_trace_id();
368
369 let message: &str = if status.is_server_error() {
373 log_server_error(&self, &trace_id);
374 INTERNAL_ERROR_MESSAGE
375 } else {
376 &self.message
377 };
378
379 let details: &[ErrorDetail] = if status.is_server_error() {
381 &[]
382 } else {
383 &self.details
384 };
385
386 let body = ErrorBody {
387 status: status.as_u16(),
388 code: self.code(),
389 title: status.canonical_reason().unwrap_or("Error"),
390 message,
391 details,
392 trace_id: &trace_id,
393 timestamp: now_rfc3339(),
394 };
395
396 let mut response = match serde_json::to_vec(&body) {
397 Ok(buffer) => with_body(status, APPLICATION_JSON, Bytes::from(buffer)),
398 Err(_) => with_body(
399 status,
400 APPLICATION_JSON,
401 Bytes::from_static(FALLBACK_ERROR_BODY),
402 ),
403 };
404 response.headers_mut().insert(
407 http::header::CACHE_CONTROL,
408 http::HeaderValue::from_static("no-store"),
409 );
410 response
411 }
412}
413
414fn generate_trace_id() -> String {
419 format!("{TRACE_ID_PREFIX}{}", uuid::Uuid::new_v4())
420}
421
422fn now_rfc3339() -> String {
424 use time::format_description::well_known::Rfc3339;
425 time::OffsetDateTime::now_utc()
426 .replace_nanosecond(0)
427 .ok()
428 .and_then(|stamp| stamp.format(&Rfc3339).ok())
429 .unwrap_or_default()
430}
431
432fn classify_issue(message: &str) -> &'static str {
437 let lower = message.to_ascii_lowercase();
438 if lower.contains("email") {
439 "INVALID_FORMAT"
440 } else if lower.contains("length is lower") {
441 "TOO_SHORT"
442 } else if lower.contains("length is greater") {
443 "TOO_LONG"
444 } else if lower.contains("must be greater than") {
445 "TOO_SMALL"
446 } else if lower.contains("must be less than") {
447 "TOO_LARGE"
448 } else if lower.contains("lower than") {
449 "TOO_SMALL"
450 } else if lower.contains("greater than") {
451 "TOO_LARGE"
452 } else {
453 GENERIC_ISSUE
454 }
455}
456
457fn log_server_error(error: &Error, trace_id: &str) {
462 match &error.source {
463 Some(source) => eprintln!(
464 "tork: server error [{trace_id}]: {}: {} (cause: {source})",
465 error.code(),
466 error.message,
467 ),
468 None => eprintln!(
469 "tork: server error [{trace_id}]: {}: {}",
470 error.code(),
471 error.message,
472 ),
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479 use crate::response::Response;
480 use http_body_util::BodyExt;
481 use serde_json::Value;
482
483 async fn body_json(response: Response) -> Value {
484 let bytes = response.into_body().collect().await.unwrap().to_bytes();
485 serde_json::from_slice(&bytes).unwrap()
486 }
487
488 #[test]
489 fn status_mapping_matches_kind() {
490 assert_eq!(ErrorKind::Forbidden.status(), StatusCode::FORBIDDEN);
491 assert_eq!(ErrorKind::NotFound.status(), StatusCode::NOT_FOUND);
492 assert_eq!(
493 ErrorKind::Internal.status(),
494 StatusCode::INTERNAL_SERVER_ERROR
495 );
496 assert_eq!(
497 ErrorKind::PayloadTooLarge.status(),
498 StatusCode::PAYLOAD_TOO_LARGE
499 );
500 assert_eq!(
501 ErrorKind::GatewayTimeout.status(),
502 StatusCode::GATEWAY_TIMEOUT
503 );
504 }
505
506 #[tokio::test]
507 async fn client_error_uses_problem_format() {
508 let response = Error::forbidden("Access denied").into_response();
509 assert_eq!(response.status(), StatusCode::FORBIDDEN);
510
511 let body = body_json(response).await;
512 assert_eq!(body["status"], 403);
513 assert_eq!(body["code"], "FORBIDDEN");
514 assert_eq!(body["title"], "Forbidden");
515 assert_eq!(body["message"], "Access denied");
516 assert!(body.get("details").is_none(), "no details expected: {body}");
517 assert!(
518 body["traceId"].as_str().unwrap().starts_with("req-"),
519 "traceId expected: {body}"
520 );
521 assert!(
522 body["timestamp"].as_str().unwrap().ends_with('Z'),
523 "timestamp: {body}"
524 );
525 }
526
527 #[tokio::test]
528 async fn server_error_is_redacted() {
529 let response = Error::internal("database password is hunter2").into_response();
530 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
531
532 let body = body_json(response).await;
533 assert_eq!(body["code"], "INTERNAL_SERVER_ERROR");
534 assert_eq!(body["message"], INTERNAL_ERROR_MESSAGE);
535 assert!(
536 !serde_json::to_string(&body).unwrap().contains("hunter2"),
537 "internal detail must not leak"
538 );
539 assert!(body["traceId"].as_str().unwrap().starts_with("req-"));
541 }
542
543 #[tokio::test]
544 async fn validation_details_are_serialized() {
545 let response = Error::unprocessable(VALIDATION_ERROR_MESSAGE)
546 .with_code(VALIDATION_ERROR_CODE)
547 .with_details(vec![ErrorDetail::new(
548 "price",
549 "TOO_SMALL",
550 "must be greater than 0",
551 )])
552 .into_response();
553 assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
554
555 let body = body_json(response).await;
556 assert_eq!(body["code"], "VALIDATION_ERROR");
557 assert_eq!(body["details"][0]["field"], "price");
558 assert_eq!(body["details"][0]["issue"], "TOO_SMALL");
559 assert_eq!(body["details"][0]["message"], "must be greater than 0");
560 }
561
562 #[derive(Debug, PartialEq)]
563 struct SampleCause(&'static str);
564 impl std::fmt::Display for SampleCause {
565 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566 f.write_str(self.0)
567 }
568 }
569 impl std::error::Error for SampleCause {}
570
571 #[derive(Debug)]
572 struct OtherCause;
573 impl std::fmt::Display for OtherCause {
574 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
575 f.write_str("other")
576 }
577 }
578 impl std::error::Error for OtherCause {}
579
580 #[test]
581 fn with_source_records_the_type() {
582 let error = Error::internal("boom").with_source(SampleCause("cause"));
583 assert_eq!(error.source_type, Some(TypeId::of::<SampleCause>()));
584 }
585
586 #[test]
587 fn take_source_round_trips_the_typed_cause() {
588 let mut error = Error::internal("boom").with_source(SampleCause("cause"));
589 assert_eq!(
590 error.take_source::<SampleCause>(),
591 Some(SampleCause("cause"))
592 );
593 assert_eq!(error.take_source::<SampleCause>(), None);
595 assert_eq!(error.source_type, None);
596 }
597
598 #[test]
599 fn take_source_rejects_a_mismatched_type() {
600 let mut error = Error::internal("boom").with_source(SampleCause("cause"));
601 assert!(error.take_source::<OtherCause>().is_none());
602 assert_eq!(error.source_type, Some(TypeId::of::<SampleCause>()));
604 assert_eq!(
605 error.take_source::<SampleCause>(),
606 Some(SampleCause("cause"))
607 );
608 }
609
610 #[test]
611 fn take_source_is_none_without_a_source() {
612 let mut error = Error::internal("boom");
613 assert!(error.take_source::<SampleCause>().is_none());
614 }
615
616 #[test]
617 fn from_garde_report_classifies_field_errors() {
618 use garde::Validate;
619
620 #[derive(garde::Validate)]
621 struct Sample {
622 #[garde(length(min = 3))]
623 name: String,
624 }
625
626 let report = Sample {
627 name: String::new(),
628 }
629 .validate()
630 .unwrap_err();
631 let error = Error::from_garde_report(report);
632
633 assert_eq!(error.code(), "VALIDATION_ERROR");
634 assert_eq!(error.details().len(), 1);
635 assert_eq!(error.details()[0].field, "name");
636 assert_eq!(error.details()[0].issue, "TOO_SHORT");
637 }
638
639 #[test]
640 fn status_mapping_covers_every_kind() {
641 use ErrorKind::*;
642 assert_eq!(BadRequest.status(), StatusCode::BAD_REQUEST);
643 assert_eq!(Unauthorized.status(), StatusCode::UNAUTHORIZED);
644 assert_eq!(Forbidden.status(), StatusCode::FORBIDDEN);
645 assert_eq!(NotFound.status(), StatusCode::NOT_FOUND);
646 assert_eq!(MethodNotAllowed.status(), StatusCode::METHOD_NOT_ALLOWED);
647 assert_eq!(Conflict.status(), StatusCode::CONFLICT);
648 assert_eq!(Unprocessable.status(), StatusCode::UNPROCESSABLE_ENTITY);
649 assert_eq!(PayloadTooLarge.status(), StatusCode::PAYLOAD_TOO_LARGE);
650 assert_eq!(TooManyRequests.status(), StatusCode::TOO_MANY_REQUESTS);
651 assert_eq!(Internal.status(), StatusCode::INTERNAL_SERVER_ERROR);
652 assert_eq!(ServiceUnavailable.status(), StatusCode::SERVICE_UNAVAILABLE);
653 assert_eq!(GatewayTimeout.status(), StatusCode::GATEWAY_TIMEOUT);
654 }
655
656 #[test]
657 fn code_mapping_covers_every_kind() {
658 use ErrorKind::*;
659 assert_eq!(BadRequest.code(), "BAD_REQUEST");
660 assert_eq!(Unauthorized.code(), "UNAUTHORIZED");
661 assert_eq!(Forbidden.code(), "FORBIDDEN");
662 assert_eq!(NotFound.code(), "NOT_FOUND");
663 assert_eq!(MethodNotAllowed.code(), "METHOD_NOT_ALLOWED");
664 assert_eq!(Conflict.code(), "CONFLICT");
665 assert_eq!(Unprocessable.code(), "UNPROCESSABLE_ENTITY");
666 assert_eq!(PayloadTooLarge.code(), "PAYLOAD_TOO_LARGE");
667 assert_eq!(TooManyRequests.code(), "TOO_MANY_REQUESTS");
668 assert_eq!(Internal.code(), "INTERNAL_SERVER_ERROR");
669 assert_eq!(ServiceUnavailable.code(), "SERVICE_UNAVAILABLE");
670 assert_eq!(GatewayTimeout.code(), "GATEWAY_TIMEOUT");
671 }
672
673 #[test]
674 fn method_not_allowed_constructor_uses_method_not_allowed_kind() {
675 let error = Error::method_not_allowed("GET not allowed");
676 assert_eq!(error.kind(), ErrorKind::MethodNotAllowed);
677 assert_eq!(error.message(), "GET not allowed");
678 }
679
680 #[test]
681 fn conflict_constructor_uses_conflict_kind() {
682 let error = Error::conflict("duplicate key");
683 assert_eq!(error.kind(), ErrorKind::Conflict);
684 assert_eq!(error.message(), "duplicate key");
685 }
686
687 #[test]
688 fn too_many_requests_constructor_uses_too_many_requests_kind() {
689 let error = Error::too_many_requests("slow down");
690 assert_eq!(error.kind(), ErrorKind::TooManyRequests);
691 assert_eq!(error.message(), "slow down");
692 }
693
694 #[test]
695 fn service_unavailable_constructor_uses_service_unavailable_kind() {
696 let error = Error::service_unavailable("maintenance");
697 assert_eq!(error.kind(), ErrorKind::ServiceUnavailable);
698 assert_eq!(error.message(), "maintenance");
699 }
700
701 #[test]
702 fn error_trait_source_returns_attached_source() {
703 use std::error::Error as _;
704 let error = Error::internal("boom").with_source(SampleCause("inner"));
705 let source = error.source().expect("source should be present");
706 assert_eq!(source.to_string(), "inner");
707 }
708
709 #[test]
710 fn error_trait_source_is_none_when_unset() {
711 use std::error::Error as _;
712 let error = Error::internal("boom");
713 assert!(error.source().is_none());
714 }
715
716 #[test]
717 fn take_source_restores_state_when_downcast_defensively_fails() {
718 let mut error = Error::internal("boom");
722 error.source = Some(Box::new(OtherCause));
723 error.source_type = Some(TypeId::of::<SampleCause>());
724
725 assert!(error.take_source::<SampleCause>().is_none());
727 assert_eq!(error.source_type, Some(TypeId::of::<SampleCause>()));
728 }
729
730 #[test]
731 fn sample_cause_display_formats_inner_message() {
732 assert_eq!(SampleCause("payload").to_string(), "payload");
733 }
734
735 #[test]
736 fn other_cause_display_formats_inner_message() {
737 assert_eq!(OtherCause.to_string(), "other");
738 }
739
740 #[test]
741 fn fallback_body_constant_is_valid_json() {
742 let parsed: Value = serde_json::from_slice(FALLBACK_ERROR_BODY).unwrap();
743 assert_eq!(parsed["status"], 500);
744 assert_eq!(parsed["code"], "INTERNAL_SERVER_ERROR");
745 }
746
747 #[test]
748 fn classify_issue_recognizes_email_format() {
749 assert_eq!(classify_issue("email is not valid"), "INVALID_FORMAT");
750 assert_eq!(classify_issue("Email is invalid"), "INVALID_FORMAT");
751 }
752
753 #[test]
754 fn classify_issue_recognizes_too_long() {
755 assert_eq!(classify_issue("length is greater than 10"), "TOO_LONG");
756 }
757
758 #[test]
759 fn classify_issue_recognizes_strict_numeric_bounds() {
760 assert_eq!(classify_issue("value must be greater than 0"), "TOO_SMALL");
761 assert_eq!(classify_issue("value must be less than 100"), "TOO_LARGE");
762 }
763
764 #[test]
765 fn classify_issue_falls_back_to_generic() {
766 assert_eq!(classify_issue("something unrelated"), "INVALID");
767 assert_eq!(classify_issue(""), "INVALID");
768 }
769}