1use crate::error::{ApiError, Result};
58use crate::json;
59use crate::request::Request;
60use crate::response::IntoResponse;
61use crate::stream::{StreamingBody, StreamingConfig};
62use bytes::Bytes;
63use http::{header, StatusCode};
64use http_body_util::Full;
65use serde::de::DeserializeOwned;
66use serde::Serialize;
67use std::future::Future;
68use std::ops::{Deref, DerefMut};
69use std::str::FromStr;
70
71pub trait FromRequestParts: Sized {
75 fn from_request_parts(req: &Request) -> Result<Self>;
77}
78
79pub trait FromRequest: Sized {
83 fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
85}
86
87impl<T: FromRequestParts> FromRequest for T {
89 async fn from_request(req: &mut Request) -> Result<Self> {
90 T::from_request_parts(req)
91 }
92}
93
94#[derive(Debug, Clone, Copy, Default)]
113pub struct Json<T>(pub T);
114
115impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
116 async fn from_request(req: &mut Request) -> Result<Self> {
117 req.load_body().await?;
118 let body = req
119 .take_body()
120 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
121
122 let value: T = json::from_slice(&body)?;
124 Ok(Json(value))
125 }
126}
127
128impl<T> Deref for Json<T> {
129 type Target = T;
130
131 fn deref(&self) -> &Self::Target {
132 &self.0
133 }
134}
135
136impl<T> DerefMut for Json<T> {
137 fn deref_mut(&mut self) -> &mut Self::Target {
138 &mut self.0
139 }
140}
141
142impl<T> From<T> for Json<T> {
143 fn from(value: T) -> Self {
144 Json(value)
145 }
146}
147
148const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
151
152impl<T: Serialize> IntoResponse for Json<T> {
154 fn into_response(self) -> crate::response::Response {
155 match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
157 Ok(body) => http::Response::builder()
158 .status(StatusCode::OK)
159 .header(header::CONTENT_TYPE, "application/json")
160 .body(Full::new(Bytes::from(body)))
161 .unwrap(),
162 Err(err) => {
163 ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
164 }
165 }
166 }
167}
168
169#[derive(Debug, Clone, Copy, Default)]
195pub struct ValidatedJson<T>(pub T);
196
197impl<T> ValidatedJson<T> {
198 pub fn new(value: T) -> Self {
200 Self(value)
201 }
202
203 pub fn into_inner(self) -> T {
205 self.0
206 }
207}
208
209impl<T: DeserializeOwned + rustapi_validate::Validate + Send> FromRequest for ValidatedJson<T> {
210 async fn from_request(req: &mut Request) -> Result<Self> {
211 req.load_body().await?;
212 let body = req
214 .take_body()
215 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
216
217 let value: T = json::from_slice(&body)?;
218
219 if let Err(validation_error) = rustapi_validate::Validate::validate(&value) {
221 return Err(validation_error.into());
223 }
224
225 Ok(ValidatedJson(value))
226 }
227}
228
229impl<T> Deref for ValidatedJson<T> {
230 type Target = T;
231
232 fn deref(&self) -> &Self::Target {
233 &self.0
234 }
235}
236
237impl<T> DerefMut for ValidatedJson<T> {
238 fn deref_mut(&mut self) -> &mut Self::Target {
239 &mut self.0
240 }
241}
242
243impl<T> From<T> for ValidatedJson<T> {
244 fn from(value: T) -> Self {
245 ValidatedJson(value)
246 }
247}
248
249impl<T: Serialize> IntoResponse for ValidatedJson<T> {
250 fn into_response(self) -> crate::response::Response {
251 Json(self.0).into_response()
252 }
253}
254
255#[derive(Debug, Clone)]
273pub struct Query<T>(pub T);
274
275impl<T: DeserializeOwned> FromRequestParts for Query<T> {
276 fn from_request_parts(req: &Request) -> Result<Self> {
277 let query = req.query_string().unwrap_or("");
278 let value: T = serde_urlencoded::from_str(query)
279 .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
280 Ok(Query(value))
281 }
282}
283
284impl<T> Deref for Query<T> {
285 type Target = T;
286
287 fn deref(&self) -> &Self::Target {
288 &self.0
289 }
290}
291
292#[derive(Debug, Clone)]
314pub struct Path<T>(pub T);
315
316impl<T: FromStr> FromRequestParts for Path<T>
317where
318 T::Err: std::fmt::Display,
319{
320 fn from_request_parts(req: &Request) -> Result<Self> {
321 let params = req.path_params();
322
323 if let Some((_, value)) = params.iter().next() {
325 let parsed = value
326 .parse::<T>()
327 .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
328 return Ok(Path(parsed));
329 }
330
331 Err(ApiError::internal("Missing path parameter"))
332 }
333}
334
335impl<T> Deref for Path<T> {
336 type Target = T;
337
338 fn deref(&self) -> &Self::Target {
339 &self.0
340 }
341}
342
343#[derive(Debug, Clone)]
360pub struct State<T>(pub T);
361
362impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
363 fn from_request_parts(req: &Request) -> Result<Self> {
364 req.state().get::<T>().cloned().map(State).ok_or_else(|| {
365 ApiError::internal(format!(
366 "State of type `{}` not found. Did you forget to call .state()?",
367 std::any::type_name::<T>()
368 ))
369 })
370 }
371}
372
373impl<T> Deref for State<T> {
374 type Target = T;
375
376 fn deref(&self) -> &Self::Target {
377 &self.0
378 }
379}
380
381#[derive(Debug, Clone)]
383pub struct Body(pub Bytes);
384
385impl FromRequest for Body {
386 async fn from_request(req: &mut Request) -> Result<Self> {
387 req.load_body().await?;
388 let body = req
389 .take_body()
390 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
391 Ok(Body(body))
392 }
393}
394
395impl Deref for Body {
396 type Target = Bytes;
397
398 fn deref(&self) -> &Self::Target {
399 &self.0
400 }
401}
402
403pub struct BodyStream(pub StreamingBody);
405
406impl FromRequest for BodyStream {
407 async fn from_request(req: &mut Request) -> Result<Self> {
408 let config = StreamingConfig::default();
409
410 if let Some(stream) = req.take_stream() {
411 Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
412 } else if let Some(bytes) = req.take_body() {
413 let stream = futures_util::stream::once(async move { Ok(bytes) });
415 Ok(BodyStream(StreamingBody::from_stream(
416 stream,
417 config.max_body_size,
418 )))
419 } else {
420 Err(ApiError::internal("Body already consumed"))
421 }
422 }
423}
424
425impl Deref for BodyStream {
426 type Target = StreamingBody;
427
428 fn deref(&self) -> &Self::Target {
429 &self.0
430 }
431}
432
433impl DerefMut for BodyStream {
434 fn deref_mut(&mut self) -> &mut Self::Target {
435 &mut self.0
436 }
437}
438
439impl futures_util::Stream for BodyStream {
441 type Item = Result<Bytes, ApiError>;
442
443 fn poll_next(
444 mut self: std::pin::Pin<&mut Self>,
445 cx: &mut std::task::Context<'_>,
446 ) -> std::task::Poll<Option<Self::Item>> {
447 std::pin::Pin::new(&mut self.0).poll_next(cx)
448 }
449}
450
451impl<T: FromRequestParts> FromRequestParts for Option<T> {
455 fn from_request_parts(req: &Request) -> Result<Self> {
456 Ok(T::from_request_parts(req).ok())
457 }
458}
459
460#[derive(Debug, Clone)]
478pub struct Headers(pub http::HeaderMap);
479
480impl Headers {
481 pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
483 self.0.get(name)
484 }
485
486 pub fn contains(&self, name: &str) -> bool {
488 self.0.contains_key(name)
489 }
490
491 pub fn len(&self) -> usize {
493 self.0.len()
494 }
495
496 pub fn is_empty(&self) -> bool {
498 self.0.is_empty()
499 }
500
501 pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
503 self.0.iter()
504 }
505}
506
507impl FromRequestParts for Headers {
508 fn from_request_parts(req: &Request) -> Result<Self> {
509 Ok(Headers(req.headers().clone()))
510 }
511}
512
513impl Deref for Headers {
514 type Target = http::HeaderMap;
515
516 fn deref(&self) -> &Self::Target {
517 &self.0
518 }
519}
520
521#[derive(Debug, Clone)]
540pub struct HeaderValue(pub String, pub &'static str);
541
542impl HeaderValue {
543 pub fn new(name: &'static str, value: String) -> Self {
545 Self(value, name)
546 }
547
548 pub fn value(&self) -> &str {
550 &self.0
551 }
552
553 pub fn name(&self) -> &'static str {
555 self.1
556 }
557
558 pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
560 req.headers()
561 .get(name)
562 .and_then(|v| v.to_str().ok())
563 .map(|s| HeaderValue(s.to_string(), name))
564 .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
565 }
566}
567
568impl Deref for HeaderValue {
569 type Target = String;
570
571 fn deref(&self) -> &Self::Target {
572 &self.0
573 }
574}
575
576#[derive(Debug, Clone)]
594pub struct Extension<T>(pub T);
595
596impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
597 fn from_request_parts(req: &Request) -> Result<Self> {
598 req.extensions()
599 .get::<T>()
600 .cloned()
601 .map(Extension)
602 .ok_or_else(|| {
603 ApiError::internal(format!(
604 "Extension of type `{}` not found. Did middleware insert it?",
605 std::any::type_name::<T>()
606 ))
607 })
608 }
609}
610
611impl<T> Deref for Extension<T> {
612 type Target = T;
613
614 fn deref(&self) -> &Self::Target {
615 &self.0
616 }
617}
618
619impl<T> DerefMut for Extension<T> {
620 fn deref_mut(&mut self) -> &mut Self::Target {
621 &mut self.0
622 }
623}
624
625#[derive(Debug, Clone)]
640pub struct ClientIp(pub std::net::IpAddr);
641
642impl ClientIp {
643 pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
645 if trust_proxy {
646 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
648 if let Ok(forwarded_str) = forwarded.to_str() {
649 if let Some(first_ip) = forwarded_str.split(',').next() {
651 if let Ok(ip) = first_ip.trim().parse() {
652 return Ok(ClientIp(ip));
653 }
654 }
655 }
656 }
657 }
658
659 if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
661 return Ok(ClientIp(addr.ip()));
662 }
663
664 Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
666 127, 0, 0, 1,
667 ))))
668 }
669}
670
671impl FromRequestParts for ClientIp {
672 fn from_request_parts(req: &Request) -> Result<Self> {
673 Self::extract_with_config(req, true)
675 }
676}
677
678#[cfg(feature = "cookies")]
696#[derive(Debug, Clone)]
697pub struct Cookies(pub cookie::CookieJar);
698
699#[cfg(feature = "cookies")]
700impl Cookies {
701 pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
703 self.0.get(name)
704 }
705
706 pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
708 self.0.iter()
709 }
710
711 pub fn contains(&self, name: &str) -> bool {
713 self.0.get(name).is_some()
714 }
715}
716
717#[cfg(feature = "cookies")]
718impl FromRequestParts for Cookies {
719 fn from_request_parts(req: &Request) -> Result<Self> {
720 let mut jar = cookie::CookieJar::new();
721
722 if let Some(cookie_header) = req.headers().get(header::COOKIE) {
723 if let Ok(cookie_str) = cookie_header.to_str() {
724 for cookie_part in cookie_str.split(';') {
726 let trimmed = cookie_part.trim();
727 if !trimmed.is_empty() {
728 if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
729 jar.add_original(cookie.into_owned());
730 }
731 }
732 }
733 }
734 }
735
736 Ok(Cookies(jar))
737 }
738}
739
740#[cfg(feature = "cookies")]
741impl Deref for Cookies {
742 type Target = cookie::CookieJar;
743
744 fn deref(&self) -> &Self::Target {
745 &self.0
746 }
747}
748
749macro_rules! impl_from_request_parts_for_primitives {
751 ($($ty:ty),*) => {
752 $(
753 impl FromRequestParts for $ty {
754 fn from_request_parts(req: &Request) -> Result<Self> {
755 let Path(value) = Path::<$ty>::from_request_parts(req)?;
756 Ok(value)
757 }
758 }
759 )*
760 };
761}
762
763impl_from_request_parts_for_primitives!(
764 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
765);
766
767use rustapi_openapi::utoipa_types::openapi;
770use rustapi_openapi::{
771 IntoParams, MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier,
772 ResponseSpec, Schema, SchemaRef,
773};
774use std::collections::HashMap;
775
776impl<T: for<'a> Schema<'a>> OperationModifier for ValidatedJson<T> {
778 fn update_operation(op: &mut Operation) {
779 let (name, _) = T::schema();
780
781 let schema_ref = SchemaRef::Ref {
782 reference: format!("#/components/schemas/{}", name),
783 };
784
785 let mut content = HashMap::new();
786 content.insert(
787 "application/json".to_string(),
788 MediaType { schema: schema_ref },
789 );
790
791 op.request_body = Some(RequestBody {
792 required: true,
793 content,
794 });
795
796 op.responses.insert(
798 "422".to_string(),
799 ResponseSpec {
800 description: "Validation Error".to_string(),
801 content: {
802 let mut map = HashMap::new();
803 map.insert(
804 "application/json".to_string(),
805 MediaType {
806 schema: SchemaRef::Ref {
807 reference: "#/components/schemas/ValidationErrorSchema".to_string(),
808 },
809 },
810 );
811 Some(map)
812 },
813 },
814 );
815 }
816}
817
818impl<T: for<'a> Schema<'a>> OperationModifier for Json<T> {
820 fn update_operation(op: &mut Operation) {
821 let (name, _) = T::schema();
822
823 let schema_ref = SchemaRef::Ref {
824 reference: format!("#/components/schemas/{}", name),
825 };
826
827 let mut content = HashMap::new();
828 content.insert(
829 "application/json".to_string(),
830 MediaType { schema: schema_ref },
831 );
832
833 op.request_body = Some(RequestBody {
834 required: true,
835 content,
836 });
837 }
838}
839
840impl<T> OperationModifier for Path<T> {
844 fn update_operation(_op: &mut Operation) {
845 }
852}
853
854impl<T: IntoParams> OperationModifier for Query<T> {
856 fn update_operation(op: &mut Operation) {
857 let params = T::into_params(|| Some(openapi::path::ParameterIn::Query));
858
859 let new_params: Vec<Parameter> = params
860 .into_iter()
861 .map(|p| {
862 let schema = match p.schema {
863 Some(schema) => match schema {
864 openapi::RefOr::Ref(r) => SchemaRef::Ref {
865 reference: r.ref_location,
866 },
867 openapi::RefOr::T(s) => {
868 let value = serde_json::to_value(s).unwrap_or(serde_json::Value::Null);
869 SchemaRef::Inline(value)
870 }
871 },
872 None => SchemaRef::Inline(serde_json::Value::Null),
873 };
874
875 let required = match p.required {
876 openapi::Required::True => true,
877 openapi::Required::False => false,
878 };
879
880 Parameter {
881 name: p.name,
882 location: "query".to_string(), required,
884 description: p.description,
885 schema,
886 }
887 })
888 .collect();
889
890 if let Some(existing) = &mut op.parameters {
891 existing.extend(new_params);
892 } else {
893 op.parameters = Some(new_params);
894 }
895 }
896}
897
898impl<T> OperationModifier for State<T> {
900 fn update_operation(_op: &mut Operation) {}
901}
902
903impl OperationModifier for Body {
905 fn update_operation(op: &mut Operation) {
906 let mut content = HashMap::new();
907 content.insert(
908 "application/octet-stream".to_string(),
909 MediaType {
910 schema: SchemaRef::Inline(
911 serde_json::json!({ "type": "string", "format": "binary" }),
912 ),
913 },
914 );
915
916 op.request_body = Some(RequestBody {
917 required: true,
918 content,
919 });
920 }
921}
922
923impl OperationModifier for BodyStream {
925 fn update_operation(op: &mut Operation) {
926 let mut content = HashMap::new();
927 content.insert(
928 "application/octet-stream".to_string(),
929 MediaType {
930 schema: SchemaRef::Inline(
931 serde_json::json!({ "type": "string", "format": "binary" }),
932 ),
933 },
934 );
935
936 op.request_body = Some(RequestBody {
937 required: true,
938 content,
939 });
940 }
941}
942
943impl<T: for<'a> Schema<'a>> ResponseModifier for Json<T> {
947 fn update_response(op: &mut Operation) {
948 let (name, _) = T::schema();
949
950 let schema_ref = SchemaRef::Ref {
951 reference: format!("#/components/schemas/{}", name),
952 };
953
954 op.responses.insert(
955 "200".to_string(),
956 ResponseSpec {
957 description: "Successful response".to_string(),
958 content: {
959 let mut map = HashMap::new();
960 map.insert(
961 "application/json".to_string(),
962 MediaType { schema: schema_ref },
963 );
964 Some(map)
965 },
966 },
967 );
968 }
969}
970
971#[cfg(test)]
972mod tests {
973 use super::*;
974 use crate::path_params::PathParams;
975 use bytes::Bytes;
976 use http::{Extensions, Method};
977 use proptest::prelude::*;
978 use proptest::test_runner::TestCaseError;
979 use std::sync::Arc;
980
981 fn create_test_request_with_headers(
983 method: Method,
984 path: &str,
985 headers: Vec<(&str, &str)>,
986 ) -> Request {
987 let uri: http::Uri = path.parse().unwrap();
988 let mut builder = http::Request::builder().method(method).uri(uri);
989
990 for (name, value) in headers {
991 builder = builder.header(name, value);
992 }
993
994 let req = builder.body(()).unwrap();
995 let (parts, _) = req.into_parts();
996
997 Request::new(
998 parts,
999 crate::request::BodyVariant::Buffered(Bytes::new()),
1000 Arc::new(Extensions::new()),
1001 PathParams::new(),
1002 )
1003 }
1004
1005 fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
1007 method: Method,
1008 path: &str,
1009 extension: T,
1010 ) -> Request {
1011 let uri: http::Uri = path.parse().unwrap();
1012 let builder = http::Request::builder().method(method).uri(uri);
1013
1014 let req = builder.body(()).unwrap();
1015 let (mut parts, _) = req.into_parts();
1016 parts.extensions.insert(extension);
1017
1018 Request::new(
1019 parts,
1020 crate::request::BodyVariant::Buffered(Bytes::new()),
1021 Arc::new(Extensions::new()),
1022 PathParams::new(),
1023 )
1024 }
1025
1026 proptest! {
1033 #![proptest_config(ProptestConfig::with_cases(100))]
1034
1035 #[test]
1036 fn prop_headers_extractor_completeness(
1037 headers in prop::collection::vec(
1040 (
1041 "[a-z][a-z0-9-]{0,20}", "[a-zA-Z0-9 ]{1,50}" ),
1044 0..10
1045 )
1046 ) {
1047 let result: Result<(), TestCaseError> = (|| {
1048 let header_tuples: Vec<(&str, &str)> = headers
1050 .iter()
1051 .map(|(k, v)| (k.as_str(), v.as_str()))
1052 .collect();
1053
1054 let request = create_test_request_with_headers(
1056 Method::GET,
1057 "/test",
1058 header_tuples.clone(),
1059 );
1060
1061 let extracted = Headers::from_request_parts(&request)
1063 .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
1064
1065 for (name, value) in &headers {
1068 let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
1070 prop_assert!(
1071 !all_values.is_empty(),
1072 "Header '{}' not found",
1073 name
1074 );
1075
1076 let value_found = all_values.iter().any(|v| {
1078 v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
1079 });
1080
1081 prop_assert!(
1082 value_found,
1083 "Header '{}' value '{}' not found in extracted values",
1084 name,
1085 value
1086 );
1087 }
1088
1089 Ok(())
1090 })();
1091 result?;
1092 }
1093 }
1094
1095 proptest! {
1102 #![proptest_config(ProptestConfig::with_cases(100))]
1103
1104 #[test]
1105 fn prop_header_value_extractor_correctness(
1106 header_name in "[a-z][a-z0-9-]{0,20}",
1107 header_value in "[a-zA-Z0-9 ]{1,50}",
1108 has_header in prop::bool::ANY,
1109 ) {
1110 let result: Result<(), TestCaseError> = (|| {
1111 let headers = if has_header {
1112 vec![(header_name.as_str(), header_value.as_str())]
1113 } else {
1114 vec![]
1115 };
1116
1117 let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1118
1119 let test_header = "x-test-header";
1122 let request_with_known_header = if has_header {
1123 create_test_request_with_headers(
1124 Method::GET,
1125 "/test",
1126 vec![(test_header, header_value.as_str())],
1127 )
1128 } else {
1129 create_test_request_with_headers(Method::GET, "/test", vec![])
1130 };
1131
1132 let result = HeaderValue::extract(&request_with_known_header, test_header);
1133
1134 if has_header {
1135 let extracted = result
1136 .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1137 prop_assert_eq!(
1138 extracted.value(),
1139 header_value.as_str(),
1140 "Header value mismatch"
1141 );
1142 } else {
1143 prop_assert!(
1144 result.is_err(),
1145 "Expected error when header is missing"
1146 );
1147 }
1148
1149 Ok(())
1150 })();
1151 result?;
1152 }
1153 }
1154
1155 proptest! {
1162 #![proptest_config(ProptestConfig::with_cases(100))]
1163
1164 #[test]
1165 fn prop_client_ip_extractor_with_forwarding(
1166 forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1168 .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1169 socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1170 .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1171 has_forwarded_header in prop::bool::ANY,
1172 trust_proxy in prop::bool::ANY,
1173 ) {
1174 let result: Result<(), TestCaseError> = (|| {
1175 let headers = if has_forwarded_header {
1176 vec![("x-forwarded-for", forwarded_ip.as_str())]
1177 } else {
1178 vec![]
1179 };
1180
1181 let uri: http::Uri = "/test".parse().unwrap();
1183 let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1184 for (name, value) in &headers {
1185 builder = builder.header(*name, *value);
1186 }
1187 let req = builder.body(()).unwrap();
1188 let (mut parts, _) = req.into_parts();
1189
1190 let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1192 parts.extensions.insert(socket_addr);
1193
1194 let request = Request::new(
1195 parts,
1196 crate::request::BodyVariant::Buffered(Bytes::new()),
1197 Arc::new(Extensions::new()),
1198 PathParams::new(),
1199 );
1200
1201 let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1202 .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1203
1204 if trust_proxy && has_forwarded_header {
1205 let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1207 .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1208 prop_assert_eq!(
1209 extracted.0,
1210 expected_ip,
1211 "Should use X-Forwarded-For IP when trust_proxy is enabled"
1212 );
1213 } else {
1214 prop_assert_eq!(
1216 extracted.0,
1217 socket_ip,
1218 "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1219 );
1220 }
1221
1222 Ok(())
1223 })();
1224 result?;
1225 }
1226 }
1227
1228 proptest! {
1235 #![proptest_config(ProptestConfig::with_cases(100))]
1236
1237 #[test]
1238 fn prop_extension_extractor_retrieval(
1239 value in any::<i64>(),
1240 has_extension in prop::bool::ANY,
1241 ) {
1242 let result: Result<(), TestCaseError> = (|| {
1243 #[derive(Clone, Debug, PartialEq)]
1245 struct TestExtension(i64);
1246
1247 let uri: http::Uri = "/test".parse().unwrap();
1248 let builder = http::Request::builder().method(Method::GET).uri(uri);
1249 let req = builder.body(()).unwrap();
1250 let (mut parts, _) = req.into_parts();
1251
1252 if has_extension {
1253 parts.extensions.insert(TestExtension(value));
1254 }
1255
1256 let request = Request::new(
1257 parts,
1258 crate::request::BodyVariant::Buffered(Bytes::new()),
1259 Arc::new(Extensions::new()),
1260 PathParams::new(),
1261 );
1262
1263 let result = Extension::<TestExtension>::from_request_parts(&request);
1264
1265 if has_extension {
1266 let extracted = result
1267 .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1268 prop_assert_eq!(
1269 extracted.0,
1270 TestExtension(value),
1271 "Extension value mismatch"
1272 );
1273 } else {
1274 prop_assert!(
1275 result.is_err(),
1276 "Expected error when extension is missing"
1277 );
1278 }
1279
1280 Ok(())
1281 })();
1282 result?;
1283 }
1284 }
1285
1286 #[test]
1289 fn test_headers_extractor_basic() {
1290 let request = create_test_request_with_headers(
1291 Method::GET,
1292 "/test",
1293 vec![
1294 ("content-type", "application/json"),
1295 ("accept", "text/html"),
1296 ],
1297 );
1298
1299 let headers = Headers::from_request_parts(&request).unwrap();
1300
1301 assert!(headers.contains("content-type"));
1302 assert!(headers.contains("accept"));
1303 assert!(!headers.contains("x-custom"));
1304 assert_eq!(headers.len(), 2);
1305 }
1306
1307 #[test]
1308 fn test_header_value_extractor_present() {
1309 let request = create_test_request_with_headers(
1310 Method::GET,
1311 "/test",
1312 vec![("authorization", "Bearer token123")],
1313 );
1314
1315 let result = HeaderValue::extract(&request, "authorization");
1316 assert!(result.is_ok());
1317 assert_eq!(result.unwrap().value(), "Bearer token123");
1318 }
1319
1320 #[test]
1321 fn test_header_value_extractor_missing() {
1322 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1323
1324 let result = HeaderValue::extract(&request, "authorization");
1325 assert!(result.is_err());
1326 }
1327
1328 #[test]
1329 fn test_client_ip_from_forwarded_header() {
1330 let request = create_test_request_with_headers(
1331 Method::GET,
1332 "/test",
1333 vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1334 );
1335
1336 let ip = ClientIp::extract_with_config(&request, true).unwrap();
1337 assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1338 }
1339
1340 #[test]
1341 fn test_client_ip_ignores_forwarded_when_not_trusted() {
1342 let uri: http::Uri = "/test".parse().unwrap();
1343 let builder = http::Request::builder()
1344 .method(Method::GET)
1345 .uri(uri)
1346 .header("x-forwarded-for", "192.168.1.100");
1347 let req = builder.body(()).unwrap();
1348 let (mut parts, _) = req.into_parts();
1349
1350 let socket_addr = std::net::SocketAddr::new(
1351 std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1352 8080,
1353 );
1354 parts.extensions.insert(socket_addr);
1355
1356 let request = Request::new(
1357 parts,
1358 crate::request::BodyVariant::Buffered(Bytes::new()),
1359 Arc::new(Extensions::new()),
1360 PathParams::new(),
1361 );
1362
1363 let ip = ClientIp::extract_with_config(&request, false).unwrap();
1364 assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1365 }
1366
1367 #[test]
1368 fn test_extension_extractor_present() {
1369 #[derive(Clone, Debug, PartialEq)]
1370 struct MyData(String);
1371
1372 let request =
1373 create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1374
1375 let result = Extension::<MyData>::from_request_parts(&request);
1376 assert!(result.is_ok());
1377 assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1378 }
1379
1380 #[test]
1381 fn test_extension_extractor_missing() {
1382 #[derive(Clone, Debug)]
1383 #[allow(dead_code)]
1384 struct MyData(String);
1385
1386 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1387
1388 let result = Extension::<MyData>::from_request_parts(&request);
1389 assert!(result.is_err());
1390 }
1391
1392 #[cfg(feature = "cookies")]
1394 mod cookies_tests {
1395 use super::*;
1396
1397 proptest! {
1405 #![proptest_config(ProptestConfig::with_cases(100))]
1406
1407 #[test]
1408 fn prop_cookies_extractor_parsing(
1409 cookies in prop::collection::vec(
1412 (
1413 "[a-zA-Z][a-zA-Z0-9_]{0,15}", "[a-zA-Z0-9]{1,30}" ),
1416 0..5
1417 )
1418 ) {
1419 let result: Result<(), TestCaseError> = (|| {
1420 let cookie_header = cookies
1422 .iter()
1423 .map(|(name, value)| format!("{}={}", name, value))
1424 .collect::<Vec<_>>()
1425 .join("; ");
1426
1427 let headers = if !cookies.is_empty() {
1428 vec![("cookie", cookie_header.as_str())]
1429 } else {
1430 vec![]
1431 };
1432
1433 let request = create_test_request_with_headers(Method::GET, "/test", headers);
1434
1435 let extracted = Cookies::from_request_parts(&request)
1437 .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1438
1439 let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1441 for (name, value) in &cookies {
1442 expected_cookies.insert(name.as_str(), value.as_str());
1443 }
1444
1445 for (name, expected_value) in &expected_cookies {
1447 let cookie = extracted.get(name)
1448 .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1449
1450 prop_assert_eq!(
1451 cookie.value(),
1452 *expected_value,
1453 "Cookie '{}' value mismatch",
1454 name
1455 );
1456 }
1457
1458 let extracted_count = extracted.iter().count();
1460 prop_assert_eq!(
1461 extracted_count,
1462 expected_cookies.len(),
1463 "Expected {} unique cookies, got {}",
1464 expected_cookies.len(),
1465 extracted_count
1466 );
1467
1468 Ok(())
1469 })();
1470 result?;
1471 }
1472 }
1473
1474 #[test]
1475 fn test_cookies_extractor_basic() {
1476 let request = create_test_request_with_headers(
1477 Method::GET,
1478 "/test",
1479 vec![("cookie", "session=abc123; user=john")],
1480 );
1481
1482 let cookies = Cookies::from_request_parts(&request).unwrap();
1483
1484 assert!(cookies.contains("session"));
1485 assert!(cookies.contains("user"));
1486 assert!(!cookies.contains("other"));
1487
1488 assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1489 assert_eq!(cookies.get("user").unwrap().value(), "john");
1490 }
1491
1492 #[test]
1493 fn test_cookies_extractor_empty() {
1494 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1495
1496 let cookies = Cookies::from_request_parts(&request).unwrap();
1497 assert_eq!(cookies.iter().count(), 0);
1498 }
1499
1500 #[test]
1501 fn test_cookies_extractor_single() {
1502 let request = create_test_request_with_headers(
1503 Method::GET,
1504 "/test",
1505 vec![("cookie", "token=xyz789")],
1506 );
1507
1508 let cookies = Cookies::from_request_parts(&request).unwrap();
1509 assert_eq!(cookies.iter().count(), 1);
1510 assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1511 }
1512 }
1513}