1use crate::error::{ApiError, Result};
58use crate::json;
59use crate::request::Request;
60use crate::response::IntoResponse;
61use crate::stream::{StreamingBody, StreamingConfig};
62use bytes::Bytes;
63use http::{header, StatusCode};
64use http_body_util::Full;
65use serde::de::DeserializeOwned;
66use serde::Serialize;
67use std::future::Future;
68use std::ops::{Deref, DerefMut};
69use std::str::FromStr;
70
71pub trait FromRequestParts: Sized {
75 fn from_request_parts(req: &Request) -> Result<Self>;
77}
78
79pub trait FromRequest: Sized {
83 fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
85}
86
87impl<T: FromRequestParts> FromRequest for T {
89 async fn from_request(req: &mut Request) -> Result<Self> {
90 T::from_request_parts(req)
91 }
92}
93
94#[derive(Debug, Clone, Copy, Default)]
113pub struct Json<T>(pub T);
114
115impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
116 async fn from_request(req: &mut Request) -> Result<Self> {
117 req.load_body().await?;
118 let body = req
119 .take_body()
120 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
121
122 let value: T = json::from_slice(&body)?;
124 Ok(Json(value))
125 }
126}
127
128impl<T> Deref for Json<T> {
129 type Target = T;
130
131 fn deref(&self) -> &Self::Target {
132 &self.0
133 }
134}
135
136impl<T> DerefMut for Json<T> {
137 fn deref_mut(&mut self) -> &mut Self::Target {
138 &mut self.0
139 }
140}
141
142impl<T> From<T> for Json<T> {
143 fn from(value: T) -> Self {
144 Json(value)
145 }
146}
147
148const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
151
152impl<T: Serialize> IntoResponse for Json<T> {
154 fn into_response(self) -> crate::response::Response {
155 match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
157 Ok(body) => http::Response::builder()
158 .status(StatusCode::OK)
159 .header(header::CONTENT_TYPE, "application/json")
160 .body(Full::new(Bytes::from(body)))
161 .unwrap(),
162 Err(err) => {
163 ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
164 }
165 }
166 }
167}
168
169#[derive(Debug, Clone, Copy, Default)]
195pub struct ValidatedJson<T>(pub T);
196
197impl<T> ValidatedJson<T> {
198 pub fn new(value: T) -> Self {
200 Self(value)
201 }
202
203 pub fn into_inner(self) -> T {
205 self.0
206 }
207}
208
209impl<T: DeserializeOwned + rustapi_validate::Validate + Send> FromRequest for ValidatedJson<T> {
210 async fn from_request(req: &mut Request) -> Result<Self> {
211 req.load_body().await?;
212 let body = req
214 .take_body()
215 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
216
217 let value: T = json::from_slice(&body)?;
218
219 if let Err(validation_error) = rustapi_validate::Validate::validate(&value) {
221 return Err(validation_error.into());
223 }
224
225 Ok(ValidatedJson(value))
226 }
227}
228
229impl<T> Deref for ValidatedJson<T> {
230 type Target = T;
231
232 fn deref(&self) -> &Self::Target {
233 &self.0
234 }
235}
236
237impl<T> DerefMut for ValidatedJson<T> {
238 fn deref_mut(&mut self) -> &mut Self::Target {
239 &mut self.0
240 }
241}
242
243impl<T> From<T> for ValidatedJson<T> {
244 fn from(value: T) -> Self {
245 ValidatedJson(value)
246 }
247}
248
249impl<T: Serialize> IntoResponse for ValidatedJson<T> {
250 fn into_response(self) -> crate::response::Response {
251 Json(self.0).into_response()
252 }
253}
254
255#[derive(Debug, Clone)]
273pub struct Query<T>(pub T);
274
275impl<T: DeserializeOwned> FromRequestParts for Query<T> {
276 fn from_request_parts(req: &Request) -> Result<Self> {
277 let query = req.query_string().unwrap_or("");
278 let value: T = serde_urlencoded::from_str(query)
279 .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
280 Ok(Query(value))
281 }
282}
283
284impl<T> Deref for Query<T> {
285 type Target = T;
286
287 fn deref(&self) -> &Self::Target {
288 &self.0
289 }
290}
291
292#[derive(Debug, Clone)]
314pub struct Path<T>(pub T);
315
316impl<T: FromStr> FromRequestParts for Path<T>
317where
318 T::Err: std::fmt::Display,
319{
320 fn from_request_parts(req: &Request) -> Result<Self> {
321 let params = req.path_params();
322
323 if let Some((_, value)) = params.iter().next() {
325 let parsed = value
326 .parse::<T>()
327 .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
328 return Ok(Path(parsed));
329 }
330
331 Err(ApiError::internal("Missing path parameter"))
332 }
333}
334
335impl<T> Deref for Path<T> {
336 type Target = T;
337
338 fn deref(&self) -> &Self::Target {
339 &self.0
340 }
341}
342
343#[derive(Debug, Clone)]
363pub struct Typed<T>(pub T);
364
365impl<T: DeserializeOwned + Send> FromRequestParts for Typed<T> {
366 fn from_request_parts(req: &Request) -> Result<Self> {
367 let params = req.path_params();
368 let mut map = serde_json::Map::new();
369 for (k, v) in params.iter() {
370 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
371 }
372 let value = serde_json::Value::Object(map);
373 let parsed: T = serde_json::from_value(value)
374 .map_err(|e| ApiError::bad_request(format!("Invalid path parameters: {}", e)))?;
375 Ok(Typed(parsed))
376 }
377}
378
379impl<T> Deref for Typed<T> {
380 type Target = T;
381
382 fn deref(&self) -> &Self::Target {
383 &self.0
384 }
385}
386
387#[derive(Debug, Clone)]
404pub struct State<T>(pub T);
405
406impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
407 fn from_request_parts(req: &Request) -> Result<Self> {
408 req.state().get::<T>().cloned().map(State).ok_or_else(|| {
409 ApiError::internal(format!(
410 "State of type `{}` not found. Did you forget to call .state()?",
411 std::any::type_name::<T>()
412 ))
413 })
414 }
415}
416
417impl<T> Deref for State<T> {
418 type Target = T;
419
420 fn deref(&self) -> &Self::Target {
421 &self.0
422 }
423}
424
425#[derive(Debug, Clone)]
427pub struct Body(pub Bytes);
428
429impl FromRequest for Body {
430 async fn from_request(req: &mut Request) -> Result<Self> {
431 req.load_body().await?;
432 let body = req
433 .take_body()
434 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
435 Ok(Body(body))
436 }
437}
438
439impl Deref for Body {
440 type Target = Bytes;
441
442 fn deref(&self) -> &Self::Target {
443 &self.0
444 }
445}
446
447pub struct BodyStream(pub StreamingBody);
449
450impl FromRequest for BodyStream {
451 async fn from_request(req: &mut Request) -> Result<Self> {
452 let config = StreamingConfig::default();
453
454 if let Some(stream) = req.take_stream() {
455 Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
456 } else if let Some(bytes) = req.take_body() {
457 let stream = futures_util::stream::once(async move { Ok(bytes) });
459 Ok(BodyStream(StreamingBody::from_stream(
460 stream,
461 config.max_body_size,
462 )))
463 } else {
464 Err(ApiError::internal("Body already consumed"))
465 }
466 }
467}
468
469impl Deref for BodyStream {
470 type Target = StreamingBody;
471
472 fn deref(&self) -> &Self::Target {
473 &self.0
474 }
475}
476
477impl DerefMut for BodyStream {
478 fn deref_mut(&mut self) -> &mut Self::Target {
479 &mut self.0
480 }
481}
482
483impl futures_util::Stream for BodyStream {
485 type Item = Result<Bytes, ApiError>;
486
487 fn poll_next(
488 mut self: std::pin::Pin<&mut Self>,
489 cx: &mut std::task::Context<'_>,
490 ) -> std::task::Poll<Option<Self::Item>> {
491 std::pin::Pin::new(&mut self.0).poll_next(cx)
492 }
493}
494
495impl<T: FromRequestParts> FromRequestParts for Option<T> {
499 fn from_request_parts(req: &Request) -> Result<Self> {
500 Ok(T::from_request_parts(req).ok())
501 }
502}
503
504#[derive(Debug, Clone)]
522pub struct Headers(pub http::HeaderMap);
523
524impl Headers {
525 pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
527 self.0.get(name)
528 }
529
530 pub fn contains(&self, name: &str) -> bool {
532 self.0.contains_key(name)
533 }
534
535 pub fn len(&self) -> usize {
537 self.0.len()
538 }
539
540 pub fn is_empty(&self) -> bool {
542 self.0.is_empty()
543 }
544
545 pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
547 self.0.iter()
548 }
549}
550
551impl FromRequestParts for Headers {
552 fn from_request_parts(req: &Request) -> Result<Self> {
553 Ok(Headers(req.headers().clone()))
554 }
555}
556
557impl Deref for Headers {
558 type Target = http::HeaderMap;
559
560 fn deref(&self) -> &Self::Target {
561 &self.0
562 }
563}
564
565#[derive(Debug, Clone)]
584pub struct HeaderValue(pub String, pub &'static str);
585
586impl HeaderValue {
587 pub fn new(name: &'static str, value: String) -> Self {
589 Self(value, name)
590 }
591
592 pub fn value(&self) -> &str {
594 &self.0
595 }
596
597 pub fn name(&self) -> &'static str {
599 self.1
600 }
601
602 pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
604 req.headers()
605 .get(name)
606 .and_then(|v| v.to_str().ok())
607 .map(|s| HeaderValue(s.to_string(), name))
608 .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
609 }
610}
611
612impl Deref for HeaderValue {
613 type Target = String;
614
615 fn deref(&self) -> &Self::Target {
616 &self.0
617 }
618}
619
620#[derive(Debug, Clone)]
638pub struct Extension<T>(pub T);
639
640impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
641 fn from_request_parts(req: &Request) -> Result<Self> {
642 req.extensions()
643 .get::<T>()
644 .cloned()
645 .map(Extension)
646 .ok_or_else(|| {
647 ApiError::internal(format!(
648 "Extension of type `{}` not found. Did middleware insert it?",
649 std::any::type_name::<T>()
650 ))
651 })
652 }
653}
654
655impl<T> Deref for Extension<T> {
656 type Target = T;
657
658 fn deref(&self) -> &Self::Target {
659 &self.0
660 }
661}
662
663impl<T> DerefMut for Extension<T> {
664 fn deref_mut(&mut self) -> &mut Self::Target {
665 &mut self.0
666 }
667}
668
669#[derive(Debug, Clone)]
684pub struct ClientIp(pub std::net::IpAddr);
685
686impl ClientIp {
687 pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
689 if trust_proxy {
690 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
692 if let Ok(forwarded_str) = forwarded.to_str() {
693 if let Some(first_ip) = forwarded_str.split(',').next() {
695 if let Ok(ip) = first_ip.trim().parse() {
696 return Ok(ClientIp(ip));
697 }
698 }
699 }
700 }
701 }
702
703 if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
705 return Ok(ClientIp(addr.ip()));
706 }
707
708 Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
710 127, 0, 0, 1,
711 ))))
712 }
713}
714
715impl FromRequestParts for ClientIp {
716 fn from_request_parts(req: &Request) -> Result<Self> {
717 Self::extract_with_config(req, true)
719 }
720}
721
722#[cfg(feature = "cookies")]
740#[derive(Debug, Clone)]
741pub struct Cookies(pub cookie::CookieJar);
742
743#[cfg(feature = "cookies")]
744impl Cookies {
745 pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
747 self.0.get(name)
748 }
749
750 pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
752 self.0.iter()
753 }
754
755 pub fn contains(&self, name: &str) -> bool {
757 self.0.get(name).is_some()
758 }
759}
760
761#[cfg(feature = "cookies")]
762impl FromRequestParts for Cookies {
763 fn from_request_parts(req: &Request) -> Result<Self> {
764 let mut jar = cookie::CookieJar::new();
765
766 if let Some(cookie_header) = req.headers().get(header::COOKIE) {
767 if let Ok(cookie_str) = cookie_header.to_str() {
768 for cookie_part in cookie_str.split(';') {
770 let trimmed = cookie_part.trim();
771 if !trimmed.is_empty() {
772 if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
773 jar.add_original(cookie.into_owned());
774 }
775 }
776 }
777 }
778 }
779
780 Ok(Cookies(jar))
781 }
782}
783
784#[cfg(feature = "cookies")]
785impl Deref for Cookies {
786 type Target = cookie::CookieJar;
787
788 fn deref(&self) -> &Self::Target {
789 &self.0
790 }
791}
792
793macro_rules! impl_from_request_parts_for_primitives {
795 ($($ty:ty),*) => {
796 $(
797 impl FromRequestParts for $ty {
798 fn from_request_parts(req: &Request) -> Result<Self> {
799 let Path(value) = Path::<$ty>::from_request_parts(req)?;
800 Ok(value)
801 }
802 }
803 )*
804 };
805}
806
807impl_from_request_parts_for_primitives!(
808 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
809);
810
811use rustapi_openapi::utoipa_types::openapi;
814use rustapi_openapi::{
815 IntoParams, MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier,
816 ResponseSpec, Schema, SchemaRef,
817};
818use std::collections::HashMap;
819
820impl<T: for<'a> Schema<'a>> OperationModifier for ValidatedJson<T> {
822 fn update_operation(op: &mut Operation) {
823 let (name, _) = T::schema();
824
825 let schema_ref = SchemaRef::Ref {
826 reference: format!("#/components/schemas/{}", name),
827 };
828
829 let mut content = HashMap::new();
830 content.insert(
831 "application/json".to_string(),
832 MediaType { schema: schema_ref },
833 );
834
835 op.request_body = Some(RequestBody {
836 required: true,
837 content,
838 });
839
840 op.responses.insert(
842 "422".to_string(),
843 ResponseSpec {
844 description: "Validation Error".to_string(),
845 content: {
846 let mut map = HashMap::new();
847 map.insert(
848 "application/json".to_string(),
849 MediaType {
850 schema: SchemaRef::Ref {
851 reference: "#/components/schemas/ValidationErrorSchema".to_string(),
852 },
853 },
854 );
855 Some(map)
856 },
857 },
858 );
859 }
860}
861
862impl<T: for<'a> Schema<'a>> OperationModifier for Json<T> {
864 fn update_operation(op: &mut Operation) {
865 let (name, _) = T::schema();
866
867 let schema_ref = SchemaRef::Ref {
868 reference: format!("#/components/schemas/{}", name),
869 };
870
871 let mut content = HashMap::new();
872 content.insert(
873 "application/json".to_string(),
874 MediaType { schema: schema_ref },
875 );
876
877 op.request_body = Some(RequestBody {
878 required: true,
879 content,
880 });
881 }
882}
883
884impl<T> OperationModifier for Path<T> {
888 fn update_operation(_op: &mut Operation) {
889 }
896}
897
898impl<T> OperationModifier for Typed<T> {
900 fn update_operation(_op: &mut Operation) {
901 }
903}
904
905impl<T: IntoParams> OperationModifier for Query<T> {
907 fn update_operation(op: &mut Operation) {
908 let params = T::into_params(|| Some(openapi::path::ParameterIn::Query));
909
910 let new_params: Vec<Parameter> = params
911 .into_iter()
912 .map(|p| {
913 let schema = match p.schema {
914 Some(schema) => match schema {
915 openapi::RefOr::Ref(r) => SchemaRef::Ref {
916 reference: r.ref_location,
917 },
918 openapi::RefOr::T(s) => {
919 let value = serde_json::to_value(s).unwrap_or(serde_json::Value::Null);
920 SchemaRef::Inline(value)
921 }
922 },
923 None => SchemaRef::Inline(serde_json::Value::Null),
924 };
925
926 let required = match p.required {
927 openapi::Required::True => true,
928 openapi::Required::False => false,
929 };
930
931 Parameter {
932 name: p.name,
933 location: "query".to_string(), required,
935 description: p.description,
936 schema,
937 }
938 })
939 .collect();
940
941 if let Some(existing) = &mut op.parameters {
942 existing.extend(new_params);
943 } else {
944 op.parameters = Some(new_params);
945 }
946 }
947}
948
949impl<T> OperationModifier for State<T> {
951 fn update_operation(_op: &mut Operation) {}
952}
953
954impl OperationModifier for Body {
956 fn update_operation(op: &mut Operation) {
957 let mut content = HashMap::new();
958 content.insert(
959 "application/octet-stream".to_string(),
960 MediaType {
961 schema: SchemaRef::Inline(
962 serde_json::json!({ "type": "string", "format": "binary" }),
963 ),
964 },
965 );
966
967 op.request_body = Some(RequestBody {
968 required: true,
969 content,
970 });
971 }
972}
973
974impl OperationModifier for BodyStream {
976 fn update_operation(op: &mut Operation) {
977 let mut content = HashMap::new();
978 content.insert(
979 "application/octet-stream".to_string(),
980 MediaType {
981 schema: SchemaRef::Inline(
982 serde_json::json!({ "type": "string", "format": "binary" }),
983 ),
984 },
985 );
986
987 op.request_body = Some(RequestBody {
988 required: true,
989 content,
990 });
991 }
992}
993
994impl<T: for<'a> Schema<'a>> ResponseModifier for Json<T> {
998 fn update_response(op: &mut Operation) {
999 let (name, _) = T::schema();
1000
1001 let schema_ref = SchemaRef::Ref {
1002 reference: format!("#/components/schemas/{}", name),
1003 };
1004
1005 op.responses.insert(
1006 "200".to_string(),
1007 ResponseSpec {
1008 description: "Successful response".to_string(),
1009 content: {
1010 let mut map = HashMap::new();
1011 map.insert(
1012 "application/json".to_string(),
1013 MediaType { schema: schema_ref },
1014 );
1015 Some(map)
1016 },
1017 },
1018 );
1019 }
1020}
1021
1022#[cfg(test)]
1023mod tests {
1024 use super::*;
1025 use crate::path_params::PathParams;
1026 use bytes::Bytes;
1027 use http::{Extensions, Method};
1028 use proptest::prelude::*;
1029 use proptest::test_runner::TestCaseError;
1030 use std::sync::Arc;
1031
1032 fn create_test_request_with_headers(
1034 method: Method,
1035 path: &str,
1036 headers: Vec<(&str, &str)>,
1037 ) -> Request {
1038 let uri: http::Uri = path.parse().unwrap();
1039 let mut builder = http::Request::builder().method(method).uri(uri);
1040
1041 for (name, value) in headers {
1042 builder = builder.header(name, value);
1043 }
1044
1045 let req = builder.body(()).unwrap();
1046 let (parts, _) = req.into_parts();
1047
1048 Request::new(
1049 parts,
1050 crate::request::BodyVariant::Buffered(Bytes::new()),
1051 Arc::new(Extensions::new()),
1052 PathParams::new(),
1053 )
1054 }
1055
1056 fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
1058 method: Method,
1059 path: &str,
1060 extension: T,
1061 ) -> Request {
1062 let uri: http::Uri = path.parse().unwrap();
1063 let builder = http::Request::builder().method(method).uri(uri);
1064
1065 let req = builder.body(()).unwrap();
1066 let (mut parts, _) = req.into_parts();
1067 parts.extensions.insert(extension);
1068
1069 Request::new(
1070 parts,
1071 crate::request::BodyVariant::Buffered(Bytes::new()),
1072 Arc::new(Extensions::new()),
1073 PathParams::new(),
1074 )
1075 }
1076
1077 proptest! {
1084 #![proptest_config(ProptestConfig::with_cases(100))]
1085
1086 #[test]
1087 fn prop_headers_extractor_completeness(
1088 headers in prop::collection::vec(
1091 (
1092 "[a-z][a-z0-9-]{0,20}", "[a-zA-Z0-9 ]{1,50}" ),
1095 0..10
1096 )
1097 ) {
1098 let result: Result<(), TestCaseError> = (|| {
1099 let header_tuples: Vec<(&str, &str)> = headers
1101 .iter()
1102 .map(|(k, v)| (k.as_str(), v.as_str()))
1103 .collect();
1104
1105 let request = create_test_request_with_headers(
1107 Method::GET,
1108 "/test",
1109 header_tuples.clone(),
1110 );
1111
1112 let extracted = Headers::from_request_parts(&request)
1114 .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
1115
1116 for (name, value) in &headers {
1119 let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
1121 prop_assert!(
1122 !all_values.is_empty(),
1123 "Header '{}' not found",
1124 name
1125 );
1126
1127 let value_found = all_values.iter().any(|v| {
1129 v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
1130 });
1131
1132 prop_assert!(
1133 value_found,
1134 "Header '{}' value '{}' not found in extracted values",
1135 name,
1136 value
1137 );
1138 }
1139
1140 Ok(())
1141 })();
1142 result?;
1143 }
1144 }
1145
1146 proptest! {
1153 #![proptest_config(ProptestConfig::with_cases(100))]
1154
1155 #[test]
1156 fn prop_header_value_extractor_correctness(
1157 header_name in "[a-z][a-z0-9-]{0,20}",
1158 header_value in "[a-zA-Z0-9 ]{1,50}",
1159 has_header in prop::bool::ANY,
1160 ) {
1161 let result: Result<(), TestCaseError> = (|| {
1162 let headers = if has_header {
1163 vec![(header_name.as_str(), header_value.as_str())]
1164 } else {
1165 vec![]
1166 };
1167
1168 let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1169
1170 let test_header = "x-test-header";
1173 let request_with_known_header = if has_header {
1174 create_test_request_with_headers(
1175 Method::GET,
1176 "/test",
1177 vec![(test_header, header_value.as_str())],
1178 )
1179 } else {
1180 create_test_request_with_headers(Method::GET, "/test", vec![])
1181 };
1182
1183 let result = HeaderValue::extract(&request_with_known_header, test_header);
1184
1185 if has_header {
1186 let extracted = result
1187 .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1188 prop_assert_eq!(
1189 extracted.value(),
1190 header_value.as_str(),
1191 "Header value mismatch"
1192 );
1193 } else {
1194 prop_assert!(
1195 result.is_err(),
1196 "Expected error when header is missing"
1197 );
1198 }
1199
1200 Ok(())
1201 })();
1202 result?;
1203 }
1204 }
1205
1206 proptest! {
1213 #![proptest_config(ProptestConfig::with_cases(100))]
1214
1215 #[test]
1216 fn prop_client_ip_extractor_with_forwarding(
1217 forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1219 .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1220 socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1221 .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1222 has_forwarded_header in prop::bool::ANY,
1223 trust_proxy in prop::bool::ANY,
1224 ) {
1225 let result: Result<(), TestCaseError> = (|| {
1226 let headers = if has_forwarded_header {
1227 vec![("x-forwarded-for", forwarded_ip.as_str())]
1228 } else {
1229 vec![]
1230 };
1231
1232 let uri: http::Uri = "/test".parse().unwrap();
1234 let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1235 for (name, value) in &headers {
1236 builder = builder.header(*name, *value);
1237 }
1238 let req = builder.body(()).unwrap();
1239 let (mut parts, _) = req.into_parts();
1240
1241 let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1243 parts.extensions.insert(socket_addr);
1244
1245 let request = Request::new(
1246 parts,
1247 crate::request::BodyVariant::Buffered(Bytes::new()),
1248 Arc::new(Extensions::new()),
1249 PathParams::new(),
1250 );
1251
1252 let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1253 .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1254
1255 if trust_proxy && has_forwarded_header {
1256 let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1258 .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1259 prop_assert_eq!(
1260 extracted.0,
1261 expected_ip,
1262 "Should use X-Forwarded-For IP when trust_proxy is enabled"
1263 );
1264 } else {
1265 prop_assert_eq!(
1267 extracted.0,
1268 socket_ip,
1269 "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1270 );
1271 }
1272
1273 Ok(())
1274 })();
1275 result?;
1276 }
1277 }
1278
1279 proptest! {
1286 #![proptest_config(ProptestConfig::with_cases(100))]
1287
1288 #[test]
1289 fn prop_extension_extractor_retrieval(
1290 value in any::<i64>(),
1291 has_extension in prop::bool::ANY,
1292 ) {
1293 let result: Result<(), TestCaseError> = (|| {
1294 #[derive(Clone, Debug, PartialEq)]
1296 struct TestExtension(i64);
1297
1298 let uri: http::Uri = "/test".parse().unwrap();
1299 let builder = http::Request::builder().method(Method::GET).uri(uri);
1300 let req = builder.body(()).unwrap();
1301 let (mut parts, _) = req.into_parts();
1302
1303 if has_extension {
1304 parts.extensions.insert(TestExtension(value));
1305 }
1306
1307 let request = Request::new(
1308 parts,
1309 crate::request::BodyVariant::Buffered(Bytes::new()),
1310 Arc::new(Extensions::new()),
1311 PathParams::new(),
1312 );
1313
1314 let result = Extension::<TestExtension>::from_request_parts(&request);
1315
1316 if has_extension {
1317 let extracted = result
1318 .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1319 prop_assert_eq!(
1320 extracted.0,
1321 TestExtension(value),
1322 "Extension value mismatch"
1323 );
1324 } else {
1325 prop_assert!(
1326 result.is_err(),
1327 "Expected error when extension is missing"
1328 );
1329 }
1330
1331 Ok(())
1332 })();
1333 result?;
1334 }
1335 }
1336
1337 #[test]
1340 fn test_headers_extractor_basic() {
1341 let request = create_test_request_with_headers(
1342 Method::GET,
1343 "/test",
1344 vec![
1345 ("content-type", "application/json"),
1346 ("accept", "text/html"),
1347 ],
1348 );
1349
1350 let headers = Headers::from_request_parts(&request).unwrap();
1351
1352 assert!(headers.contains("content-type"));
1353 assert!(headers.contains("accept"));
1354 assert!(!headers.contains("x-custom"));
1355 assert_eq!(headers.len(), 2);
1356 }
1357
1358 #[test]
1359 fn test_header_value_extractor_present() {
1360 let request = create_test_request_with_headers(
1361 Method::GET,
1362 "/test",
1363 vec![("authorization", "Bearer token123")],
1364 );
1365
1366 let result = HeaderValue::extract(&request, "authorization");
1367 assert!(result.is_ok());
1368 assert_eq!(result.unwrap().value(), "Bearer token123");
1369 }
1370
1371 #[test]
1372 fn test_header_value_extractor_missing() {
1373 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1374
1375 let result = HeaderValue::extract(&request, "authorization");
1376 assert!(result.is_err());
1377 }
1378
1379 #[test]
1380 fn test_client_ip_from_forwarded_header() {
1381 let request = create_test_request_with_headers(
1382 Method::GET,
1383 "/test",
1384 vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1385 );
1386
1387 let ip = ClientIp::extract_with_config(&request, true).unwrap();
1388 assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1389 }
1390
1391 #[test]
1392 fn test_client_ip_ignores_forwarded_when_not_trusted() {
1393 let uri: http::Uri = "/test".parse().unwrap();
1394 let builder = http::Request::builder()
1395 .method(Method::GET)
1396 .uri(uri)
1397 .header("x-forwarded-for", "192.168.1.100");
1398 let req = builder.body(()).unwrap();
1399 let (mut parts, _) = req.into_parts();
1400
1401 let socket_addr = std::net::SocketAddr::new(
1402 std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1403 8080,
1404 );
1405 parts.extensions.insert(socket_addr);
1406
1407 let request = Request::new(
1408 parts,
1409 crate::request::BodyVariant::Buffered(Bytes::new()),
1410 Arc::new(Extensions::new()),
1411 PathParams::new(),
1412 );
1413
1414 let ip = ClientIp::extract_with_config(&request, false).unwrap();
1415 assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1416 }
1417
1418 #[test]
1419 fn test_extension_extractor_present() {
1420 #[derive(Clone, Debug, PartialEq)]
1421 struct MyData(String);
1422
1423 let request =
1424 create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1425
1426 let result = Extension::<MyData>::from_request_parts(&request);
1427 assert!(result.is_ok());
1428 assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1429 }
1430
1431 #[test]
1432 fn test_extension_extractor_missing() {
1433 #[derive(Clone, Debug)]
1434 #[allow(dead_code)]
1435 struct MyData(String);
1436
1437 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1438
1439 let result = Extension::<MyData>::from_request_parts(&request);
1440 assert!(result.is_err());
1441 }
1442
1443 #[cfg(feature = "cookies")]
1445 mod cookies_tests {
1446 use super::*;
1447
1448 proptest! {
1456 #![proptest_config(ProptestConfig::with_cases(100))]
1457
1458 #[test]
1459 fn prop_cookies_extractor_parsing(
1460 cookies in prop::collection::vec(
1463 (
1464 "[a-zA-Z][a-zA-Z0-9_]{0,15}", "[a-zA-Z0-9]{1,30}" ),
1467 0..5
1468 )
1469 ) {
1470 let result: Result<(), TestCaseError> = (|| {
1471 let cookie_header = cookies
1473 .iter()
1474 .map(|(name, value)| format!("{}={}", name, value))
1475 .collect::<Vec<_>>()
1476 .join("; ");
1477
1478 let headers = if !cookies.is_empty() {
1479 vec![("cookie", cookie_header.as_str())]
1480 } else {
1481 vec![]
1482 };
1483
1484 let request = create_test_request_with_headers(Method::GET, "/test", headers);
1485
1486 let extracted = Cookies::from_request_parts(&request)
1488 .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1489
1490 let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1492 for (name, value) in &cookies {
1493 expected_cookies.insert(name.as_str(), value.as_str());
1494 }
1495
1496 for (name, expected_value) in &expected_cookies {
1498 let cookie = extracted.get(name)
1499 .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1500
1501 prop_assert_eq!(
1502 cookie.value(),
1503 *expected_value,
1504 "Cookie '{}' value mismatch",
1505 name
1506 );
1507 }
1508
1509 let extracted_count = extracted.iter().count();
1511 prop_assert_eq!(
1512 extracted_count,
1513 expected_cookies.len(),
1514 "Expected {} unique cookies, got {}",
1515 expected_cookies.len(),
1516 extracted_count
1517 );
1518
1519 Ok(())
1520 })();
1521 result?;
1522 }
1523 }
1524
1525 #[test]
1526 fn test_cookies_extractor_basic() {
1527 let request = create_test_request_with_headers(
1528 Method::GET,
1529 "/test",
1530 vec![("cookie", "session=abc123; user=john")],
1531 );
1532
1533 let cookies = Cookies::from_request_parts(&request).unwrap();
1534
1535 assert!(cookies.contains("session"));
1536 assert!(cookies.contains("user"));
1537 assert!(!cookies.contains("other"));
1538
1539 assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1540 assert_eq!(cookies.get("user").unwrap().value(), "john");
1541 }
1542
1543 #[test]
1544 fn test_cookies_extractor_empty() {
1545 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1546
1547 let cookies = Cookies::from_request_parts(&request).unwrap();
1548 assert_eq!(cookies.iter().count(), 0);
1549 }
1550
1551 #[test]
1552 fn test_cookies_extractor_single() {
1553 let request = create_test_request_with_headers(
1554 Method::GET,
1555 "/test",
1556 vec![("cookie", "token=xyz789")],
1557 );
1558
1559 let cookies = Cookies::from_request_parts(&request).unwrap();
1560 assert_eq!(cookies.iter().count(), 1);
1561 assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1562 }
1563 }
1564}