1use http::StatusCode;
61use serde::Serialize;
62use std::fmt;
63use std::sync::OnceLock;
64use uuid::Uuid;
65
66pub type Result<T, E = ApiError> = std::result::Result<T, E>;
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
89pub enum Environment {
90 #[default]
92 Development,
93 Production,
95}
96
97impl Environment {
98 pub fn from_env() -> Self {
115 match std::env::var("RUSTAPI_ENV")
116 .map(|s| s.to_lowercase())
117 .as_deref()
118 {
119 Ok("production") | Ok("prod") => Environment::Production,
120 _ => Environment::Development,
121 }
122 }
123
124 pub fn is_production(&self) -> bool {
126 matches!(self, Environment::Production)
127 }
128
129 pub fn is_development(&self) -> bool {
131 matches!(self, Environment::Development)
132 }
133}
134
135impl fmt::Display for Environment {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 match self {
138 Environment::Development => write!(f, "development"),
139 Environment::Production => write!(f, "production"),
140 }
141 }
142}
143
144static ENVIRONMENT: OnceLock<Environment> = OnceLock::new();
146
147pub fn get_environment() -> Environment {
152 *ENVIRONMENT.get_or_init(Environment::from_env)
153}
154
155#[cfg(test)]
160#[allow(dead_code)]
161pub fn set_environment_for_test(env: Environment) -> Result<(), Environment> {
162 ENVIRONMENT.set(env)
163}
164
165pub fn generate_error_id() -> String {
180 format!("err_{}", Uuid::new_v4().simple())
181}
182
183#[derive(Debug, Clone)]
206pub struct ApiError {
207 pub status: StatusCode,
209 pub error_type: String,
211 pub message: String,
213 pub fields: Option<Vec<FieldError>>,
215 pub(crate) internal: Option<String>,
217}
218
219#[derive(Debug, Clone, Serialize)]
221pub struct FieldError {
222 pub field: String,
224 pub code: String,
226 pub message: String,
228}
229
230impl ApiError {
231 pub fn new(
233 status: StatusCode,
234 error_type: impl Into<String>,
235 message: impl Into<String>,
236 ) -> Self {
237 Self {
238 status,
239 error_type: error_type.into(),
240 message: message.into(),
241 fields: None,
242 internal: None,
243 }
244 }
245
246 pub fn validation(fields: Vec<FieldError>) -> Self {
248 Self {
249 status: StatusCode::UNPROCESSABLE_ENTITY,
250 error_type: "validation_error".to_string(),
251 message: "Request validation failed".to_string(),
252 fields: Some(fields),
253 internal: None,
254 }
255 }
256
257 pub fn bad_request(message: impl Into<String>) -> Self {
259 Self::new(StatusCode::BAD_REQUEST, "bad_request", message)
260 }
261
262 pub fn unauthorized(message: impl Into<String>) -> Self {
264 Self::new(StatusCode::UNAUTHORIZED, "unauthorized", message)
265 }
266
267 pub fn forbidden(message: impl Into<String>) -> Self {
269 Self::new(StatusCode::FORBIDDEN, "forbidden", message)
270 }
271
272 pub fn not_found(message: impl Into<String>) -> Self {
274 Self::new(StatusCode::NOT_FOUND, "not_found", message)
275 }
276
277 pub fn conflict(message: impl Into<String>) -> Self {
279 Self::new(StatusCode::CONFLICT, "conflict", message)
280 }
281
282 pub fn internal(message: impl Into<String>) -> Self {
284 Self::new(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", message)
285 }
286
287 pub fn with_internal(mut self, details: impl Into<String>) -> Self {
289 self.internal = Some(details.into());
290 self
291 }
292}
293
294impl fmt::Display for ApiError {
295 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296 write!(f, "{}: {}", self.error_type, self.message)
297 }
298}
299
300impl std::error::Error for ApiError {}
301
302#[derive(Serialize)]
304pub struct ErrorResponse {
305 pub error: ErrorBody,
306 pub error_id: String,
308 #[serde(skip_serializing_if = "Option::is_none")]
309 pub request_id: Option<String>,
310}
311
312#[derive(Serialize)]
313pub struct ErrorBody {
314 #[serde(rename = "type")]
315 pub error_type: String,
316 pub message: String,
317 #[serde(skip_serializing_if = "Option::is_none")]
318 pub fields: Option<Vec<FieldError>>,
319}
320
321impl ErrorResponse {
322 pub fn from_api_error(err: ApiError, env: Environment) -> Self {
332 let error_id = generate_error_id();
333
334 if err.status.is_server_error() {
336 crate::trace_error!(
337 error_id = %error_id,
338 error_type = %err.error_type,
339 message = %err.message,
340 status = %err.status.as_u16(),
341 internal = ?err.internal,
342 environment = %env,
343 "Server error occurred"
344 );
345 } else if err.status.is_client_error() {
346 crate::trace_warn!(
347 error_id = %error_id,
348 error_type = %err.error_type,
349 message = %err.message,
350 status = %err.status.as_u16(),
351 environment = %env,
352 "Client error occurred"
353 );
354 } else {
355 crate::trace_info!(
356 error_id = %error_id,
357 error_type = %err.error_type,
358 message = %err.message,
359 status = %err.status.as_u16(),
360 environment = %env,
361 "Error response generated"
362 );
363 }
364
365 let (message, fields) = if env.is_production() && err.status.is_server_error() {
367 let masked_message = "An internal error occurred".to_string();
370 let fields = if err.error_type == "validation_error" {
372 err.fields
373 } else {
374 None
375 };
376 (masked_message, fields)
377 } else {
378 (err.message, err.fields)
380 };
381
382 Self {
383 error: ErrorBody {
384 error_type: err.error_type,
385 message,
386 fields,
387 },
388 error_id,
389 request_id: None,
390 }
391 }
392}
393
394impl From<ApiError> for ErrorResponse {
395 fn from(err: ApiError) -> Self {
396 let env = get_environment();
398 Self::from_api_error(err, env)
399 }
400}
401
402impl From<serde_json::Error> for ApiError {
404 fn from(err: serde_json::Error) -> Self {
405 ApiError::bad_request(format!("Invalid JSON: {}", err))
406 }
407}
408
409impl From<crate::json::JsonError> for ApiError {
410 fn from(err: crate::json::JsonError) -> Self {
411 ApiError::bad_request(format!("Invalid JSON: {}", err))
412 }
413}
414
415impl From<std::io::Error> for ApiError {
416 fn from(err: std::io::Error) -> Self {
417 ApiError::internal("I/O error").with_internal(err.to_string())
418 }
419}
420
421impl From<hyper::Error> for ApiError {
422 fn from(err: hyper::Error) -> Self {
423 ApiError::internal("HTTP error").with_internal(err.to_string())
424 }
425}
426
427impl From<rustapi_validate::ValidationError> for ApiError {
428 fn from(err: rustapi_validate::ValidationError) -> Self {
429 let fields = err
430 .fields
431 .into_iter()
432 .map(|f| FieldError {
433 field: f.field,
434 code: f.code,
435 message: f.message,
436 })
437 .collect();
438
439 ApiError::validation(fields)
440 }
441}
442
443impl From<rustapi_validate::v2::ValidationErrors> for ApiError {
444 fn from(err: rustapi_validate::v2::ValidationErrors) -> Self {
445 let fields = err
446 .fields
447 .into_iter()
448 .flat_map(|(field, errors)| {
449 errors.into_iter().map(move |e| {
450 let message = e.interpolate_message();
451 FieldError {
452 field: field.clone(),
453 code: e.code,
454 message,
455 }
456 })
457 })
458 .collect();
459
460 ApiError::validation(fields)
461 }
462}
463
464impl ApiError {
465 pub fn from_validation_error(err: rustapi_validate::ValidationError) -> Self {
467 err.into()
468 }
469
470 pub fn service_unavailable(message: impl Into<String>) -> Self {
472 Self::new(
473 StatusCode::SERVICE_UNAVAILABLE,
474 "service_unavailable",
475 message,
476 )
477 }
478}
479
480#[cfg(feature = "sqlx")]
482impl From<sqlx::Error> for ApiError {
483 fn from(err: sqlx::Error) -> Self {
484 match &err {
485 sqlx::Error::PoolTimedOut => {
487 ApiError::service_unavailable("Database connection pool exhausted")
488 .with_internal(err.to_string())
489 }
490
491 sqlx::Error::PoolClosed => {
493 ApiError::service_unavailable("Database connection pool is closed")
494 .with_internal(err.to_string())
495 }
496
497 sqlx::Error::RowNotFound => ApiError::not_found("Resource not found"),
499
500 sqlx::Error::Database(db_err) => {
502 if let Some(code) = db_err.code() {
505 let code_str = code.as_ref();
506 if code_str == "23505" || code_str == "1062" || code_str == "2067" {
507 return ApiError::conflict("Resource already exists")
508 .with_internal(db_err.to_string());
509 }
510
511 if code_str == "23503" || code_str == "1452" || code_str == "787" {
514 return ApiError::bad_request("Referenced resource does not exist")
515 .with_internal(db_err.to_string());
516 }
517
518 if code_str == "23514" {
521 return ApiError::bad_request("Data validation failed")
522 .with_internal(db_err.to_string());
523 }
524 }
525
526 ApiError::internal("Database error").with_internal(db_err.to_string())
528 }
529
530 sqlx::Error::Io(_) => ApiError::service_unavailable("Database connection error")
532 .with_internal(err.to_string()),
533
534 sqlx::Error::Tls(_) => {
536 ApiError::service_unavailable("Database TLS error").with_internal(err.to_string())
537 }
538
539 sqlx::Error::Protocol(_) => {
541 ApiError::internal("Database protocol error").with_internal(err.to_string())
542 }
543
544 sqlx::Error::TypeNotFound { .. } => {
546 ApiError::internal("Database type error").with_internal(err.to_string())
547 }
548
549 sqlx::Error::ColumnNotFound(_) => {
550 ApiError::internal("Database column not found").with_internal(err.to_string())
551 }
552
553 sqlx::Error::ColumnIndexOutOfBounds { .. } => {
554 ApiError::internal("Database column index error").with_internal(err.to_string())
555 }
556
557 sqlx::Error::ColumnDecode { .. } => {
558 ApiError::internal("Database decode error").with_internal(err.to_string())
559 }
560
561 sqlx::Error::Configuration(_) => {
563 ApiError::internal("Database configuration error").with_internal(err.to_string())
564 }
565
566 sqlx::Error::Migrate(_) => {
568 ApiError::internal("Database migration error").with_internal(err.to_string())
569 }
570
571 _ => ApiError::internal("Database error").with_internal(err.to_string()),
573 }
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580 use proptest::prelude::*;
581 use std::collections::HashSet;
582
583 proptest! {
591 #![proptest_config(ProptestConfig::with_cases(100))]
592
593 #[test]
594 fn prop_error_id_uniqueness(
595 num_errors in 10usize..200,
597 ) {
598 let error_ids: Vec<String> = (0..num_errors)
600 .map(|_| generate_error_id())
601 .collect();
602
603 let unique_ids: HashSet<&String> = error_ids.iter().collect();
605
606 prop_assert_eq!(
608 unique_ids.len(),
609 error_ids.len(),
610 "Generated {} error IDs but only {} were unique",
611 error_ids.len(),
612 unique_ids.len()
613 );
614
615 for id in &error_ids {
617 prop_assert!(
618 id.starts_with("err_"),
619 "Error ID '{}' does not start with 'err_'",
620 id
621 );
622
623 let uuid_part = &id[4..];
625 prop_assert_eq!(
626 uuid_part.len(),
627 32,
628 "UUID part '{}' should be 32 characters, got {}",
629 uuid_part,
630 uuid_part.len()
631 );
632
633 prop_assert!(
635 uuid_part.chars().all(|c| c.is_ascii_hexdigit()),
636 "UUID part '{}' contains non-hex characters",
637 uuid_part
638 );
639 }
640 }
641 }
642
643 proptest! {
650 #![proptest_config(ProptestConfig::with_cases(100))]
651
652 #[test]
653 fn prop_error_response_contains_error_id(
654 error_type in "[a-z_]{1,20}",
655 message in "[a-zA-Z0-9 ]{1,100}",
656 ) {
657 let api_error = ApiError::new(StatusCode::INTERNAL_SERVER_ERROR, error_type, message);
658 let error_response = ErrorResponse::from(api_error);
659
660 prop_assert!(
662 error_response.error_id.starts_with("err_"),
663 "Error ID '{}' does not start with 'err_'",
664 error_response.error_id
665 );
666
667 let uuid_part = &error_response.error_id[4..];
668 prop_assert_eq!(uuid_part.len(), 32);
669 prop_assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit()));
670 }
671 }
672
673 #[test]
674 fn test_error_id_format() {
675 let error_id = generate_error_id();
676
677 assert!(error_id.starts_with("err_"));
679
680 assert_eq!(error_id.len(), 36);
682
683 let uuid_part = &error_id[4..];
685 assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit()));
686 }
687
688 #[test]
689 fn test_error_response_includes_error_id() {
690 let api_error = ApiError::bad_request("test error");
691 let error_response = ErrorResponse::from(api_error);
692
693 assert!(error_response.error_id.starts_with("err_"));
695 assert_eq!(error_response.error_id.len(), 36);
696 }
697
698 #[test]
699 fn test_error_id_in_json_serialization() {
700 let api_error = ApiError::internal("test error");
701 let error_response = ErrorResponse::from(api_error);
702
703 let json = serde_json::to_string(&error_response).unwrap();
704
705 assert!(json.contains("\"error_id\":"));
707 assert!(json.contains("err_"));
708 }
709
710 #[test]
711 fn test_multiple_error_ids_are_unique() {
712 let ids: Vec<String> = (0..1000).map(|_| generate_error_id()).collect();
713 let unique: HashSet<_> = ids.iter().collect();
714
715 assert_eq!(ids.len(), unique.len(), "All error IDs should be unique");
716 }
717
718 proptest! {
726 #![proptest_config(ProptestConfig::with_cases(100))]
727
728 #[test]
729 fn prop_production_error_masking(
730 sensitive_message in "[a-zA-Z0-9_]{10,200}",
733 internal_details in "[a-zA-Z0-9_]{10,200}",
734 status_code in prop::sample::select(vec![500u16, 501, 502, 503, 504, 505]),
736 ) {
737 let api_error = ApiError::new(
739 StatusCode::from_u16(status_code).unwrap(),
740 "internal_error",
741 sensitive_message.clone()
742 ).with_internal(internal_details.clone());
743
744 let error_response = ErrorResponse::from_api_error(api_error, Environment::Production);
746
747 prop_assert_eq!(
749 &error_response.error.message,
750 "An internal error occurred",
751 "Production 5xx error should have masked message, got: {}",
752 &error_response.error.message
753 );
754
755 if sensitive_message.len() >= 10 {
758 prop_assert!(
759 !error_response.error.message.contains(&sensitive_message),
760 "Production error response should not contain original message"
761 );
762 }
763
764 let json = serde_json::to_string(&error_response).unwrap();
766 if internal_details.len() >= 10 {
767 prop_assert!(
768 !json.contains(&internal_details),
769 "Production error response should not contain internal details"
770 );
771 }
772
773 prop_assert!(
775 error_response.error_id.starts_with("err_"),
776 "Error ID should be present in production error response"
777 );
778 }
779 }
780
781 proptest! {
789 #![proptest_config(ProptestConfig::with_cases(100))]
790
791 #[test]
792 fn prop_development_error_details(
793 error_message in "[a-zA-Z0-9 ]{1,100}",
795 error_type in "[a-z_]{1,20}",
796 status_code in prop::sample::select(vec![400u16, 401, 403, 404, 500, 502, 503]),
798 ) {
799 let api_error = ApiError::new(
801 StatusCode::from_u16(status_code).unwrap(),
802 error_type.clone(),
803 error_message.clone()
804 );
805
806 let error_response = ErrorResponse::from_api_error(api_error, Environment::Development);
808
809 prop_assert_eq!(
811 error_response.error.message,
812 error_message,
813 "Development error should preserve original message"
814 );
815
816 prop_assert_eq!(
818 error_response.error.error_type,
819 error_type,
820 "Development error should preserve error type"
821 );
822
823 prop_assert!(
825 error_response.error_id.starts_with("err_"),
826 "Error ID should be present in development error response"
827 );
828 }
829 }
830
831 proptest! {
839 #![proptest_config(ProptestConfig::with_cases(100))]
840
841 #[test]
842 fn prop_validation_error_field_details(
843 field_name in "[a-z_]{1,20}",
845 field_code in "[a-z_]{1,15}",
846 field_message in "[a-zA-Z0-9 ]{1,50}",
847 is_production in proptest::bool::ANY,
849 ) {
850 let env = if is_production {
851 Environment::Production
852 } else {
853 Environment::Development
854 };
855
856 let field_error = FieldError {
858 field: field_name.clone(),
859 code: field_code.clone(),
860 message: field_message.clone(),
861 };
862 let api_error = ApiError::validation(vec![field_error]);
863
864 let error_response = ErrorResponse::from_api_error(api_error, env);
866
867 prop_assert!(
869 error_response.error.fields.is_some(),
870 "Validation error should always include fields in {} mode",
871 env
872 );
873
874 let fields = error_response.error.fields.as_ref().unwrap();
875 prop_assert_eq!(
876 fields.len(),
877 1,
878 "Should have exactly one field error"
879 );
880
881 let field = &fields[0];
882
883 prop_assert_eq!(
885 &field.field,
886 &field_name,
887 "Field name should be preserved in {} mode",
888 env
889 );
890
891 prop_assert_eq!(
893 &field.code,
894 &field_code,
895 "Field code should be preserved in {} mode",
896 env
897 );
898
899 prop_assert_eq!(
901 &field.message,
902 &field_message,
903 "Field message should be preserved in {} mode",
904 env
905 );
906
907 let json = serde_json::to_string(&error_response).unwrap();
909 prop_assert!(
910 json.contains(&field_name),
911 "JSON should contain field name in {} mode",
912 env
913 );
914 prop_assert!(
915 json.contains(&field_code),
916 "JSON should contain field code in {} mode",
917 env
918 );
919 prop_assert!(
920 json.contains(&field_message),
921 "JSON should contain field message in {} mode",
922 env
923 );
924 }
925 }
926
927 #[test]
933 fn test_environment_from_env_production() {
934 assert!(matches!(
939 match "production".to_lowercase().as_str() {
940 "production" | "prod" => Environment::Production,
941 _ => Environment::Development,
942 },
943 Environment::Production
944 ));
945
946 assert!(matches!(
947 match "prod".to_lowercase().as_str() {
948 "production" | "prod" => Environment::Production,
949 _ => Environment::Development,
950 },
951 Environment::Production
952 ));
953
954 assert!(matches!(
955 match "PRODUCTION".to_lowercase().as_str() {
956 "production" | "prod" => Environment::Production,
957 _ => Environment::Development,
958 },
959 Environment::Production
960 ));
961
962 assert!(matches!(
963 match "PROD".to_lowercase().as_str() {
964 "production" | "prod" => Environment::Production,
965 _ => Environment::Development,
966 },
967 Environment::Production
968 ));
969 }
970
971 #[test]
972 fn test_environment_from_env_development() {
973 assert!(matches!(
978 match "development".to_lowercase().as_str() {
979 "production" | "prod" => Environment::Production,
980 _ => Environment::Development,
981 },
982 Environment::Development
983 ));
984
985 assert!(matches!(
986 match "dev".to_lowercase().as_str() {
987 "production" | "prod" => Environment::Production,
988 _ => Environment::Development,
989 },
990 Environment::Development
991 ));
992
993 assert!(matches!(
994 match "test".to_lowercase().as_str() {
995 "production" | "prod" => Environment::Production,
996 _ => Environment::Development,
997 },
998 Environment::Development
999 ));
1000
1001 assert!(matches!(
1002 match "anything_else".to_lowercase().as_str() {
1003 "production" | "prod" => Environment::Production,
1004 _ => Environment::Development,
1005 },
1006 Environment::Development
1007 ));
1008 }
1009
1010 #[test]
1011 fn test_environment_default_is_development() {
1012 assert_eq!(Environment::default(), Environment::Development);
1014 }
1015
1016 #[test]
1017 fn test_environment_display() {
1018 assert_eq!(format!("{}", Environment::Development), "development");
1019 assert_eq!(format!("{}", Environment::Production), "production");
1020 }
1021
1022 #[test]
1023 fn test_environment_is_methods() {
1024 assert!(Environment::Production.is_production());
1025 assert!(!Environment::Production.is_development());
1026 assert!(Environment::Development.is_development());
1027 assert!(!Environment::Development.is_production());
1028 }
1029
1030 #[test]
1031 fn test_production_masks_5xx_errors() {
1032 let error =
1033 ApiError::internal("Sensitive database connection string: postgres://user:pass@host");
1034 let response = ErrorResponse::from_api_error(error, Environment::Production);
1035
1036 assert_eq!(response.error.message, "An internal error occurred");
1037 assert!(!response.error.message.contains("postgres"));
1038 }
1039
1040 #[test]
1041 fn test_production_shows_4xx_errors() {
1042 let error = ApiError::bad_request("Invalid email format");
1043 let response = ErrorResponse::from_api_error(error, Environment::Production);
1044
1045 assert_eq!(response.error.message, "Invalid email format");
1047 }
1048
1049 #[test]
1050 fn test_development_shows_all_errors() {
1051 let error = ApiError::internal("Detailed error: connection refused to 192.168.1.1:5432");
1052 let response = ErrorResponse::from_api_error(error, Environment::Development);
1053
1054 assert_eq!(
1055 response.error.message,
1056 "Detailed error: connection refused to 192.168.1.1:5432"
1057 );
1058 }
1059
1060 #[test]
1061 fn test_validation_errors_always_show_fields() {
1062 let fields = vec![
1063 FieldError {
1064 field: "email".to_string(),
1065 code: "invalid_format".to_string(),
1066 message: "Invalid email format".to_string(),
1067 },
1068 FieldError {
1069 field: "age".to_string(),
1070 code: "min".to_string(),
1071 message: "Must be at least 18".to_string(),
1072 },
1073 ];
1074
1075 let error = ApiError::validation(fields.clone());
1076
1077 let prod_response = ErrorResponse::from_api_error(error.clone(), Environment::Production);
1079 assert!(prod_response.error.fields.is_some());
1080 let prod_fields = prod_response.error.fields.unwrap();
1081 assert_eq!(prod_fields.len(), 2);
1082 assert_eq!(prod_fields[0].field, "email");
1083 assert_eq!(prod_fields[1].field, "age");
1084
1085 let dev_response = ErrorResponse::from_api_error(error, Environment::Development);
1087 assert!(dev_response.error.fields.is_some());
1088 let dev_fields = dev_response.error.fields.unwrap();
1089 assert_eq!(dev_fields.len(), 2);
1090 }
1091}