1use crate::error::{ApiError, Result};
58use crate::request::Request;
59use crate::response::IntoResponse;
60use bytes::Bytes;
61use http::{header, StatusCode};
62use http_body_util::Full;
63use serde::de::DeserializeOwned;
64use serde::Serialize;
65use std::future::Future;
66use std::ops::{Deref, DerefMut};
67use std::str::FromStr;
68
69pub trait FromRequestParts: Sized {
73 fn from_request_parts(req: &Request) -> Result<Self>;
75}
76
77pub trait FromRequest: Sized {
81 fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
83}
84
85impl<T: FromRequestParts> FromRequest for T {
87 async fn from_request(req: &mut Request) -> Result<Self> {
88 T::from_request_parts(req)
89 }
90}
91
92#[derive(Debug, Clone, Copy, Default)]
111pub struct Json<T>(pub T);
112
113impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
114 async fn from_request(req: &mut Request) -> Result<Self> {
115 let body = req
116 .take_body()
117 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
118
119 let value: T = serde_json::from_slice(&body)?;
120 Ok(Json(value))
121 }
122}
123
124impl<T> Deref for Json<T> {
125 type Target = T;
126
127 fn deref(&self) -> &Self::Target {
128 &self.0
129 }
130}
131
132impl<T> DerefMut for Json<T> {
133 fn deref_mut(&mut self) -> &mut Self::Target {
134 &mut self.0
135 }
136}
137
138impl<T> From<T> for Json<T> {
139 fn from(value: T) -> Self {
140 Json(value)
141 }
142}
143
144impl<T: Serialize> IntoResponse for Json<T> {
146 fn into_response(self) -> crate::response::Response {
147 match serde_json::to_vec(&self.0) {
148 Ok(body) => http::Response::builder()
149 .status(StatusCode::OK)
150 .header(header::CONTENT_TYPE, "application/json")
151 .body(Full::new(Bytes::from(body)))
152 .unwrap(),
153 Err(err) => {
154 ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
155 }
156 }
157 }
158}
159
160#[derive(Debug, Clone, Copy, Default)]
186pub struct ValidatedJson<T>(pub T);
187
188impl<T> ValidatedJson<T> {
189 pub fn new(value: T) -> Self {
191 Self(value)
192 }
193
194 pub fn into_inner(self) -> T {
196 self.0
197 }
198}
199
200impl<T: DeserializeOwned + rustapi_validate::Validate + Send> FromRequest for ValidatedJson<T> {
201 async fn from_request(req: &mut Request) -> Result<Self> {
202 let body = req
204 .take_body()
205 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
206
207 let value: T = serde_json::from_slice(&body)?;
208
209 if let Err(validation_error) = rustapi_validate::Validate::validate(&value) {
211 return Err(validation_error.into());
213 }
214
215 Ok(ValidatedJson(value))
216 }
217}
218
219impl<T> Deref for ValidatedJson<T> {
220 type Target = T;
221
222 fn deref(&self) -> &Self::Target {
223 &self.0
224 }
225}
226
227impl<T> DerefMut for ValidatedJson<T> {
228 fn deref_mut(&mut self) -> &mut Self::Target {
229 &mut self.0
230 }
231}
232
233impl<T> From<T> for ValidatedJson<T> {
234 fn from(value: T) -> Self {
235 ValidatedJson(value)
236 }
237}
238
239impl<T: Serialize> IntoResponse for ValidatedJson<T> {
240 fn into_response(self) -> crate::response::Response {
241 Json(self.0).into_response()
242 }
243}
244
245#[derive(Debug, Clone)]
263pub struct Query<T>(pub T);
264
265impl<T: DeserializeOwned> FromRequestParts for Query<T> {
266 fn from_request_parts(req: &Request) -> Result<Self> {
267 let query = req.query_string().unwrap_or("");
268 let value: T = serde_urlencoded::from_str(query)
269 .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
270 Ok(Query(value))
271 }
272}
273
274impl<T> Deref for Query<T> {
275 type Target = T;
276
277 fn deref(&self) -> &Self::Target {
278 &self.0
279 }
280}
281
282#[derive(Debug, Clone)]
304pub struct Path<T>(pub T);
305
306impl<T: FromStr> FromRequestParts for Path<T>
307where
308 T::Err: std::fmt::Display,
309{
310 fn from_request_parts(req: &Request) -> Result<Self> {
311 let params = req.path_params();
312
313 if let Some((_, value)) = params.iter().next() {
315 let parsed = value
316 .parse::<T>()
317 .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
318 return Ok(Path(parsed));
319 }
320
321 Err(ApiError::internal("Missing path parameter"))
322 }
323}
324
325impl<T> Deref for Path<T> {
326 type Target = T;
327
328 fn deref(&self) -> &Self::Target {
329 &self.0
330 }
331}
332
333#[derive(Debug, Clone)]
350pub struct State<T>(pub T);
351
352impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
353 fn from_request_parts(req: &Request) -> Result<Self> {
354 req.state().get::<T>().cloned().map(State).ok_or_else(|| {
355 ApiError::internal(format!(
356 "State of type `{}` not found. Did you forget to call .state()?",
357 std::any::type_name::<T>()
358 ))
359 })
360 }
361}
362
363impl<T> Deref for State<T> {
364 type Target = T;
365
366 fn deref(&self) -> &Self::Target {
367 &self.0
368 }
369}
370
371#[derive(Debug, Clone)]
373pub struct Body(pub Bytes);
374
375impl FromRequest for Body {
376 async fn from_request(req: &mut Request) -> Result<Self> {
377 let body = req
378 .take_body()
379 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
380 Ok(Body(body))
381 }
382}
383
384impl Deref for Body {
385 type Target = Bytes;
386
387 fn deref(&self) -> &Self::Target {
388 &self.0
389 }
390}
391
392impl<T: FromRequestParts> FromRequestParts for Option<T> {
396 fn from_request_parts(req: &Request) -> Result<Self> {
397 Ok(T::from_request_parts(req).ok())
398 }
399}
400
401#[derive(Debug, Clone)]
419pub struct Headers(pub http::HeaderMap);
420
421impl Headers {
422 pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
424 self.0.get(name)
425 }
426
427 pub fn contains(&self, name: &str) -> bool {
429 self.0.contains_key(name)
430 }
431
432 pub fn len(&self) -> usize {
434 self.0.len()
435 }
436
437 pub fn is_empty(&self) -> bool {
439 self.0.is_empty()
440 }
441
442 pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
444 self.0.iter()
445 }
446}
447
448impl FromRequestParts for Headers {
449 fn from_request_parts(req: &Request) -> Result<Self> {
450 Ok(Headers(req.headers().clone()))
451 }
452}
453
454impl Deref for Headers {
455 type Target = http::HeaderMap;
456
457 fn deref(&self) -> &Self::Target {
458 &self.0
459 }
460}
461
462#[derive(Debug, Clone)]
481pub struct HeaderValue(pub String, pub &'static str);
482
483impl HeaderValue {
484 pub fn new(name: &'static str, value: String) -> Self {
486 Self(value, name)
487 }
488
489 pub fn value(&self) -> &str {
491 &self.0
492 }
493
494 pub fn name(&self) -> &'static str {
496 self.1
497 }
498
499 pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
501 req.headers()
502 .get(name)
503 .and_then(|v| v.to_str().ok())
504 .map(|s| HeaderValue(s.to_string(), name))
505 .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
506 }
507}
508
509impl Deref for HeaderValue {
510 type Target = String;
511
512 fn deref(&self) -> &Self::Target {
513 &self.0
514 }
515}
516
517#[derive(Debug, Clone)]
535pub struct Extension<T>(pub T);
536
537impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
538 fn from_request_parts(req: &Request) -> Result<Self> {
539 req.extensions()
540 .get::<T>()
541 .cloned()
542 .map(Extension)
543 .ok_or_else(|| {
544 ApiError::internal(format!(
545 "Extension of type `{}` not found. Did middleware insert it?",
546 std::any::type_name::<T>()
547 ))
548 })
549 }
550}
551
552impl<T> Deref for Extension<T> {
553 type Target = T;
554
555 fn deref(&self) -> &Self::Target {
556 &self.0
557 }
558}
559
560impl<T> DerefMut for Extension<T> {
561 fn deref_mut(&mut self) -> &mut Self::Target {
562 &mut self.0
563 }
564}
565
566#[derive(Debug, Clone)]
581pub struct ClientIp(pub std::net::IpAddr);
582
583impl ClientIp {
584 pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
586 if trust_proxy {
587 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
589 if let Ok(forwarded_str) = forwarded.to_str() {
590 if let Some(first_ip) = forwarded_str.split(',').next() {
592 if let Ok(ip) = first_ip.trim().parse() {
593 return Ok(ClientIp(ip));
594 }
595 }
596 }
597 }
598 }
599
600 if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
602 return Ok(ClientIp(addr.ip()));
603 }
604
605 Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
607 127, 0, 0, 1,
608 ))))
609 }
610}
611
612impl FromRequestParts for ClientIp {
613 fn from_request_parts(req: &Request) -> Result<Self> {
614 Self::extract_with_config(req, true)
616 }
617}
618
619#[cfg(feature = "cookies")]
637#[derive(Debug, Clone)]
638pub struct Cookies(pub cookie::CookieJar);
639
640#[cfg(feature = "cookies")]
641impl Cookies {
642 pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
644 self.0.get(name)
645 }
646
647 pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
649 self.0.iter()
650 }
651
652 pub fn contains(&self, name: &str) -> bool {
654 self.0.get(name).is_some()
655 }
656}
657
658#[cfg(feature = "cookies")]
659impl FromRequestParts for Cookies {
660 fn from_request_parts(req: &Request) -> Result<Self> {
661 let mut jar = cookie::CookieJar::new();
662
663 if let Some(cookie_header) = req.headers().get(header::COOKIE) {
664 if let Ok(cookie_str) = cookie_header.to_str() {
665 for cookie_part in cookie_str.split(';') {
667 let trimmed = cookie_part.trim();
668 if !trimmed.is_empty() {
669 if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
670 jar.add_original(cookie.into_owned());
671 }
672 }
673 }
674 }
675 }
676
677 Ok(Cookies(jar))
678 }
679}
680
681#[cfg(feature = "cookies")]
682impl Deref for Cookies {
683 type Target = cookie::CookieJar;
684
685 fn deref(&self) -> &Self::Target {
686 &self.0
687 }
688}
689
690macro_rules! impl_from_request_parts_for_primitives {
692 ($($ty:ty),*) => {
693 $(
694 impl FromRequestParts for $ty {
695 fn from_request_parts(req: &Request) -> Result<Self> {
696 let Path(value) = Path::<$ty>::from_request_parts(req)?;
697 Ok(value)
698 }
699 }
700 )*
701 };
702}
703
704impl_from_request_parts_for_primitives!(
705 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
706);
707
708use rustapi_openapi::utoipa_types::openapi;
711use rustapi_openapi::{
712 IntoParams, MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier,
713 ResponseSpec, Schema, SchemaRef,
714};
715use std::collections::HashMap;
716
717impl<T: for<'a> Schema<'a>> OperationModifier for ValidatedJson<T> {
719 fn update_operation(op: &mut Operation) {
720 let (name, _) = T::schema();
721
722 let schema_ref = SchemaRef::Ref {
723 reference: format!("#/components/schemas/{}", name),
724 };
725
726 let mut content = HashMap::new();
727 content.insert(
728 "application/json".to_string(),
729 MediaType { schema: schema_ref },
730 );
731
732 op.request_body = Some(RequestBody {
733 required: true,
734 content,
735 });
736
737 op.responses.insert(
739 "422".to_string(),
740 ResponseSpec {
741 description: "Validation Error".to_string(),
742 content: {
743 let mut map = HashMap::new();
744 map.insert(
745 "application/json".to_string(),
746 MediaType {
747 schema: SchemaRef::Ref {
748 reference: "#/components/schemas/ValidationErrorSchema".to_string(),
749 },
750 },
751 );
752 Some(map)
753 },
754 },
755 );
756 }
757}
758
759impl<T: for<'a> Schema<'a>> OperationModifier for Json<T> {
761 fn update_operation(op: &mut Operation) {
762 let (name, _) = T::schema();
763
764 let schema_ref = SchemaRef::Ref {
765 reference: format!("#/components/schemas/{}", name),
766 };
767
768 let mut content = HashMap::new();
769 content.insert(
770 "application/json".to_string(),
771 MediaType { schema: schema_ref },
772 );
773
774 op.request_body = Some(RequestBody {
775 required: true,
776 content,
777 });
778 }
779}
780
781impl<T> OperationModifier for Path<T> {
783 fn update_operation(_op: &mut Operation) {
784 }
786}
787
788impl<T: IntoParams> OperationModifier for Query<T> {
790 fn update_operation(op: &mut Operation) {
791 let params = T::into_params(|| Some(openapi::path::ParameterIn::Query));
792
793 let new_params: Vec<Parameter> = params
794 .into_iter()
795 .map(|p| {
796 let schema = match p.schema {
797 Some(schema) => match schema {
798 openapi::RefOr::Ref(r) => SchemaRef::Ref {
799 reference: r.ref_location,
800 },
801 openapi::RefOr::T(s) => {
802 let value = serde_json::to_value(s).unwrap_or(serde_json::Value::Null);
803 SchemaRef::Inline(value)
804 }
805 },
806 None => SchemaRef::Inline(serde_json::Value::Null),
807 };
808
809 let required = match p.required {
810 openapi::Required::True => true,
811 openapi::Required::False => false,
812 };
813
814 Parameter {
815 name: p.name,
816 location: "query".to_string(), required,
818 description: p.description,
819 schema,
820 }
821 })
822 .collect();
823
824 if let Some(existing) = &mut op.parameters {
825 existing.extend(new_params);
826 } else {
827 op.parameters = Some(new_params);
828 }
829 }
830}
831
832impl<T> OperationModifier for State<T> {
834 fn update_operation(_op: &mut Operation) {}
835}
836
837impl OperationModifier for Body {
839 fn update_operation(op: &mut Operation) {
840 let mut content = HashMap::new();
841 content.insert(
842 "application/octet-stream".to_string(),
843 MediaType {
844 schema: SchemaRef::Inline(
845 serde_json::json!({ "type": "string", "format": "binary" }),
846 ),
847 },
848 );
849
850 op.request_body = Some(RequestBody {
851 required: true,
852 content,
853 });
854 }
855}
856
857impl<T: for<'a> Schema<'a>> ResponseModifier for Json<T> {
861 fn update_response(op: &mut Operation) {
862 let (name, _) = T::schema();
863
864 let schema_ref = SchemaRef::Ref {
865 reference: format!("#/components/schemas/{}", name),
866 };
867
868 op.responses.insert(
869 "200".to_string(),
870 ResponseSpec {
871 description: "Successful response".to_string(),
872 content: {
873 let mut map = HashMap::new();
874 map.insert(
875 "application/json".to_string(),
876 MediaType { schema: schema_ref },
877 );
878 Some(map)
879 },
880 },
881 );
882 }
883}
884
885#[cfg(test)]
886mod tests {
887 use super::*;
888 use bytes::Bytes;
889 use http::{Extensions, Method};
890 use proptest::prelude::*;
891 use proptest::test_runner::TestCaseError;
892 use std::collections::HashMap;
893 use std::sync::Arc;
894
895 fn create_test_request_with_headers(
897 method: Method,
898 path: &str,
899 headers: Vec<(&str, &str)>,
900 ) -> Request {
901 let uri: http::Uri = path.parse().unwrap();
902 let mut builder = http::Request::builder().method(method).uri(uri);
903
904 for (name, value) in headers {
905 builder = builder.header(name, value);
906 }
907
908 let req = builder.body(()).unwrap();
909 let (parts, _) = req.into_parts();
910
911 Request::new(
912 parts,
913 Bytes::new(),
914 Arc::new(Extensions::new()),
915 HashMap::new(),
916 )
917 }
918
919 fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
921 method: Method,
922 path: &str,
923 extension: T,
924 ) -> Request {
925 let uri: http::Uri = path.parse().unwrap();
926 let builder = http::Request::builder().method(method).uri(uri);
927
928 let req = builder.body(()).unwrap();
929 let (mut parts, _) = req.into_parts();
930 parts.extensions.insert(extension);
931
932 Request::new(
933 parts,
934 Bytes::new(),
935 Arc::new(Extensions::new()),
936 HashMap::new(),
937 )
938 }
939
940 proptest! {
947 #![proptest_config(ProptestConfig::with_cases(100))]
948
949 #[test]
950 fn prop_headers_extractor_completeness(
951 headers in prop::collection::vec(
954 (
955 "[a-z][a-z0-9-]{0,20}", "[a-zA-Z0-9 ]{1,50}" ),
958 0..10
959 )
960 ) {
961 let result: Result<(), TestCaseError> = (|| {
962 let header_tuples: Vec<(&str, &str)> = headers
964 .iter()
965 .map(|(k, v)| (k.as_str(), v.as_str()))
966 .collect();
967
968 let request = create_test_request_with_headers(
970 Method::GET,
971 "/test",
972 header_tuples.clone(),
973 );
974
975 let extracted = Headers::from_request_parts(&request)
977 .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
978
979 for (name, value) in &headers {
982 let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
984 prop_assert!(
985 !all_values.is_empty(),
986 "Header '{}' not found",
987 name
988 );
989
990 let value_found = all_values.iter().any(|v| {
992 v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
993 });
994
995 prop_assert!(
996 value_found,
997 "Header '{}' value '{}' not found in extracted values",
998 name,
999 value
1000 );
1001 }
1002
1003 Ok(())
1004 })();
1005 result?;
1006 }
1007 }
1008
1009 proptest! {
1016 #![proptest_config(ProptestConfig::with_cases(100))]
1017
1018 #[test]
1019 fn prop_header_value_extractor_correctness(
1020 header_name in "[a-z][a-z0-9-]{0,20}",
1021 header_value in "[a-zA-Z0-9 ]{1,50}",
1022 has_header in prop::bool::ANY,
1023 ) {
1024 let result: Result<(), TestCaseError> = (|| {
1025 let headers = if has_header {
1026 vec![(header_name.as_str(), header_value.as_str())]
1027 } else {
1028 vec![]
1029 };
1030
1031 let request = create_test_request_with_headers(Method::GET, "/test", headers);
1032
1033 let test_header = "x-test-header";
1036 let request_with_known_header = if has_header {
1037 create_test_request_with_headers(
1038 Method::GET,
1039 "/test",
1040 vec![(test_header, header_value.as_str())],
1041 )
1042 } else {
1043 create_test_request_with_headers(Method::GET, "/test", vec![])
1044 };
1045
1046 let result = HeaderValue::extract(&request_with_known_header, test_header);
1047
1048 if has_header {
1049 let extracted = result
1050 .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1051 prop_assert_eq!(
1052 extracted.value(),
1053 header_value.as_str(),
1054 "Header value mismatch"
1055 );
1056 } else {
1057 prop_assert!(
1058 result.is_err(),
1059 "Expected error when header is missing"
1060 );
1061 }
1062
1063 Ok(())
1064 })();
1065 result?;
1066 }
1067 }
1068
1069 proptest! {
1076 #![proptest_config(ProptestConfig::with_cases(100))]
1077
1078 #[test]
1079 fn prop_client_ip_extractor_with_forwarding(
1080 forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1082 .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1083 socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1084 .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1085 has_forwarded_header in prop::bool::ANY,
1086 trust_proxy in prop::bool::ANY,
1087 ) {
1088 let result: Result<(), TestCaseError> = (|| {
1089 let headers = if has_forwarded_header {
1090 vec![("x-forwarded-for", forwarded_ip.as_str())]
1091 } else {
1092 vec![]
1093 };
1094
1095 let uri: http::Uri = "/test".parse().unwrap();
1097 let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1098 for (name, value) in &headers {
1099 builder = builder.header(*name, *value);
1100 }
1101 let req = builder.body(()).unwrap();
1102 let (mut parts, _) = req.into_parts();
1103
1104 let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1106 parts.extensions.insert(socket_addr);
1107
1108 let request = Request::new(
1109 parts,
1110 Bytes::new(),
1111 Arc::new(Extensions::new()),
1112 HashMap::new(),
1113 );
1114
1115 let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1116 .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1117
1118 if trust_proxy && has_forwarded_header {
1119 let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1121 .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1122 prop_assert_eq!(
1123 extracted.0,
1124 expected_ip,
1125 "Should use X-Forwarded-For IP when trust_proxy is enabled"
1126 );
1127 } else {
1128 prop_assert_eq!(
1130 extracted.0,
1131 socket_ip,
1132 "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1133 );
1134 }
1135
1136 Ok(())
1137 })();
1138 result?;
1139 }
1140 }
1141
1142 proptest! {
1149 #![proptest_config(ProptestConfig::with_cases(100))]
1150
1151 #[test]
1152 fn prop_extension_extractor_retrieval(
1153 value in any::<i64>(),
1154 has_extension in prop::bool::ANY,
1155 ) {
1156 let result: Result<(), TestCaseError> = (|| {
1157 #[derive(Clone, Debug, PartialEq)]
1159 struct TestExtension(i64);
1160
1161 let uri: http::Uri = "/test".parse().unwrap();
1162 let builder = http::Request::builder().method(Method::GET).uri(uri);
1163 let req = builder.body(()).unwrap();
1164 let (mut parts, _) = req.into_parts();
1165
1166 if has_extension {
1167 parts.extensions.insert(TestExtension(value));
1168 }
1169
1170 let request = Request::new(
1171 parts,
1172 Bytes::new(),
1173 Arc::new(Extensions::new()),
1174 HashMap::new(),
1175 );
1176
1177 let result = Extension::<TestExtension>::from_request_parts(&request);
1178
1179 if has_extension {
1180 let extracted = result
1181 .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1182 prop_assert_eq!(
1183 extracted.0,
1184 TestExtension(value),
1185 "Extension value mismatch"
1186 );
1187 } else {
1188 prop_assert!(
1189 result.is_err(),
1190 "Expected error when extension is missing"
1191 );
1192 }
1193
1194 Ok(())
1195 })();
1196 result?;
1197 }
1198 }
1199
1200 #[test]
1203 fn test_headers_extractor_basic() {
1204 let request = create_test_request_with_headers(
1205 Method::GET,
1206 "/test",
1207 vec![
1208 ("content-type", "application/json"),
1209 ("accept", "text/html"),
1210 ],
1211 );
1212
1213 let headers = Headers::from_request_parts(&request).unwrap();
1214
1215 assert!(headers.contains("content-type"));
1216 assert!(headers.contains("accept"));
1217 assert!(!headers.contains("x-custom"));
1218 assert_eq!(headers.len(), 2);
1219 }
1220
1221 #[test]
1222 fn test_header_value_extractor_present() {
1223 let request = create_test_request_with_headers(
1224 Method::GET,
1225 "/test",
1226 vec![("authorization", "Bearer token123")],
1227 );
1228
1229 let result = HeaderValue::extract(&request, "authorization");
1230 assert!(result.is_ok());
1231 assert_eq!(result.unwrap().value(), "Bearer token123");
1232 }
1233
1234 #[test]
1235 fn test_header_value_extractor_missing() {
1236 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1237
1238 let result = HeaderValue::extract(&request, "authorization");
1239 assert!(result.is_err());
1240 }
1241
1242 #[test]
1243 fn test_client_ip_from_forwarded_header() {
1244 let request = create_test_request_with_headers(
1245 Method::GET,
1246 "/test",
1247 vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1248 );
1249
1250 let ip = ClientIp::extract_with_config(&request, true).unwrap();
1251 assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1252 }
1253
1254 #[test]
1255 fn test_client_ip_ignores_forwarded_when_not_trusted() {
1256 let uri: http::Uri = "/test".parse().unwrap();
1257 let builder = http::Request::builder()
1258 .method(Method::GET)
1259 .uri(uri)
1260 .header("x-forwarded-for", "192.168.1.100");
1261 let req = builder.body(()).unwrap();
1262 let (mut parts, _) = req.into_parts();
1263
1264 let socket_addr = std::net::SocketAddr::new(
1265 std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1266 8080,
1267 );
1268 parts.extensions.insert(socket_addr);
1269
1270 let request = Request::new(
1271 parts,
1272 Bytes::new(),
1273 Arc::new(Extensions::new()),
1274 HashMap::new(),
1275 );
1276
1277 let ip = ClientIp::extract_with_config(&request, false).unwrap();
1278 assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1279 }
1280
1281 #[test]
1282 fn test_extension_extractor_present() {
1283 #[derive(Clone, Debug, PartialEq)]
1284 struct MyData(String);
1285
1286 let request =
1287 create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1288
1289 let result = Extension::<MyData>::from_request_parts(&request);
1290 assert!(result.is_ok());
1291 assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1292 }
1293
1294 #[test]
1295 fn test_extension_extractor_missing() {
1296 #[derive(Clone, Debug)]
1297 struct MyData(String);
1298
1299 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1300
1301 let result = Extension::<MyData>::from_request_parts(&request);
1302 assert!(result.is_err());
1303 }
1304
1305 #[cfg(feature = "cookies")]
1307 mod cookies_tests {
1308 use super::*;
1309
1310 proptest! {
1318 #![proptest_config(ProptestConfig::with_cases(100))]
1319
1320 #[test]
1321 fn prop_cookies_extractor_parsing(
1322 cookies in prop::collection::vec(
1325 (
1326 "[a-zA-Z][a-zA-Z0-9_]{0,15}", "[a-zA-Z0-9]{1,30}" ),
1329 0..5
1330 )
1331 ) {
1332 let result: Result<(), TestCaseError> = (|| {
1333 let cookie_header = cookies
1335 .iter()
1336 .map(|(name, value)| format!("{}={}", name, value))
1337 .collect::<Vec<_>>()
1338 .join("; ");
1339
1340 let headers = if !cookies.is_empty() {
1341 vec![("cookie", cookie_header.as_str())]
1342 } else {
1343 vec![]
1344 };
1345
1346 let request = create_test_request_with_headers(Method::GET, "/test", headers);
1347
1348 let extracted = Cookies::from_request_parts(&request)
1350 .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1351
1352 let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1354 for (name, value) in &cookies {
1355 expected_cookies.insert(name.as_str(), value.as_str());
1356 }
1357
1358 for (name, expected_value) in &expected_cookies {
1360 let cookie = extracted.get(*name)
1361 .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1362
1363 prop_assert_eq!(
1364 cookie.value(),
1365 *expected_value,
1366 "Cookie '{}' value mismatch",
1367 name
1368 );
1369 }
1370
1371 let extracted_count = extracted.iter().count();
1373 prop_assert_eq!(
1374 extracted_count,
1375 expected_cookies.len(),
1376 "Expected {} unique cookies, got {}",
1377 expected_cookies.len(),
1378 extracted_count
1379 );
1380
1381 Ok(())
1382 })();
1383 result?;
1384 }
1385 }
1386
1387 #[test]
1388 fn test_cookies_extractor_basic() {
1389 let request = create_test_request_with_headers(
1390 Method::GET,
1391 "/test",
1392 vec![("cookie", "session=abc123; user=john")],
1393 );
1394
1395 let cookies = Cookies::from_request_parts(&request).unwrap();
1396
1397 assert!(cookies.contains("session"));
1398 assert!(cookies.contains("user"));
1399 assert!(!cookies.contains("other"));
1400
1401 assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1402 assert_eq!(cookies.get("user").unwrap().value(), "john");
1403 }
1404
1405 #[test]
1406 fn test_cookies_extractor_empty() {
1407 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1408
1409 let cookies = Cookies::from_request_parts(&request).unwrap();
1410 assert_eq!(cookies.iter().count(), 0);
1411 }
1412
1413 #[test]
1414 fn test_cookies_extractor_single() {
1415 let request = create_test_request_with_headers(
1416 Method::GET,
1417 "/test",
1418 vec![("cookie", "token=xyz789")],
1419 );
1420
1421 let cookies = Cookies::from_request_parts(&request).unwrap();
1422 assert_eq!(cookies.iter().count(), 1);
1423 assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1424 }
1425 }
1426}