1use crate::error::{ApiError, Result};
58use crate::request::Request;
59use crate::response::IntoResponse;
60use bytes::Bytes;
61use http::{header, StatusCode};
62use http_body_util::Full;
63use serde::de::DeserializeOwned;
64use serde::Serialize;
65use std::future::Future;
66use std::ops::{Deref, DerefMut};
67use std::str::FromStr;
68
69pub trait FromRequestParts: Sized {
73 fn from_request_parts(req: &Request) -> Result<Self>;
75}
76
77pub trait FromRequest: Sized {
81 fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
83}
84
85impl<T: FromRequestParts> FromRequest for T {
87 async fn from_request(req: &mut Request) -> Result<Self> {
88 T::from_request_parts(req)
89 }
90}
91
92#[derive(Debug, Clone, Copy, Default)]
111pub struct Json<T>(pub T);
112
113impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
114 async fn from_request(req: &mut Request) -> Result<Self> {
115 let body = req
116 .take_body()
117 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
118
119 let value: T = serde_json::from_slice(&body)?;
120 Ok(Json(value))
121 }
122}
123
124impl<T> Deref for Json<T> {
125 type Target = T;
126
127 fn deref(&self) -> &Self::Target {
128 &self.0
129 }
130}
131
132impl<T> DerefMut for Json<T> {
133 fn deref_mut(&mut self) -> &mut Self::Target {
134 &mut self.0
135 }
136}
137
138impl<T> From<T> for Json<T> {
139 fn from(value: T) -> Self {
140 Json(value)
141 }
142}
143
144impl<T: Serialize> IntoResponse for Json<T> {
146 fn into_response(self) -> crate::response::Response {
147 match serde_json::to_vec(&self.0) {
148 Ok(body) => http::Response::builder()
149 .status(StatusCode::OK)
150 .header(header::CONTENT_TYPE, "application/json")
151 .body(Full::new(Bytes::from(body)))
152 .unwrap(),
153 Err(err) => {
154 ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
155 }
156 }
157 }
158}
159
160#[derive(Debug, Clone, Copy, Default)]
186pub struct ValidatedJson<T>(pub T);
187
188impl<T> ValidatedJson<T> {
189 pub fn new(value: T) -> Self {
191 Self(value)
192 }
193
194 pub fn into_inner(self) -> T {
196 self.0
197 }
198}
199
200impl<T: DeserializeOwned + rustapi_validate::Validate + Send> FromRequest for ValidatedJson<T> {
201 async fn from_request(req: &mut Request) -> Result<Self> {
202 let body = req
203 .take_body()
204 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
205
206 let value: T = serde_json::from_slice(&body)?;
207
208 if let Err(validation_error) = rustapi_validate::Validate::validate(&value) {
210 return Err(validation_error.into());
212 }
213
214 Ok(ValidatedJson(value))
215 }
216}
217
218impl<T> Deref for ValidatedJson<T> {
219 type Target = T;
220
221 fn deref(&self) -> &Self::Target {
222 &self.0
223 }
224}
225
226impl<T> DerefMut for ValidatedJson<T> {
227 fn deref_mut(&mut self) -> &mut Self::Target {
228 &mut self.0
229 }
230}
231
232impl<T> From<T> for ValidatedJson<T> {
233 fn from(value: T) -> Self {
234 ValidatedJson(value)
235 }
236}
237
238impl<T: Serialize> IntoResponse for ValidatedJson<T> {
239 fn into_response(self) -> crate::response::Response {
240 Json(self.0).into_response()
241 }
242}
243
244#[derive(Debug, Clone)]
262pub struct Query<T>(pub T);
263
264impl<T: DeserializeOwned> FromRequestParts for Query<T> {
265 fn from_request_parts(req: &Request) -> Result<Self> {
266 let query = req.query_string().unwrap_or("");
267 let value: T = serde_urlencoded::from_str(query)
268 .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
269 Ok(Query(value))
270 }
271}
272
273impl<T> Deref for Query<T> {
274 type Target = T;
275
276 fn deref(&self) -> &Self::Target {
277 &self.0
278 }
279}
280
281#[derive(Debug, Clone)]
303pub struct Path<T>(pub T);
304
305impl<T: FromStr> FromRequestParts for Path<T>
306where
307 T::Err: std::fmt::Display,
308{
309 fn from_request_parts(req: &Request) -> Result<Self> {
310 let params = req.path_params();
311
312 if let Some((_, value)) = params.iter().next() {
314 let parsed = value
315 .parse::<T>()
316 .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
317 return Ok(Path(parsed));
318 }
319
320 Err(ApiError::internal("Missing path parameter"))
321 }
322}
323
324impl<T> Deref for Path<T> {
325 type Target = T;
326
327 fn deref(&self) -> &Self::Target {
328 &self.0
329 }
330}
331
332#[derive(Debug, Clone)]
349pub struct State<T>(pub T);
350
351impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
352 fn from_request_parts(req: &Request) -> Result<Self> {
353 req.state().get::<T>().cloned().map(State).ok_or_else(|| {
354 ApiError::internal(format!(
355 "State of type `{}` not found. Did you forget to call .state()?",
356 std::any::type_name::<T>()
357 ))
358 })
359 }
360}
361
362impl<T> Deref for State<T> {
363 type Target = T;
364
365 fn deref(&self) -> &Self::Target {
366 &self.0
367 }
368}
369
370#[derive(Debug, Clone)]
372pub struct Body(pub Bytes);
373
374impl FromRequest for Body {
375 async fn from_request(req: &mut Request) -> Result<Self> {
376 let body = req
377 .take_body()
378 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
379 Ok(Body(body))
380 }
381}
382
383impl Deref for Body {
384 type Target = Bytes;
385
386 fn deref(&self) -> &Self::Target {
387 &self.0
388 }
389}
390
391impl<T: FromRequestParts> FromRequestParts for Option<T> {
395 fn from_request_parts(req: &Request) -> Result<Self> {
396 Ok(T::from_request_parts(req).ok())
397 }
398}
399
400#[derive(Debug, Clone)]
418pub struct Headers(pub http::HeaderMap);
419
420impl Headers {
421 pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
423 self.0.get(name)
424 }
425
426 pub fn contains(&self, name: &str) -> bool {
428 self.0.contains_key(name)
429 }
430
431 pub fn len(&self) -> usize {
433 self.0.len()
434 }
435
436 pub fn is_empty(&self) -> bool {
438 self.0.is_empty()
439 }
440
441 pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
443 self.0.iter()
444 }
445}
446
447impl FromRequestParts for Headers {
448 fn from_request_parts(req: &Request) -> Result<Self> {
449 Ok(Headers(req.headers().clone()))
450 }
451}
452
453impl Deref for Headers {
454 type Target = http::HeaderMap;
455
456 fn deref(&self) -> &Self::Target {
457 &self.0
458 }
459}
460
461#[derive(Debug, Clone)]
480pub struct HeaderValue(pub String, pub &'static str);
481
482impl HeaderValue {
483 pub fn new(name: &'static str, value: String) -> Self {
485 Self(value, name)
486 }
487
488 pub fn value(&self) -> &str {
490 &self.0
491 }
492
493 pub fn name(&self) -> &'static str {
495 self.1
496 }
497
498 pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
500 req.headers()
501 .get(name)
502 .and_then(|v| v.to_str().ok())
503 .map(|s| HeaderValue(s.to_string(), name))
504 .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
505 }
506}
507
508impl Deref for HeaderValue {
509 type Target = String;
510
511 fn deref(&self) -> &Self::Target {
512 &self.0
513 }
514}
515
516#[derive(Debug, Clone)]
534pub struct Extension<T>(pub T);
535
536impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
537 fn from_request_parts(req: &Request) -> Result<Self> {
538 req.extensions()
539 .get::<T>()
540 .cloned()
541 .map(Extension)
542 .ok_or_else(|| {
543 ApiError::internal(format!(
544 "Extension of type `{}` not found. Did middleware insert it?",
545 std::any::type_name::<T>()
546 ))
547 })
548 }
549}
550
551impl<T> Deref for Extension<T> {
552 type Target = T;
553
554 fn deref(&self) -> &Self::Target {
555 &self.0
556 }
557}
558
559impl<T> DerefMut for Extension<T> {
560 fn deref_mut(&mut self) -> &mut Self::Target {
561 &mut self.0
562 }
563}
564
565#[derive(Debug, Clone)]
580pub struct ClientIp(pub std::net::IpAddr);
581
582impl ClientIp {
583 pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
585 if trust_proxy {
586 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
588 if let Ok(forwarded_str) = forwarded.to_str() {
589 if let Some(first_ip) = forwarded_str.split(',').next() {
591 if let Ok(ip) = first_ip.trim().parse() {
592 return Ok(ClientIp(ip));
593 }
594 }
595 }
596 }
597 }
598
599 if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
601 return Ok(ClientIp(addr.ip()));
602 }
603
604 Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
606 127, 0, 0, 1,
607 ))))
608 }
609}
610
611impl FromRequestParts for ClientIp {
612 fn from_request_parts(req: &Request) -> Result<Self> {
613 Self::extract_with_config(req, true)
615 }
616}
617
618#[cfg(feature = "cookies")]
636#[derive(Debug, Clone)]
637pub struct Cookies(pub cookie::CookieJar);
638
639#[cfg(feature = "cookies")]
640impl Cookies {
641 pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
643 self.0.get(name)
644 }
645
646 pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
648 self.0.iter()
649 }
650
651 pub fn contains(&self, name: &str) -> bool {
653 self.0.get(name).is_some()
654 }
655}
656
657#[cfg(feature = "cookies")]
658impl FromRequestParts for Cookies {
659 fn from_request_parts(req: &Request) -> Result<Self> {
660 let mut jar = cookie::CookieJar::new();
661
662 if let Some(cookie_header) = req.headers().get(header::COOKIE) {
663 if let Ok(cookie_str) = cookie_header.to_str() {
664 for cookie_part in cookie_str.split(';') {
666 let trimmed = cookie_part.trim();
667 if !trimmed.is_empty() {
668 if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
669 jar.add_original(cookie.into_owned());
670 }
671 }
672 }
673 }
674 }
675
676 Ok(Cookies(jar))
677 }
678}
679
680#[cfg(feature = "cookies")]
681impl Deref for Cookies {
682 type Target = cookie::CookieJar;
683
684 fn deref(&self) -> &Self::Target {
685 &self.0
686 }
687}
688
689macro_rules! impl_from_request_parts_for_primitives {
691 ($($ty:ty),*) => {
692 $(
693 impl FromRequestParts for $ty {
694 fn from_request_parts(req: &Request) -> Result<Self> {
695 let Path(value) = Path::<$ty>::from_request_parts(req)?;
696 Ok(value)
697 }
698 }
699 )*
700 };
701}
702
703impl_from_request_parts_for_primitives!(
704 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
705);
706
707use rustapi_openapi::utoipa_types::openapi;
710use rustapi_openapi::{
711 IntoParams, MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier,
712 ResponseSpec, Schema, SchemaRef,
713};
714use std::collections::HashMap;
715
716impl<T: for<'a> Schema<'a>> OperationModifier for ValidatedJson<T> {
718 fn update_operation(op: &mut Operation) {
719 let (name, _) = T::schema();
720
721 let schema_ref = SchemaRef::Ref {
722 reference: format!("#/components/schemas/{}", name),
723 };
724
725 let mut content = HashMap::new();
726 content.insert(
727 "application/json".to_string(),
728 MediaType { schema: schema_ref },
729 );
730
731 op.request_body = Some(RequestBody {
732 required: true,
733 content,
734 });
735
736 op.responses.insert(
738 "422".to_string(),
739 ResponseSpec {
740 description: "Validation Error".to_string(),
741 content: {
742 let mut map = HashMap::new();
743 map.insert(
744 "application/json".to_string(),
745 MediaType {
746 schema: SchemaRef::Ref {
747 reference: "#/components/schemas/ValidationErrorSchema".to_string(),
748 },
749 },
750 );
751 Some(map)
752 },
753 },
754 );
755 }
756}
757
758impl<T: for<'a> Schema<'a>> OperationModifier for Json<T> {
760 fn update_operation(op: &mut Operation) {
761 let (name, _) = T::schema();
762
763 let schema_ref = SchemaRef::Ref {
764 reference: format!("#/components/schemas/{}", name),
765 };
766
767 let mut content = HashMap::new();
768 content.insert(
769 "application/json".to_string(),
770 MediaType { schema: schema_ref },
771 );
772
773 op.request_body = Some(RequestBody {
774 required: true,
775 content,
776 });
777 }
778}
779
780impl<T> OperationModifier for Path<T> {
784 fn update_operation(_op: &mut Operation) {
785 }
792}
793
794impl<T: IntoParams> OperationModifier for Query<T> {
796 fn update_operation(op: &mut Operation) {
797 let params = T::into_params(|| Some(openapi::path::ParameterIn::Query));
798
799 let new_params: Vec<Parameter> = params
800 .into_iter()
801 .map(|p| {
802 let schema = match p.schema {
803 Some(schema) => match schema {
804 openapi::RefOr::Ref(r) => SchemaRef::Ref {
805 reference: r.ref_location,
806 },
807 openapi::RefOr::T(s) => {
808 let value = serde_json::to_value(s).unwrap_or(serde_json::Value::Null);
809 SchemaRef::Inline(value)
810 }
811 },
812 None => SchemaRef::Inline(serde_json::Value::Null),
813 };
814
815 let required = match p.required {
816 openapi::Required::True => true,
817 openapi::Required::False => false,
818 };
819
820 Parameter {
821 name: p.name,
822 location: "query".to_string(), required,
824 description: p.description,
825 schema,
826 }
827 })
828 .collect();
829
830 if let Some(existing) = &mut op.parameters {
831 existing.extend(new_params);
832 } else {
833 op.parameters = Some(new_params);
834 }
835 }
836}
837
838impl<T> OperationModifier for State<T> {
840 fn update_operation(_op: &mut Operation) {}
841}
842
843impl OperationModifier for Body {
845 fn update_operation(op: &mut Operation) {
846 let mut content = HashMap::new();
847 content.insert(
848 "application/octet-stream".to_string(),
849 MediaType {
850 schema: SchemaRef::Inline(
851 serde_json::json!({ "type": "string", "format": "binary" }),
852 ),
853 },
854 );
855
856 op.request_body = Some(RequestBody {
857 required: true,
858 content,
859 });
860 }
861}
862
863impl<T: for<'a> Schema<'a>> ResponseModifier for Json<T> {
867 fn update_response(op: &mut Operation) {
868 let (name, _) = T::schema();
869
870 let schema_ref = SchemaRef::Ref {
871 reference: format!("#/components/schemas/{}", name),
872 };
873
874 op.responses.insert(
875 "200".to_string(),
876 ResponseSpec {
877 description: "Successful response".to_string(),
878 content: {
879 let mut map = HashMap::new();
880 map.insert(
881 "application/json".to_string(),
882 MediaType { schema: schema_ref },
883 );
884 Some(map)
885 },
886 },
887 );
888 }
889}
890
891#[cfg(test)]
892mod tests {
893 use super::*;
894 use bytes::Bytes;
895 use http::{Extensions, Method};
896 use proptest::prelude::*;
897 use proptest::test_runner::TestCaseError;
898 use std::collections::HashMap;
899 use std::sync::Arc;
900
901 fn create_test_request_with_headers(
903 method: Method,
904 path: &str,
905 headers: Vec<(&str, &str)>,
906 ) -> Request {
907 let uri: http::Uri = path.parse().unwrap();
908 let mut builder = http::Request::builder().method(method).uri(uri);
909
910 for (name, value) in headers {
911 builder = builder.header(name, value);
912 }
913
914 let req = builder.body(()).unwrap();
915 let (parts, _) = req.into_parts();
916
917 Request::new(
918 parts,
919 Bytes::new(),
920 Arc::new(Extensions::new()),
921 HashMap::new(),
922 )
923 }
924
925 fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
927 method: Method,
928 path: &str,
929 extension: T,
930 ) -> Request {
931 let uri: http::Uri = path.parse().unwrap();
932 let builder = http::Request::builder().method(method).uri(uri);
933
934 let req = builder.body(()).unwrap();
935 let (mut parts, _) = req.into_parts();
936 parts.extensions.insert(extension);
937
938 Request::new(
939 parts,
940 Bytes::new(),
941 Arc::new(Extensions::new()),
942 HashMap::new(),
943 )
944 }
945
946 proptest! {
953 #![proptest_config(ProptestConfig::with_cases(100))]
954
955 #[test]
956 fn prop_headers_extractor_completeness(
957 headers in prop::collection::vec(
960 (
961 "[a-z][a-z0-9-]{0,20}", "[a-zA-Z0-9 ]{1,50}" ),
964 0..10
965 )
966 ) {
967 let result: Result<(), TestCaseError> = (|| {
968 let header_tuples: Vec<(&str, &str)> = headers
970 .iter()
971 .map(|(k, v)| (k.as_str(), v.as_str()))
972 .collect();
973
974 let request = create_test_request_with_headers(
976 Method::GET,
977 "/test",
978 header_tuples.clone(),
979 );
980
981 let extracted = Headers::from_request_parts(&request)
983 .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
984
985 for (name, value) in &headers {
988 let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
990 prop_assert!(
991 !all_values.is_empty(),
992 "Header '{}' not found",
993 name
994 );
995
996 let value_found = all_values.iter().any(|v| {
998 v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
999 });
1000
1001 prop_assert!(
1002 value_found,
1003 "Header '{}' value '{}' not found in extracted values",
1004 name,
1005 value
1006 );
1007 }
1008
1009 Ok(())
1010 })();
1011 result?;
1012 }
1013 }
1014
1015 proptest! {
1022 #![proptest_config(ProptestConfig::with_cases(100))]
1023
1024 #[test]
1025 fn prop_header_value_extractor_correctness(
1026 header_name in "[a-z][a-z0-9-]{0,20}",
1027 header_value in "[a-zA-Z0-9 ]{1,50}",
1028 has_header in prop::bool::ANY,
1029 ) {
1030 let result: Result<(), TestCaseError> = (|| {
1031 let headers = if has_header {
1032 vec![(header_name.as_str(), header_value.as_str())]
1033 } else {
1034 vec![]
1035 };
1036
1037 let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1038
1039 let test_header = "x-test-header";
1042 let request_with_known_header = if has_header {
1043 create_test_request_with_headers(
1044 Method::GET,
1045 "/test",
1046 vec![(test_header, header_value.as_str())],
1047 )
1048 } else {
1049 create_test_request_with_headers(Method::GET, "/test", vec![])
1050 };
1051
1052 let result = HeaderValue::extract(&request_with_known_header, test_header);
1053
1054 if has_header {
1055 let extracted = result
1056 .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1057 prop_assert_eq!(
1058 extracted.value(),
1059 header_value.as_str(),
1060 "Header value mismatch"
1061 );
1062 } else {
1063 prop_assert!(
1064 result.is_err(),
1065 "Expected error when header is missing"
1066 );
1067 }
1068
1069 Ok(())
1070 })();
1071 result?;
1072 }
1073 }
1074
1075 proptest! {
1082 #![proptest_config(ProptestConfig::with_cases(100))]
1083
1084 #[test]
1085 fn prop_client_ip_extractor_with_forwarding(
1086 forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1088 .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1089 socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1090 .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1091 has_forwarded_header in prop::bool::ANY,
1092 trust_proxy in prop::bool::ANY,
1093 ) {
1094 let result: Result<(), TestCaseError> = (|| {
1095 let headers = if has_forwarded_header {
1096 vec![("x-forwarded-for", forwarded_ip.as_str())]
1097 } else {
1098 vec![]
1099 };
1100
1101 let uri: http::Uri = "/test".parse().unwrap();
1103 let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1104 for (name, value) in &headers {
1105 builder = builder.header(*name, *value);
1106 }
1107 let req = builder.body(()).unwrap();
1108 let (mut parts, _) = req.into_parts();
1109
1110 let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1112 parts.extensions.insert(socket_addr);
1113
1114 let request = Request::new(
1115 parts,
1116 Bytes::new(),
1117 Arc::new(Extensions::new()),
1118 HashMap::new(),
1119 );
1120
1121 let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1122 .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1123
1124 if trust_proxy && has_forwarded_header {
1125 let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1127 .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1128 prop_assert_eq!(
1129 extracted.0,
1130 expected_ip,
1131 "Should use X-Forwarded-For IP when trust_proxy is enabled"
1132 );
1133 } else {
1134 prop_assert_eq!(
1136 extracted.0,
1137 socket_ip,
1138 "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1139 );
1140 }
1141
1142 Ok(())
1143 })();
1144 result?;
1145 }
1146 }
1147
1148 proptest! {
1155 #![proptest_config(ProptestConfig::with_cases(100))]
1156
1157 #[test]
1158 fn prop_extension_extractor_retrieval(
1159 value in any::<i64>(),
1160 has_extension in prop::bool::ANY,
1161 ) {
1162 let result: Result<(), TestCaseError> = (|| {
1163 #[derive(Clone, Debug, PartialEq)]
1165 struct TestExtension(i64);
1166
1167 let uri: http::Uri = "/test".parse().unwrap();
1168 let builder = http::Request::builder().method(Method::GET).uri(uri);
1169 let req = builder.body(()).unwrap();
1170 let (mut parts, _) = req.into_parts();
1171
1172 if has_extension {
1173 parts.extensions.insert(TestExtension(value));
1174 }
1175
1176 let request = Request::new(
1177 parts,
1178 Bytes::new(),
1179 Arc::new(Extensions::new()),
1180 HashMap::new(),
1181 );
1182
1183 let result = Extension::<TestExtension>::from_request_parts(&request);
1184
1185 if has_extension {
1186 let extracted = result
1187 .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1188 prop_assert_eq!(
1189 extracted.0,
1190 TestExtension(value),
1191 "Extension value mismatch"
1192 );
1193 } else {
1194 prop_assert!(
1195 result.is_err(),
1196 "Expected error when extension is missing"
1197 );
1198 }
1199
1200 Ok(())
1201 })();
1202 result?;
1203 }
1204 }
1205
1206 #[test]
1209 fn test_headers_extractor_basic() {
1210 let request = create_test_request_with_headers(
1211 Method::GET,
1212 "/test",
1213 vec![
1214 ("content-type", "application/json"),
1215 ("accept", "text/html"),
1216 ],
1217 );
1218
1219 let headers = Headers::from_request_parts(&request).unwrap();
1220
1221 assert!(headers.contains("content-type"));
1222 assert!(headers.contains("accept"));
1223 assert!(!headers.contains("x-custom"));
1224 assert_eq!(headers.len(), 2);
1225 }
1226
1227 #[test]
1228 fn test_header_value_extractor_present() {
1229 let request = create_test_request_with_headers(
1230 Method::GET,
1231 "/test",
1232 vec![("authorization", "Bearer token123")],
1233 );
1234
1235 let result = HeaderValue::extract(&request, "authorization");
1236 assert!(result.is_ok());
1237 assert_eq!(result.unwrap().value(), "Bearer token123");
1238 }
1239
1240 #[test]
1241 fn test_header_value_extractor_missing() {
1242 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1243
1244 let result = HeaderValue::extract(&request, "authorization");
1245 assert!(result.is_err());
1246 }
1247
1248 #[test]
1249 fn test_client_ip_from_forwarded_header() {
1250 let request = create_test_request_with_headers(
1251 Method::GET,
1252 "/test",
1253 vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1254 );
1255
1256 let ip = ClientIp::extract_with_config(&request, true).unwrap();
1257 assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1258 }
1259
1260 #[test]
1261 fn test_client_ip_ignores_forwarded_when_not_trusted() {
1262 let uri: http::Uri = "/test".parse().unwrap();
1263 let builder = http::Request::builder()
1264 .method(Method::GET)
1265 .uri(uri)
1266 .header("x-forwarded-for", "192.168.1.100");
1267 let req = builder.body(()).unwrap();
1268 let (mut parts, _) = req.into_parts();
1269
1270 let socket_addr = std::net::SocketAddr::new(
1271 std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1272 8080,
1273 );
1274 parts.extensions.insert(socket_addr);
1275
1276 let request = Request::new(
1277 parts,
1278 Bytes::new(),
1279 Arc::new(Extensions::new()),
1280 HashMap::new(),
1281 );
1282
1283 let ip = ClientIp::extract_with_config(&request, false).unwrap();
1284 assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1285 }
1286
1287 #[test]
1288 fn test_extension_extractor_present() {
1289 #[derive(Clone, Debug, PartialEq)]
1290 struct MyData(String);
1291
1292 let request =
1293 create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1294
1295 let result = Extension::<MyData>::from_request_parts(&request);
1296 assert!(result.is_ok());
1297 assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1298 }
1299
1300 #[test]
1301 fn test_extension_extractor_missing() {
1302 #[derive(Clone, Debug)]
1303 #[allow(dead_code)]
1304 struct MyData(String);
1305
1306 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1307
1308 let result = Extension::<MyData>::from_request_parts(&request);
1309 assert!(result.is_err());
1310 }
1311
1312 #[cfg(feature = "cookies")]
1314 mod cookies_tests {
1315 use super::*;
1316
1317 proptest! {
1325 #![proptest_config(ProptestConfig::with_cases(100))]
1326
1327 #[test]
1328 fn prop_cookies_extractor_parsing(
1329 cookies in prop::collection::vec(
1332 (
1333 "[a-zA-Z][a-zA-Z0-9_]{0,15}", "[a-zA-Z0-9]{1,30}" ),
1336 0..5
1337 )
1338 ) {
1339 let result: Result<(), TestCaseError> = (|| {
1340 let cookie_header = cookies
1342 .iter()
1343 .map(|(name, value)| format!("{}={}", name, value))
1344 .collect::<Vec<_>>()
1345 .join("; ");
1346
1347 let headers = if !cookies.is_empty() {
1348 vec![("cookie", cookie_header.as_str())]
1349 } else {
1350 vec![]
1351 };
1352
1353 let request = create_test_request_with_headers(Method::GET, "/test", headers);
1354
1355 let extracted = Cookies::from_request_parts(&request)
1357 .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1358
1359 let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1361 for (name, value) in &cookies {
1362 expected_cookies.insert(name.as_str(), value.as_str());
1363 }
1364
1365 for (name, expected_value) in &expected_cookies {
1367 let cookie = extracted.get(name)
1368 .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1369
1370 prop_assert_eq!(
1371 cookie.value(),
1372 *expected_value,
1373 "Cookie '{}' value mismatch",
1374 name
1375 );
1376 }
1377
1378 let extracted_count = extracted.iter().count();
1380 prop_assert_eq!(
1381 extracted_count,
1382 expected_cookies.len(),
1383 "Expected {} unique cookies, got {}",
1384 expected_cookies.len(),
1385 extracted_count
1386 );
1387
1388 Ok(())
1389 })();
1390 result?;
1391 }
1392 }
1393
1394 #[test]
1395 fn test_cookies_extractor_basic() {
1396 let request = create_test_request_with_headers(
1397 Method::GET,
1398 "/test",
1399 vec![("cookie", "session=abc123; user=john")],
1400 );
1401
1402 let cookies = Cookies::from_request_parts(&request).unwrap();
1403
1404 assert!(cookies.contains("session"));
1405 assert!(cookies.contains("user"));
1406 assert!(!cookies.contains("other"));
1407
1408 assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1409 assert_eq!(cookies.get("user").unwrap().value(), "john");
1410 }
1411
1412 #[test]
1413 fn test_cookies_extractor_empty() {
1414 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1415
1416 let cookies = Cookies::from_request_parts(&request).unwrap();
1417 assert_eq!(cookies.iter().count(), 0);
1418 }
1419
1420 #[test]
1421 fn test_cookies_extractor_single() {
1422 let request = create_test_request_with_headers(
1423 Method::GET,
1424 "/test",
1425 vec![("cookie", "token=xyz789")],
1426 );
1427
1428 let cookies = Cookies::from_request_parts(&request).unwrap();
1429 assert_eq!(cookies.iter().count(), 1);
1430 assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1431 }
1432 }
1433}