1use http::StatusCode;
61use serde::Serialize;
62use std::fmt;
63use std::sync::OnceLock;
64use uuid::Uuid;
65
66pub type Result<T, E = ApiError> = std::result::Result<T, E>;
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
89pub enum Environment {
90 #[default]
92 Development,
93 Production,
95}
96
97impl Environment {
98 pub fn from_env() -> Self {
115 match std::env::var("RUSTAPI_ENV")
116 .map(|s| s.to_lowercase())
117 .as_deref()
118 {
119 Ok("production") | Ok("prod") => Environment::Production,
120 _ => Environment::Development,
121 }
122 }
123
124 pub fn is_production(&self) -> bool {
126 matches!(self, Environment::Production)
127 }
128
129 pub fn is_development(&self) -> bool {
131 matches!(self, Environment::Development)
132 }
133}
134
135impl fmt::Display for Environment {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 match self {
138 Environment::Development => write!(f, "development"),
139 Environment::Production => write!(f, "production"),
140 }
141 }
142}
143
144static ENVIRONMENT: OnceLock<Environment> = OnceLock::new();
146
147pub fn get_environment() -> Environment {
152 *ENVIRONMENT.get_or_init(Environment::from_env)
153}
154
155#[cfg(test)]
160pub fn set_environment_for_test(env: Environment) -> Result<(), Environment> {
161 ENVIRONMENT.set(env)
162}
163
164pub fn generate_error_id() -> String {
179 format!("err_{}", Uuid::new_v4().simple())
180}
181
182#[derive(Debug, Clone)]
205pub struct ApiError {
206 pub status: StatusCode,
208 pub error_type: String,
210 pub message: String,
212 pub fields: Option<Vec<FieldError>>,
214 pub(crate) internal: Option<String>,
216}
217
218#[derive(Debug, Clone, Serialize)]
220pub struct FieldError {
221 pub field: String,
223 pub code: String,
225 pub message: String,
227}
228
229impl ApiError {
230 pub fn new(
232 status: StatusCode,
233 error_type: impl Into<String>,
234 message: impl Into<String>,
235 ) -> Self {
236 Self {
237 status,
238 error_type: error_type.into(),
239 message: message.into(),
240 fields: None,
241 internal: None,
242 }
243 }
244
245 pub fn validation(fields: Vec<FieldError>) -> Self {
247 Self {
248 status: StatusCode::UNPROCESSABLE_ENTITY,
249 error_type: "validation_error".to_string(),
250 message: "Request validation failed".to_string(),
251 fields: Some(fields),
252 internal: None,
253 }
254 }
255
256 pub fn bad_request(message: impl Into<String>) -> Self {
258 Self::new(StatusCode::BAD_REQUEST, "bad_request", message)
259 }
260
261 pub fn unauthorized(message: impl Into<String>) -> Self {
263 Self::new(StatusCode::UNAUTHORIZED, "unauthorized", message)
264 }
265
266 pub fn forbidden(message: impl Into<String>) -> Self {
268 Self::new(StatusCode::FORBIDDEN, "forbidden", message)
269 }
270
271 pub fn not_found(message: impl Into<String>) -> Self {
273 Self::new(StatusCode::NOT_FOUND, "not_found", message)
274 }
275
276 pub fn conflict(message: impl Into<String>) -> Self {
278 Self::new(StatusCode::CONFLICT, "conflict", message)
279 }
280
281 pub fn internal(message: impl Into<String>) -> Self {
283 Self::new(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", message)
284 }
285
286 pub fn with_internal(mut self, details: impl Into<String>) -> Self {
288 self.internal = Some(details.into());
289 self
290 }
291}
292
293impl fmt::Display for ApiError {
294 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
295 write!(f, "{}: {}", self.error_type, self.message)
296 }
297}
298
299impl std::error::Error for ApiError {}
300
301#[derive(Serialize)]
303pub struct ErrorResponse {
304 pub error: ErrorBody,
305 pub error_id: String,
307 #[serde(skip_serializing_if = "Option::is_none")]
308 pub request_id: Option<String>,
309}
310
311#[derive(Serialize)]
312pub struct ErrorBody {
313 #[serde(rename = "type")]
314 pub error_type: String,
315 pub message: String,
316 #[serde(skip_serializing_if = "Option::is_none")]
317 pub fields: Option<Vec<FieldError>>,
318}
319
320impl ErrorResponse {
321 pub fn from_api_error(err: ApiError, env: Environment) -> Self {
331 let error_id = generate_error_id();
332
333 if err.status.is_server_error() {
335 tracing::error!(
336 error_id = %error_id,
337 error_type = %err.error_type,
338 message = %err.message,
339 status = %err.status.as_u16(),
340 internal = ?err.internal,
341 environment = %env,
342 "Server error occurred"
343 );
344 } else if err.status.is_client_error() {
345 tracing::warn!(
346 error_id = %error_id,
347 error_type = %err.error_type,
348 message = %err.message,
349 status = %err.status.as_u16(),
350 environment = %env,
351 "Client error occurred"
352 );
353 } else {
354 tracing::info!(
355 error_id = %error_id,
356 error_type = %err.error_type,
357 message = %err.message,
358 status = %err.status.as_u16(),
359 environment = %env,
360 "Error response generated"
361 );
362 }
363
364 let (message, fields) = if env.is_production() && err.status.is_server_error() {
366 let masked_message = "An internal error occurred".to_string();
369 let fields = if err.error_type == "validation_error" {
371 err.fields
372 } else {
373 None
374 };
375 (masked_message, fields)
376 } else {
377 (err.message, err.fields)
379 };
380
381 Self {
382 error: ErrorBody {
383 error_type: err.error_type,
384 message,
385 fields,
386 },
387 error_id,
388 request_id: None,
389 }
390 }
391}
392
393impl From<ApiError> for ErrorResponse {
394 fn from(err: ApiError) -> Self {
395 let env = get_environment();
397 Self::from_api_error(err, env)
398 }
399}
400
401impl From<serde_json::Error> for ApiError {
403 fn from(err: serde_json::Error) -> Self {
404 ApiError::bad_request(format!("Invalid JSON: {}", err))
405 }
406}
407
408impl From<std::io::Error> for ApiError {
409 fn from(err: std::io::Error) -> Self {
410 ApiError::internal("I/O error").with_internal(err.to_string())
411 }
412}
413
414impl From<hyper::Error> for ApiError {
415 fn from(err: hyper::Error) -> Self {
416 ApiError::internal("HTTP error").with_internal(err.to_string())
417 }
418}
419
420impl From<rustapi_validate::ValidationError> for ApiError {
421 fn from(err: rustapi_validate::ValidationError) -> Self {
422 let fields = err
423 .fields
424 .into_iter()
425 .map(|f| FieldError {
426 field: f.field,
427 code: f.code,
428 message: f.message,
429 })
430 .collect();
431
432 ApiError::validation(fields)
433 }
434}
435
436impl ApiError {
437 pub fn from_validation_error(err: rustapi_validate::ValidationError) -> Self {
439 err.into()
440 }
441
442 pub fn service_unavailable(message: impl Into<String>) -> Self {
444 Self::new(
445 StatusCode::SERVICE_UNAVAILABLE,
446 "service_unavailable",
447 message,
448 )
449 }
450}
451
452#[cfg(feature = "sqlx")]
454impl From<sqlx::Error> for ApiError {
455 fn from(err: sqlx::Error) -> Self {
456 match &err {
457 sqlx::Error::PoolTimedOut => {
459 ApiError::service_unavailable("Database connection pool exhausted")
460 .with_internal(err.to_string())
461 }
462
463 sqlx::Error::PoolClosed => {
465 ApiError::service_unavailable("Database connection pool is closed")
466 .with_internal(err.to_string())
467 }
468
469 sqlx::Error::RowNotFound => ApiError::not_found("Resource not found"),
471
472 sqlx::Error::Database(db_err) => {
474 if let Some(code) = db_err.code() {
477 let code_str = code.as_ref();
478 if code_str == "23505" || code_str == "1062" || code_str == "2067" {
479 return ApiError::conflict("Resource already exists")
480 .with_internal(db_err.to_string());
481 }
482
483 if code_str == "23503" || code_str == "1452" || code_str == "787" {
486 return ApiError::bad_request("Referenced resource does not exist")
487 .with_internal(db_err.to_string());
488 }
489
490 if code_str == "23514" {
493 return ApiError::bad_request("Data validation failed")
494 .with_internal(db_err.to_string());
495 }
496 }
497
498 ApiError::internal("Database error").with_internal(db_err.to_string())
500 }
501
502 sqlx::Error::Io(_) => ApiError::service_unavailable("Database connection error")
504 .with_internal(err.to_string()),
505
506 sqlx::Error::Tls(_) => {
508 ApiError::service_unavailable("Database TLS error").with_internal(err.to_string())
509 }
510
511 sqlx::Error::Protocol(_) => {
513 ApiError::internal("Database protocol error").with_internal(err.to_string())
514 }
515
516 sqlx::Error::TypeNotFound { .. } => {
518 ApiError::internal("Database type error").with_internal(err.to_string())
519 }
520
521 sqlx::Error::ColumnNotFound(_) => {
522 ApiError::internal("Database column not found").with_internal(err.to_string())
523 }
524
525 sqlx::Error::ColumnIndexOutOfBounds { .. } => {
526 ApiError::internal("Database column index error").with_internal(err.to_string())
527 }
528
529 sqlx::Error::ColumnDecode { .. } => {
530 ApiError::internal("Database decode error").with_internal(err.to_string())
531 }
532
533 sqlx::Error::Configuration(_) => {
535 ApiError::internal("Database configuration error").with_internal(err.to_string())
536 }
537
538 sqlx::Error::Migrate(_) => {
540 ApiError::internal("Database migration error").with_internal(err.to_string())
541 }
542
543 _ => ApiError::internal("Database error").with_internal(err.to_string()),
545 }
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use proptest::prelude::*;
553 use std::collections::HashSet;
554
555 proptest! {
563 #![proptest_config(ProptestConfig::with_cases(100))]
564
565 #[test]
566 fn prop_error_id_uniqueness(
567 num_errors in 10usize..200,
569 ) {
570 let error_ids: Vec<String> = (0..num_errors)
572 .map(|_| generate_error_id())
573 .collect();
574
575 let unique_ids: HashSet<&String> = error_ids.iter().collect();
577
578 prop_assert_eq!(
580 unique_ids.len(),
581 error_ids.len(),
582 "Generated {} error IDs but only {} were unique",
583 error_ids.len(),
584 unique_ids.len()
585 );
586
587 for id in &error_ids {
589 prop_assert!(
590 id.starts_with("err_"),
591 "Error ID '{}' does not start with 'err_'",
592 id
593 );
594
595 let uuid_part = &id[4..];
597 prop_assert_eq!(
598 uuid_part.len(),
599 32,
600 "UUID part '{}' should be 32 characters, got {}",
601 uuid_part,
602 uuid_part.len()
603 );
604
605 prop_assert!(
607 uuid_part.chars().all(|c| c.is_ascii_hexdigit()),
608 "UUID part '{}' contains non-hex characters",
609 uuid_part
610 );
611 }
612 }
613 }
614
615 proptest! {
622 #![proptest_config(ProptestConfig::with_cases(100))]
623
624 #[test]
625 fn prop_error_response_contains_error_id(
626 error_type in "[a-z_]{1,20}",
627 message in "[a-zA-Z0-9 ]{1,100}",
628 ) {
629 let api_error = ApiError::new(StatusCode::INTERNAL_SERVER_ERROR, error_type, message);
630 let error_response = ErrorResponse::from(api_error);
631
632 prop_assert!(
634 error_response.error_id.starts_with("err_"),
635 "Error ID '{}' does not start with 'err_'",
636 error_response.error_id
637 );
638
639 let uuid_part = &error_response.error_id[4..];
640 prop_assert_eq!(uuid_part.len(), 32);
641 prop_assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit()));
642 }
643 }
644
645 #[test]
646 fn test_error_id_format() {
647 let error_id = generate_error_id();
648
649 assert!(error_id.starts_with("err_"));
651
652 assert_eq!(error_id.len(), 36);
654
655 let uuid_part = &error_id[4..];
657 assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit()));
658 }
659
660 #[test]
661 fn test_error_response_includes_error_id() {
662 let api_error = ApiError::bad_request("test error");
663 let error_response = ErrorResponse::from(api_error);
664
665 assert!(error_response.error_id.starts_with("err_"));
667 assert_eq!(error_response.error_id.len(), 36);
668 }
669
670 #[test]
671 fn test_error_id_in_json_serialization() {
672 let api_error = ApiError::internal("test error");
673 let error_response = ErrorResponse::from(api_error);
674
675 let json = serde_json::to_string(&error_response).unwrap();
676
677 assert!(json.contains("\"error_id\":"));
679 assert!(json.contains("err_"));
680 }
681
682 #[test]
683 fn test_multiple_error_ids_are_unique() {
684 let ids: Vec<String> = (0..1000).map(|_| generate_error_id()).collect();
685 let unique: HashSet<_> = ids.iter().collect();
686
687 assert_eq!(ids.len(), unique.len(), "All error IDs should be unique");
688 }
689
690 proptest! {
698 #![proptest_config(ProptestConfig::with_cases(100))]
699
700 #[test]
701 fn prop_production_error_masking(
702 sensitive_message in "[a-zA-Z0-9_]{10,200}",
705 internal_details in "[a-zA-Z0-9_]{10,200}",
706 status_code in prop::sample::select(vec![500u16, 501, 502, 503, 504, 505]),
708 ) {
709 let api_error = ApiError::new(
711 StatusCode::from_u16(status_code).unwrap(),
712 "internal_error",
713 sensitive_message.clone()
714 ).with_internal(internal_details.clone());
715
716 let error_response = ErrorResponse::from_api_error(api_error, Environment::Production);
718
719 prop_assert_eq!(
721 &error_response.error.message,
722 "An internal error occurred",
723 "Production 5xx error should have masked message, got: {}",
724 &error_response.error.message
725 );
726
727 if sensitive_message.len() >= 10 {
730 prop_assert!(
731 !error_response.error.message.contains(&sensitive_message),
732 "Production error response should not contain original message"
733 );
734 }
735
736 let json = serde_json::to_string(&error_response).unwrap();
738 if internal_details.len() >= 10 {
739 prop_assert!(
740 !json.contains(&internal_details),
741 "Production error response should not contain internal details"
742 );
743 }
744
745 prop_assert!(
747 error_response.error_id.starts_with("err_"),
748 "Error ID should be present in production error response"
749 );
750 }
751 }
752
753 proptest! {
761 #![proptest_config(ProptestConfig::with_cases(100))]
762
763 #[test]
764 fn prop_development_error_details(
765 error_message in "[a-zA-Z0-9 ]{1,100}",
767 error_type in "[a-z_]{1,20}",
768 status_code in prop::sample::select(vec![400u16, 401, 403, 404, 500, 502, 503]),
770 ) {
771 let api_error = ApiError::new(
773 StatusCode::from_u16(status_code).unwrap(),
774 error_type.clone(),
775 error_message.clone()
776 );
777
778 let error_response = ErrorResponse::from_api_error(api_error, Environment::Development);
780
781 prop_assert_eq!(
783 error_response.error.message,
784 error_message,
785 "Development error should preserve original message"
786 );
787
788 prop_assert_eq!(
790 error_response.error.error_type,
791 error_type,
792 "Development error should preserve error type"
793 );
794
795 prop_assert!(
797 error_response.error_id.starts_with("err_"),
798 "Error ID should be present in development error response"
799 );
800 }
801 }
802
803 proptest! {
811 #![proptest_config(ProptestConfig::with_cases(100))]
812
813 #[test]
814 fn prop_validation_error_field_details(
815 field_name in "[a-z_]{1,20}",
817 field_code in "[a-z_]{1,15}",
818 field_message in "[a-zA-Z0-9 ]{1,50}",
819 is_production in proptest::bool::ANY,
821 ) {
822 let env = if is_production {
823 Environment::Production
824 } else {
825 Environment::Development
826 };
827
828 let field_error = FieldError {
830 field: field_name.clone(),
831 code: field_code.clone(),
832 message: field_message.clone(),
833 };
834 let api_error = ApiError::validation(vec![field_error]);
835
836 let error_response = ErrorResponse::from_api_error(api_error, env);
838
839 prop_assert!(
841 error_response.error.fields.is_some(),
842 "Validation error should always include fields in {} mode",
843 env
844 );
845
846 let fields = error_response.error.fields.as_ref().unwrap();
847 prop_assert_eq!(
848 fields.len(),
849 1,
850 "Should have exactly one field error"
851 );
852
853 let field = &fields[0];
854
855 prop_assert_eq!(
857 &field.field,
858 &field_name,
859 "Field name should be preserved in {} mode",
860 env
861 );
862
863 prop_assert_eq!(
865 &field.code,
866 &field_code,
867 "Field code should be preserved in {} mode",
868 env
869 );
870
871 prop_assert_eq!(
873 &field.message,
874 &field_message,
875 "Field message should be preserved in {} mode",
876 env
877 );
878
879 let json = serde_json::to_string(&error_response).unwrap();
881 prop_assert!(
882 json.contains(&field_name),
883 "JSON should contain field name in {} mode",
884 env
885 );
886 prop_assert!(
887 json.contains(&field_code),
888 "JSON should contain field code in {} mode",
889 env
890 );
891 prop_assert!(
892 json.contains(&field_message),
893 "JSON should contain field message in {} mode",
894 env
895 );
896 }
897 }
898
899 #[test]
905 fn test_environment_from_env_production() {
906 assert!(matches!(
911 match "production".to_lowercase().as_str() {
912 "production" | "prod" => Environment::Production,
913 _ => Environment::Development,
914 },
915 Environment::Production
916 ));
917
918 assert!(matches!(
919 match "prod".to_lowercase().as_str() {
920 "production" | "prod" => Environment::Production,
921 _ => Environment::Development,
922 },
923 Environment::Production
924 ));
925
926 assert!(matches!(
927 match "PRODUCTION".to_lowercase().as_str() {
928 "production" | "prod" => Environment::Production,
929 _ => Environment::Development,
930 },
931 Environment::Production
932 ));
933
934 assert!(matches!(
935 match "PROD".to_lowercase().as_str() {
936 "production" | "prod" => Environment::Production,
937 _ => Environment::Development,
938 },
939 Environment::Production
940 ));
941 }
942
943 #[test]
944 fn test_environment_from_env_development() {
945 assert!(matches!(
950 match "development".to_lowercase().as_str() {
951 "production" | "prod" => Environment::Production,
952 _ => Environment::Development,
953 },
954 Environment::Development
955 ));
956
957 assert!(matches!(
958 match "dev".to_lowercase().as_str() {
959 "production" | "prod" => Environment::Production,
960 _ => Environment::Development,
961 },
962 Environment::Development
963 ));
964
965 assert!(matches!(
966 match "test".to_lowercase().as_str() {
967 "production" | "prod" => Environment::Production,
968 _ => Environment::Development,
969 },
970 Environment::Development
971 ));
972
973 assert!(matches!(
974 match "anything_else".to_lowercase().as_str() {
975 "production" | "prod" => Environment::Production,
976 _ => Environment::Development,
977 },
978 Environment::Development
979 ));
980 }
981
982 #[test]
983 fn test_environment_default_is_development() {
984 assert_eq!(Environment::default(), Environment::Development);
986 }
987
988 #[test]
989 fn test_environment_display() {
990 assert_eq!(format!("{}", Environment::Development), "development");
991 assert_eq!(format!("{}", Environment::Production), "production");
992 }
993
994 #[test]
995 fn test_environment_is_methods() {
996 assert!(Environment::Production.is_production());
997 assert!(!Environment::Production.is_development());
998 assert!(Environment::Development.is_development());
999 assert!(!Environment::Development.is_production());
1000 }
1001
1002 #[test]
1003 fn test_production_masks_5xx_errors() {
1004 let error =
1005 ApiError::internal("Sensitive database connection string: postgres://user:pass@host");
1006 let response = ErrorResponse::from_api_error(error, Environment::Production);
1007
1008 assert_eq!(response.error.message, "An internal error occurred");
1009 assert!(!response.error.message.contains("postgres"));
1010 }
1011
1012 #[test]
1013 fn test_production_shows_4xx_errors() {
1014 let error = ApiError::bad_request("Invalid email format");
1015 let response = ErrorResponse::from_api_error(error, Environment::Production);
1016
1017 assert_eq!(response.error.message, "Invalid email format");
1019 }
1020
1021 #[test]
1022 fn test_development_shows_all_errors() {
1023 let error = ApiError::internal("Detailed error: connection refused to 192.168.1.1:5432");
1024 let response = ErrorResponse::from_api_error(error, Environment::Development);
1025
1026 assert_eq!(
1027 response.error.message,
1028 "Detailed error: connection refused to 192.168.1.1:5432"
1029 );
1030 }
1031
1032 #[test]
1033 fn test_validation_errors_always_show_fields() {
1034 let fields = vec![
1035 FieldError {
1036 field: "email".to_string(),
1037 code: "invalid_format".to_string(),
1038 message: "Invalid email format".to_string(),
1039 },
1040 FieldError {
1041 field: "age".to_string(),
1042 code: "min".to_string(),
1043 message: "Must be at least 18".to_string(),
1044 },
1045 ];
1046
1047 let error = ApiError::validation(fields.clone());
1048
1049 let prod_response = ErrorResponse::from_api_error(error.clone(), Environment::Production);
1051 assert!(prod_response.error.fields.is_some());
1052 let prod_fields = prod_response.error.fields.unwrap();
1053 assert_eq!(prod_fields.len(), 2);
1054 assert_eq!(prod_fields[0].field, "email");
1055 assert_eq!(prod_fields[1].field, "age");
1056
1057 let dev_response = ErrorResponse::from_api_error(error, Environment::Development);
1059 assert!(dev_response.error.fields.is_some());
1060 let dev_fields = dev_response.error.fields.unwrap();
1061 assert_eq!(dev_fields.len(), 2);
1062 }
1063}