1use http::StatusCode;
61use serde::Serialize;
62use std::fmt;
63use std::sync::OnceLock;
64use uuid::Uuid;
65
66pub type Result<T, E = ApiError> = std::result::Result<T, E>;
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
89pub enum Environment {
90 #[default]
92 Development,
93 Production,
95}
96
97impl Environment {
98 pub fn from_env() -> Self {
115 match std::env::var("RUSTAPI_ENV")
116 .map(|s| s.to_lowercase())
117 .as_deref()
118 {
119 Ok("production") | Ok("prod") => Environment::Production,
120 _ => Environment::Development,
121 }
122 }
123
124 pub fn is_production(&self) -> bool {
126 matches!(self, Environment::Production)
127 }
128
129 pub fn is_development(&self) -> bool {
131 matches!(self, Environment::Development)
132 }
133}
134
135impl fmt::Display for Environment {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 match self {
138 Environment::Development => write!(f, "development"),
139 Environment::Production => write!(f, "production"),
140 }
141 }
142}
143
144static ENVIRONMENT: OnceLock<Environment> = OnceLock::new();
146
147pub fn get_environment() -> Environment {
152 *ENVIRONMENT.get_or_init(Environment::from_env)
153}
154
155#[cfg(test)]
160#[allow(dead_code)]
161pub fn set_environment_for_test(env: Environment) -> Result<(), Environment> {
162 ENVIRONMENT.set(env)
163}
164
165pub fn generate_error_id() -> String {
180 format!("err_{}", Uuid::new_v4().simple())
181}
182
183#[derive(Debug, Clone)]
206pub struct ApiError {
207 pub status: StatusCode,
209 pub error_type: String,
211 pub message: String,
213 pub fields: Option<Vec<FieldError>>,
215 pub(crate) internal: Option<String>,
217}
218
219#[derive(Debug, Clone, Serialize)]
221pub struct FieldError {
222 pub field: String,
224 pub code: String,
226 pub message: String,
228}
229
230impl ApiError {
231 pub fn new(
233 status: StatusCode,
234 error_type: impl Into<String>,
235 message: impl Into<String>,
236 ) -> Self {
237 Self {
238 status,
239 error_type: error_type.into(),
240 message: message.into(),
241 fields: None,
242 internal: None,
243 }
244 }
245
246 pub fn validation(fields: Vec<FieldError>) -> Self {
248 Self {
249 status: StatusCode::UNPROCESSABLE_ENTITY,
250 error_type: "validation_error".to_string(),
251 message: "Request validation failed".to_string(),
252 fields: Some(fields),
253 internal: None,
254 }
255 }
256
257 pub fn bad_request(message: impl Into<String>) -> Self {
259 Self::new(StatusCode::BAD_REQUEST, "bad_request", message)
260 }
261
262 pub fn unauthorized(message: impl Into<String>) -> Self {
264 Self::new(StatusCode::UNAUTHORIZED, "unauthorized", message)
265 }
266
267 pub fn forbidden(message: impl Into<String>) -> Self {
269 Self::new(StatusCode::FORBIDDEN, "forbidden", message)
270 }
271
272 pub fn not_found(message: impl Into<String>) -> Self {
274 Self::new(StatusCode::NOT_FOUND, "not_found", message)
275 }
276
277 pub fn conflict(message: impl Into<String>) -> Self {
279 Self::new(StatusCode::CONFLICT, "conflict", message)
280 }
281
282 pub fn internal(message: impl Into<String>) -> Self {
284 Self::new(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", message)
285 }
286
287 pub fn with_internal(mut self, details: impl Into<String>) -> Self {
289 self.internal = Some(details.into());
290 self
291 }
292}
293
294impl fmt::Display for ApiError {
295 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296 write!(f, "{}: {}", self.error_type, self.message)
297 }
298}
299
300impl std::error::Error for ApiError {}
301
302#[derive(Serialize)]
304pub struct ErrorResponse {
305 pub error: ErrorBody,
306 pub error_id: String,
308 #[serde(skip_serializing_if = "Option::is_none")]
309 pub request_id: Option<String>,
310}
311
312#[derive(Serialize)]
313pub struct ErrorBody {
314 #[serde(rename = "type")]
315 pub error_type: String,
316 pub message: String,
317 #[serde(skip_serializing_if = "Option::is_none")]
318 pub fields: Option<Vec<FieldError>>,
319}
320
321impl ErrorResponse {
322 pub fn from_api_error(err: ApiError, env: Environment) -> Self {
332 let error_id = generate_error_id();
333
334 if err.status.is_server_error() {
336 crate::trace_error!(
337 error_id = %error_id,
338 error_type = %err.error_type,
339 message = %err.message,
340 status = %err.status.as_u16(),
341 internal = ?err.internal,
342 environment = %env,
343 "Server error occurred"
344 );
345 } else if err.status.is_client_error() {
346 crate::trace_warn!(
347 error_id = %error_id,
348 error_type = %err.error_type,
349 message = %err.message,
350 status = %err.status.as_u16(),
351 environment = %env,
352 "Client error occurred"
353 );
354 } else {
355 crate::trace_info!(
356 error_id = %error_id,
357 error_type = %err.error_type,
358 message = %err.message,
359 status = %err.status.as_u16(),
360 environment = %env,
361 "Error response generated"
362 );
363 }
364
365 let (message, fields) = if env.is_production() && err.status.is_server_error() {
367 let masked_message = "An internal error occurred".to_string();
370 let fields = if err.error_type == "validation_error" {
372 err.fields
373 } else {
374 None
375 };
376 (masked_message, fields)
377 } else {
378 (err.message, err.fields)
380 };
381
382 Self {
383 error: ErrorBody {
384 error_type: err.error_type,
385 message,
386 fields,
387 },
388 error_id,
389 request_id: None,
390 }
391 }
392}
393
394impl From<ApiError> for ErrorResponse {
395 fn from(err: ApiError) -> Self {
396 let env = get_environment();
398 Self::from_api_error(err, env)
399 }
400}
401
402impl From<serde_json::Error> for ApiError {
404 fn from(err: serde_json::Error) -> Self {
405 ApiError::bad_request(format!("Invalid JSON: {}", err))
406 }
407}
408
409impl From<crate::json::JsonError> for ApiError {
410 fn from(err: crate::json::JsonError) -> Self {
411 ApiError::bad_request(format!("Invalid JSON: {}", err))
412 }
413}
414
415impl From<std::io::Error> for ApiError {
416 fn from(err: std::io::Error) -> Self {
417 ApiError::internal("I/O error").with_internal(err.to_string())
418 }
419}
420
421impl From<hyper::Error> for ApiError {
422 fn from(err: hyper::Error) -> Self {
423 ApiError::internal("HTTP error").with_internal(err.to_string())
424 }
425}
426
427impl From<rustapi_validate::ValidationError> for ApiError {
428 fn from(err: rustapi_validate::ValidationError) -> Self {
429 let fields = err
430 .fields
431 .into_iter()
432 .map(|f| FieldError {
433 field: f.field,
434 code: f.code,
435 message: f.message,
436 })
437 .collect();
438
439 ApiError::validation(fields)
440 }
441}
442
443impl ApiError {
444 pub fn from_validation_error(err: rustapi_validate::ValidationError) -> Self {
446 err.into()
447 }
448
449 pub fn service_unavailable(message: impl Into<String>) -> Self {
451 Self::new(
452 StatusCode::SERVICE_UNAVAILABLE,
453 "service_unavailable",
454 message,
455 )
456 }
457}
458
459#[cfg(feature = "sqlx")]
461impl From<sqlx::Error> for ApiError {
462 fn from(err: sqlx::Error) -> Self {
463 match &err {
464 sqlx::Error::PoolTimedOut => {
466 ApiError::service_unavailable("Database connection pool exhausted")
467 .with_internal(err.to_string())
468 }
469
470 sqlx::Error::PoolClosed => {
472 ApiError::service_unavailable("Database connection pool is closed")
473 .with_internal(err.to_string())
474 }
475
476 sqlx::Error::RowNotFound => ApiError::not_found("Resource not found"),
478
479 sqlx::Error::Database(db_err) => {
481 if let Some(code) = db_err.code() {
484 let code_str = code.as_ref();
485 if code_str == "23505" || code_str == "1062" || code_str == "2067" {
486 return ApiError::conflict("Resource already exists")
487 .with_internal(db_err.to_string());
488 }
489
490 if code_str == "23503" || code_str == "1452" || code_str == "787" {
493 return ApiError::bad_request("Referenced resource does not exist")
494 .with_internal(db_err.to_string());
495 }
496
497 if code_str == "23514" {
500 return ApiError::bad_request("Data validation failed")
501 .with_internal(db_err.to_string());
502 }
503 }
504
505 ApiError::internal("Database error").with_internal(db_err.to_string())
507 }
508
509 sqlx::Error::Io(_) => ApiError::service_unavailable("Database connection error")
511 .with_internal(err.to_string()),
512
513 sqlx::Error::Tls(_) => {
515 ApiError::service_unavailable("Database TLS error").with_internal(err.to_string())
516 }
517
518 sqlx::Error::Protocol(_) => {
520 ApiError::internal("Database protocol error").with_internal(err.to_string())
521 }
522
523 sqlx::Error::TypeNotFound { .. } => {
525 ApiError::internal("Database type error").with_internal(err.to_string())
526 }
527
528 sqlx::Error::ColumnNotFound(_) => {
529 ApiError::internal("Database column not found").with_internal(err.to_string())
530 }
531
532 sqlx::Error::ColumnIndexOutOfBounds { .. } => {
533 ApiError::internal("Database column index error").with_internal(err.to_string())
534 }
535
536 sqlx::Error::ColumnDecode { .. } => {
537 ApiError::internal("Database decode error").with_internal(err.to_string())
538 }
539
540 sqlx::Error::Configuration(_) => {
542 ApiError::internal("Database configuration error").with_internal(err.to_string())
543 }
544
545 sqlx::Error::Migrate(_) => {
547 ApiError::internal("Database migration error").with_internal(err.to_string())
548 }
549
550 _ => ApiError::internal("Database error").with_internal(err.to_string()),
552 }
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use proptest::prelude::*;
560 use std::collections::HashSet;
561
562 proptest! {
570 #![proptest_config(ProptestConfig::with_cases(100))]
571
572 #[test]
573 fn prop_error_id_uniqueness(
574 num_errors in 10usize..200,
576 ) {
577 let error_ids: Vec<String> = (0..num_errors)
579 .map(|_| generate_error_id())
580 .collect();
581
582 let unique_ids: HashSet<&String> = error_ids.iter().collect();
584
585 prop_assert_eq!(
587 unique_ids.len(),
588 error_ids.len(),
589 "Generated {} error IDs but only {} were unique",
590 error_ids.len(),
591 unique_ids.len()
592 );
593
594 for id in &error_ids {
596 prop_assert!(
597 id.starts_with("err_"),
598 "Error ID '{}' does not start with 'err_'",
599 id
600 );
601
602 let uuid_part = &id[4..];
604 prop_assert_eq!(
605 uuid_part.len(),
606 32,
607 "UUID part '{}' should be 32 characters, got {}",
608 uuid_part,
609 uuid_part.len()
610 );
611
612 prop_assert!(
614 uuid_part.chars().all(|c| c.is_ascii_hexdigit()),
615 "UUID part '{}' contains non-hex characters",
616 uuid_part
617 );
618 }
619 }
620 }
621
622 proptest! {
629 #![proptest_config(ProptestConfig::with_cases(100))]
630
631 #[test]
632 fn prop_error_response_contains_error_id(
633 error_type in "[a-z_]{1,20}",
634 message in "[a-zA-Z0-9 ]{1,100}",
635 ) {
636 let api_error = ApiError::new(StatusCode::INTERNAL_SERVER_ERROR, error_type, message);
637 let error_response = ErrorResponse::from(api_error);
638
639 prop_assert!(
641 error_response.error_id.starts_with("err_"),
642 "Error ID '{}' does not start with 'err_'",
643 error_response.error_id
644 );
645
646 let uuid_part = &error_response.error_id[4..];
647 prop_assert_eq!(uuid_part.len(), 32);
648 prop_assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit()));
649 }
650 }
651
652 #[test]
653 fn test_error_id_format() {
654 let error_id = generate_error_id();
655
656 assert!(error_id.starts_with("err_"));
658
659 assert_eq!(error_id.len(), 36);
661
662 let uuid_part = &error_id[4..];
664 assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit()));
665 }
666
667 #[test]
668 fn test_error_response_includes_error_id() {
669 let api_error = ApiError::bad_request("test error");
670 let error_response = ErrorResponse::from(api_error);
671
672 assert!(error_response.error_id.starts_with("err_"));
674 assert_eq!(error_response.error_id.len(), 36);
675 }
676
677 #[test]
678 fn test_error_id_in_json_serialization() {
679 let api_error = ApiError::internal("test error");
680 let error_response = ErrorResponse::from(api_error);
681
682 let json = serde_json::to_string(&error_response).unwrap();
683
684 assert!(json.contains("\"error_id\":"));
686 assert!(json.contains("err_"));
687 }
688
689 #[test]
690 fn test_multiple_error_ids_are_unique() {
691 let ids: Vec<String> = (0..1000).map(|_| generate_error_id()).collect();
692 let unique: HashSet<_> = ids.iter().collect();
693
694 assert_eq!(ids.len(), unique.len(), "All error IDs should be unique");
695 }
696
697 proptest! {
705 #![proptest_config(ProptestConfig::with_cases(100))]
706
707 #[test]
708 fn prop_production_error_masking(
709 sensitive_message in "[a-zA-Z0-9_]{10,200}",
712 internal_details in "[a-zA-Z0-9_]{10,200}",
713 status_code in prop::sample::select(vec![500u16, 501, 502, 503, 504, 505]),
715 ) {
716 let api_error = ApiError::new(
718 StatusCode::from_u16(status_code).unwrap(),
719 "internal_error",
720 sensitive_message.clone()
721 ).with_internal(internal_details.clone());
722
723 let error_response = ErrorResponse::from_api_error(api_error, Environment::Production);
725
726 prop_assert_eq!(
728 &error_response.error.message,
729 "An internal error occurred",
730 "Production 5xx error should have masked message, got: {}",
731 &error_response.error.message
732 );
733
734 if sensitive_message.len() >= 10 {
737 prop_assert!(
738 !error_response.error.message.contains(&sensitive_message),
739 "Production error response should not contain original message"
740 );
741 }
742
743 let json = serde_json::to_string(&error_response).unwrap();
745 if internal_details.len() >= 10 {
746 prop_assert!(
747 !json.contains(&internal_details),
748 "Production error response should not contain internal details"
749 );
750 }
751
752 prop_assert!(
754 error_response.error_id.starts_with("err_"),
755 "Error ID should be present in production error response"
756 );
757 }
758 }
759
760 proptest! {
768 #![proptest_config(ProptestConfig::with_cases(100))]
769
770 #[test]
771 fn prop_development_error_details(
772 error_message in "[a-zA-Z0-9 ]{1,100}",
774 error_type in "[a-z_]{1,20}",
775 status_code in prop::sample::select(vec![400u16, 401, 403, 404, 500, 502, 503]),
777 ) {
778 let api_error = ApiError::new(
780 StatusCode::from_u16(status_code).unwrap(),
781 error_type.clone(),
782 error_message.clone()
783 );
784
785 let error_response = ErrorResponse::from_api_error(api_error, Environment::Development);
787
788 prop_assert_eq!(
790 error_response.error.message,
791 error_message,
792 "Development error should preserve original message"
793 );
794
795 prop_assert_eq!(
797 error_response.error.error_type,
798 error_type,
799 "Development error should preserve error type"
800 );
801
802 prop_assert!(
804 error_response.error_id.starts_with("err_"),
805 "Error ID should be present in development error response"
806 );
807 }
808 }
809
810 proptest! {
818 #![proptest_config(ProptestConfig::with_cases(100))]
819
820 #[test]
821 fn prop_validation_error_field_details(
822 field_name in "[a-z_]{1,20}",
824 field_code in "[a-z_]{1,15}",
825 field_message in "[a-zA-Z0-9 ]{1,50}",
826 is_production in proptest::bool::ANY,
828 ) {
829 let env = if is_production {
830 Environment::Production
831 } else {
832 Environment::Development
833 };
834
835 let field_error = FieldError {
837 field: field_name.clone(),
838 code: field_code.clone(),
839 message: field_message.clone(),
840 };
841 let api_error = ApiError::validation(vec![field_error]);
842
843 let error_response = ErrorResponse::from_api_error(api_error, env);
845
846 prop_assert!(
848 error_response.error.fields.is_some(),
849 "Validation error should always include fields in {} mode",
850 env
851 );
852
853 let fields = error_response.error.fields.as_ref().unwrap();
854 prop_assert_eq!(
855 fields.len(),
856 1,
857 "Should have exactly one field error"
858 );
859
860 let field = &fields[0];
861
862 prop_assert_eq!(
864 &field.field,
865 &field_name,
866 "Field name should be preserved in {} mode",
867 env
868 );
869
870 prop_assert_eq!(
872 &field.code,
873 &field_code,
874 "Field code should be preserved in {} mode",
875 env
876 );
877
878 prop_assert_eq!(
880 &field.message,
881 &field_message,
882 "Field message should be preserved in {} mode",
883 env
884 );
885
886 let json = serde_json::to_string(&error_response).unwrap();
888 prop_assert!(
889 json.contains(&field_name),
890 "JSON should contain field name in {} mode",
891 env
892 );
893 prop_assert!(
894 json.contains(&field_code),
895 "JSON should contain field code in {} mode",
896 env
897 );
898 prop_assert!(
899 json.contains(&field_message),
900 "JSON should contain field message in {} mode",
901 env
902 );
903 }
904 }
905
906 #[test]
912 fn test_environment_from_env_production() {
913 assert!(matches!(
918 match "production".to_lowercase().as_str() {
919 "production" | "prod" => Environment::Production,
920 _ => Environment::Development,
921 },
922 Environment::Production
923 ));
924
925 assert!(matches!(
926 match "prod".to_lowercase().as_str() {
927 "production" | "prod" => Environment::Production,
928 _ => Environment::Development,
929 },
930 Environment::Production
931 ));
932
933 assert!(matches!(
934 match "PRODUCTION".to_lowercase().as_str() {
935 "production" | "prod" => Environment::Production,
936 _ => Environment::Development,
937 },
938 Environment::Production
939 ));
940
941 assert!(matches!(
942 match "PROD".to_lowercase().as_str() {
943 "production" | "prod" => Environment::Production,
944 _ => Environment::Development,
945 },
946 Environment::Production
947 ));
948 }
949
950 #[test]
951 fn test_environment_from_env_development() {
952 assert!(matches!(
957 match "development".to_lowercase().as_str() {
958 "production" | "prod" => Environment::Production,
959 _ => Environment::Development,
960 },
961 Environment::Development
962 ));
963
964 assert!(matches!(
965 match "dev".to_lowercase().as_str() {
966 "production" | "prod" => Environment::Production,
967 _ => Environment::Development,
968 },
969 Environment::Development
970 ));
971
972 assert!(matches!(
973 match "test".to_lowercase().as_str() {
974 "production" | "prod" => Environment::Production,
975 _ => Environment::Development,
976 },
977 Environment::Development
978 ));
979
980 assert!(matches!(
981 match "anything_else".to_lowercase().as_str() {
982 "production" | "prod" => Environment::Production,
983 _ => Environment::Development,
984 },
985 Environment::Development
986 ));
987 }
988
989 #[test]
990 fn test_environment_default_is_development() {
991 assert_eq!(Environment::default(), Environment::Development);
993 }
994
995 #[test]
996 fn test_environment_display() {
997 assert_eq!(format!("{}", Environment::Development), "development");
998 assert_eq!(format!("{}", Environment::Production), "production");
999 }
1000
1001 #[test]
1002 fn test_environment_is_methods() {
1003 assert!(Environment::Production.is_production());
1004 assert!(!Environment::Production.is_development());
1005 assert!(Environment::Development.is_development());
1006 assert!(!Environment::Development.is_production());
1007 }
1008
1009 #[test]
1010 fn test_production_masks_5xx_errors() {
1011 let error =
1012 ApiError::internal("Sensitive database connection string: postgres://user:pass@host");
1013 let response = ErrorResponse::from_api_error(error, Environment::Production);
1014
1015 assert_eq!(response.error.message, "An internal error occurred");
1016 assert!(!response.error.message.contains("postgres"));
1017 }
1018
1019 #[test]
1020 fn test_production_shows_4xx_errors() {
1021 let error = ApiError::bad_request("Invalid email format");
1022 let response = ErrorResponse::from_api_error(error, Environment::Production);
1023
1024 assert_eq!(response.error.message, "Invalid email format");
1026 }
1027
1028 #[test]
1029 fn test_development_shows_all_errors() {
1030 let error = ApiError::internal("Detailed error: connection refused to 192.168.1.1:5432");
1031 let response = ErrorResponse::from_api_error(error, Environment::Development);
1032
1033 assert_eq!(
1034 response.error.message,
1035 "Detailed error: connection refused to 192.168.1.1:5432"
1036 );
1037 }
1038
1039 #[test]
1040 fn test_validation_errors_always_show_fields() {
1041 let fields = vec![
1042 FieldError {
1043 field: "email".to_string(),
1044 code: "invalid_format".to_string(),
1045 message: "Invalid email format".to_string(),
1046 },
1047 FieldError {
1048 field: "age".to_string(),
1049 code: "min".to_string(),
1050 message: "Must be at least 18".to_string(),
1051 },
1052 ];
1053
1054 let error = ApiError::validation(fields.clone());
1055
1056 let prod_response = ErrorResponse::from_api_error(error.clone(), Environment::Production);
1058 assert!(prod_response.error.fields.is_some());
1059 let prod_fields = prod_response.error.fields.unwrap();
1060 assert_eq!(prod_fields.len(), 2);
1061 assert_eq!(prod_fields[0].field, "email");
1062 assert_eq!(prod_fields[1].field, "age");
1063
1064 let dev_response = ErrorResponse::from_api_error(error, Environment::Development);
1066 assert!(dev_response.error.fields.is_some());
1067 let dev_fields = dev_response.error.fields.unwrap();
1068 assert_eq!(dev_fields.len(), 2);
1069 }
1070}