1use http::StatusCode;
61use serde::Serialize;
62use std::fmt;
63use std::sync::OnceLock;
64use uuid::Uuid;
65
66pub type Result<T, E = ApiError> = std::result::Result<T, E>;
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
89pub enum Environment {
90 #[default]
92 Development,
93 Production,
95}
96
97impl Environment {
98 pub fn from_env() -> Self {
115 match std::env::var("RUSTAPI_ENV")
116 .map(|s| s.to_lowercase())
117 .as_deref()
118 {
119 Ok("production") | Ok("prod") => Environment::Production,
120 _ => Environment::Development,
121 }
122 }
123
124 pub fn is_production(&self) -> bool {
126 matches!(self, Environment::Production)
127 }
128
129 pub fn is_development(&self) -> bool {
131 matches!(self, Environment::Development)
132 }
133}
134
135impl fmt::Display for Environment {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 match self {
138 Environment::Development => write!(f, "development"),
139 Environment::Production => write!(f, "production"),
140 }
141 }
142}
143
144static ENVIRONMENT: OnceLock<Environment> = OnceLock::new();
146
147pub fn get_environment() -> Environment {
152 *ENVIRONMENT.get_or_init(Environment::from_env)
153}
154
155#[cfg(test)]
160#[allow(dead_code)]
161pub fn set_environment_for_test(env: Environment) -> Result<(), Environment> {
162 ENVIRONMENT.set(env)
163}
164
165pub fn generate_error_id() -> String {
180 format!("err_{}", Uuid::new_v4().simple())
181}
182
183#[derive(Debug, Clone)]
206pub struct ApiError {
207 pub status: StatusCode,
209 pub error_type: String,
211 pub message: String,
213 pub fields: Option<Vec<FieldError>>,
215 pub(crate) internal: Option<String>,
217}
218
219#[derive(Debug, Clone, Serialize)]
221pub struct FieldError {
222 pub field: String,
224 pub code: String,
226 pub message: String,
228}
229
230impl ApiError {
231 pub fn new(
233 status: StatusCode,
234 error_type: impl Into<String>,
235 message: impl Into<String>,
236 ) -> Self {
237 Self {
238 status,
239 error_type: error_type.into(),
240 message: message.into(),
241 fields: None,
242 internal: None,
243 }
244 }
245
246 pub fn validation(fields: Vec<FieldError>) -> Self {
248 Self {
249 status: StatusCode::UNPROCESSABLE_ENTITY,
250 error_type: "validation_error".to_string(),
251 message: "Request validation failed".to_string(),
252 fields: Some(fields),
253 internal: None,
254 }
255 }
256
257 pub fn bad_request(message: impl Into<String>) -> Self {
259 Self::new(StatusCode::BAD_REQUEST, "bad_request", message)
260 }
261
262 pub fn unauthorized(message: impl Into<String>) -> Self {
264 Self::new(StatusCode::UNAUTHORIZED, "unauthorized", message)
265 }
266
267 pub fn forbidden(message: impl Into<String>) -> Self {
269 Self::new(StatusCode::FORBIDDEN, "forbidden", message)
270 }
271
272 pub fn not_found(message: impl Into<String>) -> Self {
274 Self::new(StatusCode::NOT_FOUND, "not_found", message)
275 }
276
277 pub fn conflict(message: impl Into<String>) -> Self {
279 Self::new(StatusCode::CONFLICT, "conflict", message)
280 }
281
282 pub fn internal(message: impl Into<String>) -> Self {
284 Self::new(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", message)
285 }
286
287 pub fn with_internal(mut self, details: impl Into<String>) -> Self {
289 self.internal = Some(details.into());
290 self
291 }
292}
293
294impl fmt::Display for ApiError {
295 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296 write!(f, "{}: {}", self.error_type, self.message)
297 }
298}
299
300impl std::error::Error for ApiError {}
301
302#[derive(Serialize)]
304pub struct ErrorResponse {
305 pub error: ErrorBody,
306 pub error_id: String,
308 #[serde(skip_serializing_if = "Option::is_none")]
309 pub request_id: Option<String>,
310}
311
312#[derive(Serialize)]
313pub struct ErrorBody {
314 #[serde(rename = "type")]
315 pub error_type: String,
316 pub message: String,
317 #[serde(skip_serializing_if = "Option::is_none")]
318 pub fields: Option<Vec<FieldError>>,
319}
320
321impl ErrorResponse {
322 pub fn from_api_error(err: ApiError, env: Environment) -> Self {
332 let error_id = generate_error_id();
333
334 if err.status.is_server_error() {
336 tracing::error!(
337 error_id = %error_id,
338 error_type = %err.error_type,
339 message = %err.message,
340 status = %err.status.as_u16(),
341 internal = ?err.internal,
342 environment = %env,
343 "Server error occurred"
344 );
345 } else if err.status.is_client_error() {
346 tracing::warn!(
347 error_id = %error_id,
348 error_type = %err.error_type,
349 message = %err.message,
350 status = %err.status.as_u16(),
351 environment = %env,
352 "Client error occurred"
353 );
354 } else {
355 tracing::info!(
356 error_id = %error_id,
357 error_type = %err.error_type,
358 message = %err.message,
359 status = %err.status.as_u16(),
360 environment = %env,
361 "Error response generated"
362 );
363 }
364
365 let (message, fields) = if env.is_production() && err.status.is_server_error() {
367 let masked_message = "An internal error occurred".to_string();
370 let fields = if err.error_type == "validation_error" {
372 err.fields
373 } else {
374 None
375 };
376 (masked_message, fields)
377 } else {
378 (err.message, err.fields)
380 };
381
382 Self {
383 error: ErrorBody {
384 error_type: err.error_type,
385 message,
386 fields,
387 },
388 error_id,
389 request_id: None,
390 }
391 }
392}
393
394impl From<ApiError> for ErrorResponse {
395 fn from(err: ApiError) -> Self {
396 let env = get_environment();
398 Self::from_api_error(err, env)
399 }
400}
401
402impl From<serde_json::Error> for ApiError {
404 fn from(err: serde_json::Error) -> Self {
405 ApiError::bad_request(format!("Invalid JSON: {}", err))
406 }
407}
408
409impl From<std::io::Error> for ApiError {
410 fn from(err: std::io::Error) -> Self {
411 ApiError::internal("I/O error").with_internal(err.to_string())
412 }
413}
414
415impl From<hyper::Error> for ApiError {
416 fn from(err: hyper::Error) -> Self {
417 ApiError::internal("HTTP error").with_internal(err.to_string())
418 }
419}
420
421impl From<rustapi_validate::ValidationError> for ApiError {
422 fn from(err: rustapi_validate::ValidationError) -> Self {
423 let fields = err
424 .fields
425 .into_iter()
426 .map(|f| FieldError {
427 field: f.field,
428 code: f.code,
429 message: f.message,
430 })
431 .collect();
432
433 ApiError::validation(fields)
434 }
435}
436
437impl ApiError {
438 pub fn from_validation_error(err: rustapi_validate::ValidationError) -> Self {
440 err.into()
441 }
442
443 pub fn service_unavailable(message: impl Into<String>) -> Self {
445 Self::new(
446 StatusCode::SERVICE_UNAVAILABLE,
447 "service_unavailable",
448 message,
449 )
450 }
451}
452
453#[cfg(feature = "sqlx")]
455impl From<sqlx::Error> for ApiError {
456 fn from(err: sqlx::Error) -> Self {
457 match &err {
458 sqlx::Error::PoolTimedOut => {
460 ApiError::service_unavailable("Database connection pool exhausted")
461 .with_internal(err.to_string())
462 }
463
464 sqlx::Error::PoolClosed => {
466 ApiError::service_unavailable("Database connection pool is closed")
467 .with_internal(err.to_string())
468 }
469
470 sqlx::Error::RowNotFound => ApiError::not_found("Resource not found"),
472
473 sqlx::Error::Database(db_err) => {
475 if let Some(code) = db_err.code() {
478 let code_str = code.as_ref();
479 if code_str == "23505" || code_str == "1062" || code_str == "2067" {
480 return ApiError::conflict("Resource already exists")
481 .with_internal(db_err.to_string());
482 }
483
484 if code_str == "23503" || code_str == "1452" || code_str == "787" {
487 return ApiError::bad_request("Referenced resource does not exist")
488 .with_internal(db_err.to_string());
489 }
490
491 if code_str == "23514" {
494 return ApiError::bad_request("Data validation failed")
495 .with_internal(db_err.to_string());
496 }
497 }
498
499 ApiError::internal("Database error").with_internal(db_err.to_string())
501 }
502
503 sqlx::Error::Io(_) => ApiError::service_unavailable("Database connection error")
505 .with_internal(err.to_string()),
506
507 sqlx::Error::Tls(_) => {
509 ApiError::service_unavailable("Database TLS error").with_internal(err.to_string())
510 }
511
512 sqlx::Error::Protocol(_) => {
514 ApiError::internal("Database protocol error").with_internal(err.to_string())
515 }
516
517 sqlx::Error::TypeNotFound { .. } => {
519 ApiError::internal("Database type error").with_internal(err.to_string())
520 }
521
522 sqlx::Error::ColumnNotFound(_) => {
523 ApiError::internal("Database column not found").with_internal(err.to_string())
524 }
525
526 sqlx::Error::ColumnIndexOutOfBounds { .. } => {
527 ApiError::internal("Database column index error").with_internal(err.to_string())
528 }
529
530 sqlx::Error::ColumnDecode { .. } => {
531 ApiError::internal("Database decode error").with_internal(err.to_string())
532 }
533
534 sqlx::Error::Configuration(_) => {
536 ApiError::internal("Database configuration error").with_internal(err.to_string())
537 }
538
539 sqlx::Error::Migrate(_) => {
541 ApiError::internal("Database migration error").with_internal(err.to_string())
542 }
543
544 _ => ApiError::internal("Database error").with_internal(err.to_string()),
546 }
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553 use proptest::prelude::*;
554 use std::collections::HashSet;
555
556 proptest! {
564 #![proptest_config(ProptestConfig::with_cases(100))]
565
566 #[test]
567 fn prop_error_id_uniqueness(
568 num_errors in 10usize..200,
570 ) {
571 let error_ids: Vec<String> = (0..num_errors)
573 .map(|_| generate_error_id())
574 .collect();
575
576 let unique_ids: HashSet<&String> = error_ids.iter().collect();
578
579 prop_assert_eq!(
581 unique_ids.len(),
582 error_ids.len(),
583 "Generated {} error IDs but only {} were unique",
584 error_ids.len(),
585 unique_ids.len()
586 );
587
588 for id in &error_ids {
590 prop_assert!(
591 id.starts_with("err_"),
592 "Error ID '{}' does not start with 'err_'",
593 id
594 );
595
596 let uuid_part = &id[4..];
598 prop_assert_eq!(
599 uuid_part.len(),
600 32,
601 "UUID part '{}' should be 32 characters, got {}",
602 uuid_part,
603 uuid_part.len()
604 );
605
606 prop_assert!(
608 uuid_part.chars().all(|c| c.is_ascii_hexdigit()),
609 "UUID part '{}' contains non-hex characters",
610 uuid_part
611 );
612 }
613 }
614 }
615
616 proptest! {
623 #![proptest_config(ProptestConfig::with_cases(100))]
624
625 #[test]
626 fn prop_error_response_contains_error_id(
627 error_type in "[a-z_]{1,20}",
628 message in "[a-zA-Z0-9 ]{1,100}",
629 ) {
630 let api_error = ApiError::new(StatusCode::INTERNAL_SERVER_ERROR, error_type, message);
631 let error_response = ErrorResponse::from(api_error);
632
633 prop_assert!(
635 error_response.error_id.starts_with("err_"),
636 "Error ID '{}' does not start with 'err_'",
637 error_response.error_id
638 );
639
640 let uuid_part = &error_response.error_id[4..];
641 prop_assert_eq!(uuid_part.len(), 32);
642 prop_assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit()));
643 }
644 }
645
646 #[test]
647 fn test_error_id_format() {
648 let error_id = generate_error_id();
649
650 assert!(error_id.starts_with("err_"));
652
653 assert_eq!(error_id.len(), 36);
655
656 let uuid_part = &error_id[4..];
658 assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit()));
659 }
660
661 #[test]
662 fn test_error_response_includes_error_id() {
663 let api_error = ApiError::bad_request("test error");
664 let error_response = ErrorResponse::from(api_error);
665
666 assert!(error_response.error_id.starts_with("err_"));
668 assert_eq!(error_response.error_id.len(), 36);
669 }
670
671 #[test]
672 fn test_error_id_in_json_serialization() {
673 let api_error = ApiError::internal("test error");
674 let error_response = ErrorResponse::from(api_error);
675
676 let json = serde_json::to_string(&error_response).unwrap();
677
678 assert!(json.contains("\"error_id\":"));
680 assert!(json.contains("err_"));
681 }
682
683 #[test]
684 fn test_multiple_error_ids_are_unique() {
685 let ids: Vec<String> = (0..1000).map(|_| generate_error_id()).collect();
686 let unique: HashSet<_> = ids.iter().collect();
687
688 assert_eq!(ids.len(), unique.len(), "All error IDs should be unique");
689 }
690
691 proptest! {
699 #![proptest_config(ProptestConfig::with_cases(100))]
700
701 #[test]
702 fn prop_production_error_masking(
703 sensitive_message in "[a-zA-Z0-9_]{10,200}",
706 internal_details in "[a-zA-Z0-9_]{10,200}",
707 status_code in prop::sample::select(vec![500u16, 501, 502, 503, 504, 505]),
709 ) {
710 let api_error = ApiError::new(
712 StatusCode::from_u16(status_code).unwrap(),
713 "internal_error",
714 sensitive_message.clone()
715 ).with_internal(internal_details.clone());
716
717 let error_response = ErrorResponse::from_api_error(api_error, Environment::Production);
719
720 prop_assert_eq!(
722 &error_response.error.message,
723 "An internal error occurred",
724 "Production 5xx error should have masked message, got: {}",
725 &error_response.error.message
726 );
727
728 if sensitive_message.len() >= 10 {
731 prop_assert!(
732 !error_response.error.message.contains(&sensitive_message),
733 "Production error response should not contain original message"
734 );
735 }
736
737 let json = serde_json::to_string(&error_response).unwrap();
739 if internal_details.len() >= 10 {
740 prop_assert!(
741 !json.contains(&internal_details),
742 "Production error response should not contain internal details"
743 );
744 }
745
746 prop_assert!(
748 error_response.error_id.starts_with("err_"),
749 "Error ID should be present in production error response"
750 );
751 }
752 }
753
754 proptest! {
762 #![proptest_config(ProptestConfig::with_cases(100))]
763
764 #[test]
765 fn prop_development_error_details(
766 error_message in "[a-zA-Z0-9 ]{1,100}",
768 error_type in "[a-z_]{1,20}",
769 status_code in prop::sample::select(vec![400u16, 401, 403, 404, 500, 502, 503]),
771 ) {
772 let api_error = ApiError::new(
774 StatusCode::from_u16(status_code).unwrap(),
775 error_type.clone(),
776 error_message.clone()
777 );
778
779 let error_response = ErrorResponse::from_api_error(api_error, Environment::Development);
781
782 prop_assert_eq!(
784 error_response.error.message,
785 error_message,
786 "Development error should preserve original message"
787 );
788
789 prop_assert_eq!(
791 error_response.error.error_type,
792 error_type,
793 "Development error should preserve error type"
794 );
795
796 prop_assert!(
798 error_response.error_id.starts_with("err_"),
799 "Error ID should be present in development error response"
800 );
801 }
802 }
803
804 proptest! {
812 #![proptest_config(ProptestConfig::with_cases(100))]
813
814 #[test]
815 fn prop_validation_error_field_details(
816 field_name in "[a-z_]{1,20}",
818 field_code in "[a-z_]{1,15}",
819 field_message in "[a-zA-Z0-9 ]{1,50}",
820 is_production in proptest::bool::ANY,
822 ) {
823 let env = if is_production {
824 Environment::Production
825 } else {
826 Environment::Development
827 };
828
829 let field_error = FieldError {
831 field: field_name.clone(),
832 code: field_code.clone(),
833 message: field_message.clone(),
834 };
835 let api_error = ApiError::validation(vec![field_error]);
836
837 let error_response = ErrorResponse::from_api_error(api_error, env);
839
840 prop_assert!(
842 error_response.error.fields.is_some(),
843 "Validation error should always include fields in {} mode",
844 env
845 );
846
847 let fields = error_response.error.fields.as_ref().unwrap();
848 prop_assert_eq!(
849 fields.len(),
850 1,
851 "Should have exactly one field error"
852 );
853
854 let field = &fields[0];
855
856 prop_assert_eq!(
858 &field.field,
859 &field_name,
860 "Field name should be preserved in {} mode",
861 env
862 );
863
864 prop_assert_eq!(
866 &field.code,
867 &field_code,
868 "Field code should be preserved in {} mode",
869 env
870 );
871
872 prop_assert_eq!(
874 &field.message,
875 &field_message,
876 "Field message should be preserved in {} mode",
877 env
878 );
879
880 let json = serde_json::to_string(&error_response).unwrap();
882 prop_assert!(
883 json.contains(&field_name),
884 "JSON should contain field name in {} mode",
885 env
886 );
887 prop_assert!(
888 json.contains(&field_code),
889 "JSON should contain field code in {} mode",
890 env
891 );
892 prop_assert!(
893 json.contains(&field_message),
894 "JSON should contain field message in {} mode",
895 env
896 );
897 }
898 }
899
900 #[test]
906 fn test_environment_from_env_production() {
907 assert!(matches!(
912 match "production".to_lowercase().as_str() {
913 "production" | "prod" => Environment::Production,
914 _ => Environment::Development,
915 },
916 Environment::Production
917 ));
918
919 assert!(matches!(
920 match "prod".to_lowercase().as_str() {
921 "production" | "prod" => Environment::Production,
922 _ => Environment::Development,
923 },
924 Environment::Production
925 ));
926
927 assert!(matches!(
928 match "PRODUCTION".to_lowercase().as_str() {
929 "production" | "prod" => Environment::Production,
930 _ => Environment::Development,
931 },
932 Environment::Production
933 ));
934
935 assert!(matches!(
936 match "PROD".to_lowercase().as_str() {
937 "production" | "prod" => Environment::Production,
938 _ => Environment::Development,
939 },
940 Environment::Production
941 ));
942 }
943
944 #[test]
945 fn test_environment_from_env_development() {
946 assert!(matches!(
951 match "development".to_lowercase().as_str() {
952 "production" | "prod" => Environment::Production,
953 _ => Environment::Development,
954 },
955 Environment::Development
956 ));
957
958 assert!(matches!(
959 match "dev".to_lowercase().as_str() {
960 "production" | "prod" => Environment::Production,
961 _ => Environment::Development,
962 },
963 Environment::Development
964 ));
965
966 assert!(matches!(
967 match "test".to_lowercase().as_str() {
968 "production" | "prod" => Environment::Production,
969 _ => Environment::Development,
970 },
971 Environment::Development
972 ));
973
974 assert!(matches!(
975 match "anything_else".to_lowercase().as_str() {
976 "production" | "prod" => Environment::Production,
977 _ => Environment::Development,
978 },
979 Environment::Development
980 ));
981 }
982
983 #[test]
984 fn test_environment_default_is_development() {
985 assert_eq!(Environment::default(), Environment::Development);
987 }
988
989 #[test]
990 fn test_environment_display() {
991 assert_eq!(format!("{}", Environment::Development), "development");
992 assert_eq!(format!("{}", Environment::Production), "production");
993 }
994
995 #[test]
996 fn test_environment_is_methods() {
997 assert!(Environment::Production.is_production());
998 assert!(!Environment::Production.is_development());
999 assert!(Environment::Development.is_development());
1000 assert!(!Environment::Development.is_production());
1001 }
1002
1003 #[test]
1004 fn test_production_masks_5xx_errors() {
1005 let error =
1006 ApiError::internal("Sensitive database connection string: postgres://user:pass@host");
1007 let response = ErrorResponse::from_api_error(error, Environment::Production);
1008
1009 assert_eq!(response.error.message, "An internal error occurred");
1010 assert!(!response.error.message.contains("postgres"));
1011 }
1012
1013 #[test]
1014 fn test_production_shows_4xx_errors() {
1015 let error = ApiError::bad_request("Invalid email format");
1016 let response = ErrorResponse::from_api_error(error, Environment::Production);
1017
1018 assert_eq!(response.error.message, "Invalid email format");
1020 }
1021
1022 #[test]
1023 fn test_development_shows_all_errors() {
1024 let error = ApiError::internal("Detailed error: connection refused to 192.168.1.1:5432");
1025 let response = ErrorResponse::from_api_error(error, Environment::Development);
1026
1027 assert_eq!(
1028 response.error.message,
1029 "Detailed error: connection refused to 192.168.1.1:5432"
1030 );
1031 }
1032
1033 #[test]
1034 fn test_validation_errors_always_show_fields() {
1035 let fields = vec![
1036 FieldError {
1037 field: "email".to_string(),
1038 code: "invalid_format".to_string(),
1039 message: "Invalid email format".to_string(),
1040 },
1041 FieldError {
1042 field: "age".to_string(),
1043 code: "min".to_string(),
1044 message: "Must be at least 18".to_string(),
1045 },
1046 ];
1047
1048 let error = ApiError::validation(fields.clone());
1049
1050 let prod_response = ErrorResponse::from_api_error(error.clone(), Environment::Production);
1052 assert!(prod_response.error.fields.is_some());
1053 let prod_fields = prod_response.error.fields.unwrap();
1054 assert_eq!(prod_fields.len(), 2);
1055 assert_eq!(prod_fields[0].field, "email");
1056 assert_eq!(prod_fields[1].field, "age");
1057
1058 let dev_response = ErrorResponse::from_api_error(error, Environment::Development);
1060 assert!(dev_response.error.fields.is_some());
1061 let dev_fields = dev_response.error.fields.unwrap();
1062 assert_eq!(dev_fields.len(), 2);
1063 }
1064}