1use crate::error::{ApiError, Result};
58use crate::json;
59use crate::request::Request;
60use crate::response::IntoResponse;
61use crate::stream::{StreamingBody, StreamingConfig};
62use bytes::Bytes;
63use http::{header, StatusCode};
64
65use serde::de::DeserializeOwned;
66use serde::Serialize;
67use std::future::Future;
68use std::ops::{Deref, DerefMut};
69use std::str::FromStr;
70
71pub trait FromRequestParts: Sized {
96 fn from_request_parts(req: &Request) -> Result<Self>;
98}
99
100pub trait FromRequest: Sized {
130 fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
132}
133
134impl<T: FromRequestParts> FromRequest for T {
136 async fn from_request(req: &mut Request) -> Result<Self> {
137 T::from_request_parts(req)
138 }
139}
140
141#[derive(Debug, Clone, Copy, Default)]
160pub struct Json<T>(pub T);
161
162impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
163 async fn from_request(req: &mut Request) -> Result<Self> {
164 req.load_body().await?;
165 let body = req
166 .take_body()
167 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
168
169 let value: T = json::from_slice(&body)?;
171 Ok(Json(value))
172 }
173}
174
175impl<T> Deref for Json<T> {
176 type Target = T;
177
178 fn deref(&self) -> &Self::Target {
179 &self.0
180 }
181}
182
183impl<T> DerefMut for Json<T> {
184 fn deref_mut(&mut self) -> &mut Self::Target {
185 &mut self.0
186 }
187}
188
189impl<T> From<T> for Json<T> {
190 fn from(value: T) -> Self {
191 Json(value)
192 }
193}
194
195const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
198
199impl<T: Serialize> IntoResponse for Json<T> {
201 fn into_response(self) -> crate::response::Response {
202 match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
204 Ok(body) => http::Response::builder()
205 .status(StatusCode::OK)
206 .header(header::CONTENT_TYPE, "application/json")
207 .body(crate::response::Body::from(body))
208 .unwrap(),
209 Err(err) => {
210 ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
211 }
212 }
213 }
214}
215
216#[derive(Debug, Clone, Copy, Default)]
242pub struct ValidatedJson<T>(pub T);
243
244impl<T> ValidatedJson<T> {
245 pub fn new(value: T) -> Self {
247 Self(value)
248 }
249
250 pub fn into_inner(self) -> T {
252 self.0
253 }
254}
255
256impl<T: DeserializeOwned + rustapi_validate::Validate + Send> FromRequest for ValidatedJson<T> {
257 async fn from_request(req: &mut Request) -> Result<Self> {
258 req.load_body().await?;
259 let body = req
261 .take_body()
262 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
263
264 let value: T = json::from_slice(&body)?;
265
266 if let Err(validation_error) = rustapi_validate::Validate::validate(&value) {
268 return Err(validation_error.into());
270 }
271
272 Ok(ValidatedJson(value))
273 }
274}
275
276impl<T> Deref for ValidatedJson<T> {
277 type Target = T;
278
279 fn deref(&self) -> &Self::Target {
280 &self.0
281 }
282}
283
284impl<T> DerefMut for ValidatedJson<T> {
285 fn deref_mut(&mut self) -> &mut Self::Target {
286 &mut self.0
287 }
288}
289
290impl<T> From<T> for ValidatedJson<T> {
291 fn from(value: T) -> Self {
292 ValidatedJson(value)
293 }
294}
295
296impl<T: Serialize> IntoResponse for ValidatedJson<T> {
297 fn into_response(self) -> crate::response::Response {
298 Json(self.0).into_response()
299 }
300}
301
302#[derive(Debug, Clone)]
320pub struct Query<T>(pub T);
321
322impl<T: DeserializeOwned> FromRequestParts for Query<T> {
323 fn from_request_parts(req: &Request) -> Result<Self> {
324 let query = req.query_string().unwrap_or("");
325 let value: T = serde_urlencoded::from_str(query)
326 .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
327 Ok(Query(value))
328 }
329}
330
331impl<T> Deref for Query<T> {
332 type Target = T;
333
334 fn deref(&self) -> &Self::Target {
335 &self.0
336 }
337}
338
339#[derive(Debug, Clone)]
361pub struct Path<T>(pub T);
362
363impl<T: FromStr> FromRequestParts for Path<T>
364where
365 T::Err: std::fmt::Display,
366{
367 fn from_request_parts(req: &Request) -> Result<Self> {
368 let params = req.path_params();
369
370 if let Some((_, value)) = params.iter().next() {
372 let parsed = value
373 .parse::<T>()
374 .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
375 return Ok(Path(parsed));
376 }
377
378 Err(ApiError::internal("Missing path parameter"))
379 }
380}
381
382impl<T> Deref for Path<T> {
383 type Target = T;
384
385 fn deref(&self) -> &Self::Target {
386 &self.0
387 }
388}
389
390#[derive(Debug, Clone)]
410pub struct Typed<T>(pub T);
411
412impl<T: DeserializeOwned + Send> FromRequestParts for Typed<T> {
413 fn from_request_parts(req: &Request) -> Result<Self> {
414 let params = req.path_params();
415 let mut map = serde_json::Map::new();
416 for (k, v) in params.iter() {
417 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
418 }
419 let value = serde_json::Value::Object(map);
420 let parsed: T = serde_json::from_value(value)
421 .map_err(|e| ApiError::bad_request(format!("Invalid path parameters: {}", e)))?;
422 Ok(Typed(parsed))
423 }
424}
425
426impl<T> Deref for Typed<T> {
427 type Target = T;
428
429 fn deref(&self) -> &Self::Target {
430 &self.0
431 }
432}
433
434#[derive(Debug, Clone)]
451pub struct State<T>(pub T);
452
453impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
454 fn from_request_parts(req: &Request) -> Result<Self> {
455 req.state().get::<T>().cloned().map(State).ok_or_else(|| {
456 ApiError::internal(format!(
457 "State of type `{}` not found. Did you forget to call .state()?",
458 std::any::type_name::<T>()
459 ))
460 })
461 }
462}
463
464impl<T> Deref for State<T> {
465 type Target = T;
466
467 fn deref(&self) -> &Self::Target {
468 &self.0
469 }
470}
471
472#[derive(Debug, Clone)]
474pub struct Body(pub Bytes);
475
476impl FromRequest for Body {
477 async fn from_request(req: &mut Request) -> Result<Self> {
478 req.load_body().await?;
479 let body = req
480 .take_body()
481 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
482 Ok(Body(body))
483 }
484}
485
486impl Deref for Body {
487 type Target = Bytes;
488
489 fn deref(&self) -> &Self::Target {
490 &self.0
491 }
492}
493
494pub struct BodyStream(pub StreamingBody);
496
497impl FromRequest for BodyStream {
498 async fn from_request(req: &mut Request) -> Result<Self> {
499 let config = StreamingConfig::default();
500
501 if let Some(stream) = req.take_stream() {
502 Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
503 } else if let Some(bytes) = req.take_body() {
504 let stream = futures_util::stream::once(async move { Ok(bytes) });
506 Ok(BodyStream(StreamingBody::from_stream(
507 stream,
508 config.max_body_size,
509 )))
510 } else {
511 Err(ApiError::internal("Body already consumed"))
512 }
513 }
514}
515
516impl Deref for BodyStream {
517 type Target = StreamingBody;
518
519 fn deref(&self) -> &Self::Target {
520 &self.0
521 }
522}
523
524impl DerefMut for BodyStream {
525 fn deref_mut(&mut self) -> &mut Self::Target {
526 &mut self.0
527 }
528}
529
530impl futures_util::Stream for BodyStream {
532 type Item = Result<Bytes, ApiError>;
533
534 fn poll_next(
535 mut self: std::pin::Pin<&mut Self>,
536 cx: &mut std::task::Context<'_>,
537 ) -> std::task::Poll<Option<Self::Item>> {
538 std::pin::Pin::new(&mut self.0).poll_next(cx)
539 }
540}
541
542impl<T: FromRequestParts> FromRequestParts for Option<T> {
546 fn from_request_parts(req: &Request) -> Result<Self> {
547 Ok(T::from_request_parts(req).ok())
548 }
549}
550
551#[derive(Debug, Clone)]
569pub struct Headers(pub http::HeaderMap);
570
571impl Headers {
572 pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
574 self.0.get(name)
575 }
576
577 pub fn contains(&self, name: &str) -> bool {
579 self.0.contains_key(name)
580 }
581
582 pub fn len(&self) -> usize {
584 self.0.len()
585 }
586
587 pub fn is_empty(&self) -> bool {
589 self.0.is_empty()
590 }
591
592 pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
594 self.0.iter()
595 }
596}
597
598impl FromRequestParts for Headers {
599 fn from_request_parts(req: &Request) -> Result<Self> {
600 Ok(Headers(req.headers().clone()))
601 }
602}
603
604impl Deref for Headers {
605 type Target = http::HeaderMap;
606
607 fn deref(&self) -> &Self::Target {
608 &self.0
609 }
610}
611
612#[derive(Debug, Clone)]
631pub struct HeaderValue(pub String, pub &'static str);
632
633impl HeaderValue {
634 pub fn new(name: &'static str, value: String) -> Self {
636 Self(value, name)
637 }
638
639 pub fn value(&self) -> &str {
641 &self.0
642 }
643
644 pub fn name(&self) -> &'static str {
646 self.1
647 }
648
649 pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
651 req.headers()
652 .get(name)
653 .and_then(|v| v.to_str().ok())
654 .map(|s| HeaderValue(s.to_string(), name))
655 .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
656 }
657}
658
659impl Deref for HeaderValue {
660 type Target = String;
661
662 fn deref(&self) -> &Self::Target {
663 &self.0
664 }
665}
666
667#[derive(Debug, Clone)]
685pub struct Extension<T>(pub T);
686
687impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
688 fn from_request_parts(req: &Request) -> Result<Self> {
689 req.extensions()
690 .get::<T>()
691 .cloned()
692 .map(Extension)
693 .ok_or_else(|| {
694 ApiError::internal(format!(
695 "Extension of type `{}` not found. Did middleware insert it?",
696 std::any::type_name::<T>()
697 ))
698 })
699 }
700}
701
702impl<T> Deref for Extension<T> {
703 type Target = T;
704
705 fn deref(&self) -> &Self::Target {
706 &self.0
707 }
708}
709
710impl<T> DerefMut for Extension<T> {
711 fn deref_mut(&mut self) -> &mut Self::Target {
712 &mut self.0
713 }
714}
715
716#[derive(Debug, Clone)]
731pub struct ClientIp(pub std::net::IpAddr);
732
733impl ClientIp {
734 pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
736 if trust_proxy {
737 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
739 if let Ok(forwarded_str) = forwarded.to_str() {
740 if let Some(first_ip) = forwarded_str.split(',').next() {
742 if let Ok(ip) = first_ip.trim().parse() {
743 return Ok(ClientIp(ip));
744 }
745 }
746 }
747 }
748 }
749
750 if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
752 return Ok(ClientIp(addr.ip()));
753 }
754
755 Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
757 127, 0, 0, 1,
758 ))))
759 }
760}
761
762impl FromRequestParts for ClientIp {
763 fn from_request_parts(req: &Request) -> Result<Self> {
764 Self::extract_with_config(req, true)
766 }
767}
768
769#[cfg(feature = "cookies")]
787#[derive(Debug, Clone)]
788pub struct Cookies(pub cookie::CookieJar);
789
790#[cfg(feature = "cookies")]
791impl Cookies {
792 pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
794 self.0.get(name)
795 }
796
797 pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
799 self.0.iter()
800 }
801
802 pub fn contains(&self, name: &str) -> bool {
804 self.0.get(name).is_some()
805 }
806}
807
808#[cfg(feature = "cookies")]
809impl FromRequestParts for Cookies {
810 fn from_request_parts(req: &Request) -> Result<Self> {
811 let mut jar = cookie::CookieJar::new();
812
813 if let Some(cookie_header) = req.headers().get(header::COOKIE) {
814 if let Ok(cookie_str) = cookie_header.to_str() {
815 for cookie_part in cookie_str.split(';') {
817 let trimmed = cookie_part.trim();
818 if !trimmed.is_empty() {
819 if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
820 jar.add_original(cookie.into_owned());
821 }
822 }
823 }
824 }
825 }
826
827 Ok(Cookies(jar))
828 }
829}
830
831#[cfg(feature = "cookies")]
832impl Deref for Cookies {
833 type Target = cookie::CookieJar;
834
835 fn deref(&self) -> &Self::Target {
836 &self.0
837 }
838}
839
840macro_rules! impl_from_request_parts_for_primitives {
842 ($($ty:ty),*) => {
843 $(
844 impl FromRequestParts for $ty {
845 fn from_request_parts(req: &Request) -> Result<Self> {
846 let Path(value) = Path::<$ty>::from_request_parts(req)?;
847 Ok(value)
848 }
849 }
850 )*
851 };
852}
853
854impl_from_request_parts_for_primitives!(
855 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
856);
857
858use rustapi_openapi::utoipa_types::openapi;
861use rustapi_openapi::{
862 IntoParams, MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier,
863 ResponseSpec, Schema, SchemaRef,
864};
865use std::collections::HashMap;
866
867impl<T: for<'a> Schema<'a>> OperationModifier for ValidatedJson<T> {
869 fn update_operation(op: &mut Operation) {
870 let (name, _) = T::schema();
871
872 let schema_ref = SchemaRef::Ref {
873 reference: format!("#/components/schemas/{}", name),
874 };
875
876 let mut content = HashMap::new();
877 content.insert(
878 "application/json".to_string(),
879 MediaType { schema: schema_ref },
880 );
881
882 op.request_body = Some(RequestBody {
883 required: true,
884 content,
885 });
886
887 op.responses.insert(
889 "422".to_string(),
890 ResponseSpec {
891 description: "Validation Error".to_string(),
892 content: {
893 let mut map = HashMap::new();
894 map.insert(
895 "application/json".to_string(),
896 MediaType {
897 schema: SchemaRef::Ref {
898 reference: "#/components/schemas/ValidationErrorSchema".to_string(),
899 },
900 },
901 );
902 Some(map)
903 },
904 },
905 );
906 }
907}
908
909impl<T: for<'a> Schema<'a>> OperationModifier for Json<T> {
911 fn update_operation(op: &mut Operation) {
912 let (name, _) = T::schema();
913
914 let schema_ref = SchemaRef::Ref {
915 reference: format!("#/components/schemas/{}", name),
916 };
917
918 let mut content = HashMap::new();
919 content.insert(
920 "application/json".to_string(),
921 MediaType { schema: schema_ref },
922 );
923
924 op.request_body = Some(RequestBody {
925 required: true,
926 content,
927 });
928 }
929}
930
931impl<T> OperationModifier for Path<T> {
935 fn update_operation(_op: &mut Operation) {
936 }
943}
944
945impl<T> OperationModifier for Typed<T> {
947 fn update_operation(_op: &mut Operation) {
948 }
950}
951
952impl<T: IntoParams> OperationModifier for Query<T> {
954 fn update_operation(op: &mut Operation) {
955 let params = T::into_params(|| Some(openapi::path::ParameterIn::Query));
956
957 let new_params: Vec<Parameter> = params
958 .into_iter()
959 .map(|p| {
960 let schema = match p.schema {
961 Some(schema) => match schema {
962 openapi::RefOr::Ref(r) => SchemaRef::Ref {
963 reference: r.ref_location,
964 },
965 openapi::RefOr::T(s) => {
966 let value = serde_json::to_value(s).unwrap_or(serde_json::Value::Null);
967 SchemaRef::Inline(value)
968 }
969 },
970 None => SchemaRef::Inline(serde_json::Value::Null),
971 };
972
973 let required = match p.required {
974 openapi::Required::True => true,
975 openapi::Required::False => false,
976 };
977
978 Parameter {
979 name: p.name,
980 location: "query".to_string(), required,
982 description: p.description,
983 schema,
984 }
985 })
986 .collect();
987
988 if let Some(existing) = &mut op.parameters {
989 existing.extend(new_params);
990 } else {
991 op.parameters = Some(new_params);
992 }
993 }
994}
995
996impl<T> OperationModifier for State<T> {
998 fn update_operation(_op: &mut Operation) {}
999}
1000
1001impl OperationModifier for Body {
1003 fn update_operation(op: &mut Operation) {
1004 let mut content = HashMap::new();
1005 content.insert(
1006 "application/octet-stream".to_string(),
1007 MediaType {
1008 schema: SchemaRef::Inline(
1009 serde_json::json!({ "type": "string", "format": "binary" }),
1010 ),
1011 },
1012 );
1013
1014 op.request_body = Some(RequestBody {
1015 required: true,
1016 content,
1017 });
1018 }
1019}
1020
1021impl OperationModifier for BodyStream {
1023 fn update_operation(op: &mut Operation) {
1024 let mut content = HashMap::new();
1025 content.insert(
1026 "application/octet-stream".to_string(),
1027 MediaType {
1028 schema: SchemaRef::Inline(
1029 serde_json::json!({ "type": "string", "format": "binary" }),
1030 ),
1031 },
1032 );
1033
1034 op.request_body = Some(RequestBody {
1035 required: true,
1036 content,
1037 });
1038 }
1039}
1040
1041impl<T: for<'a> Schema<'a>> ResponseModifier for Json<T> {
1045 fn update_response(op: &mut Operation) {
1046 let (name, _) = T::schema();
1047
1048 let schema_ref = SchemaRef::Ref {
1049 reference: format!("#/components/schemas/{}", name),
1050 };
1051
1052 op.responses.insert(
1053 "200".to_string(),
1054 ResponseSpec {
1055 description: "Successful response".to_string(),
1056 content: {
1057 let mut map = HashMap::new();
1058 map.insert(
1059 "application/json".to_string(),
1060 MediaType { schema: schema_ref },
1061 );
1062 Some(map)
1063 },
1064 },
1065 );
1066 }
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071 use super::*;
1072 use crate::path_params::PathParams;
1073 use bytes::Bytes;
1074 use http::{Extensions, Method};
1075 use proptest::prelude::*;
1076 use proptest::test_runner::TestCaseError;
1077 use std::sync::Arc;
1078
1079 fn create_test_request_with_headers(
1081 method: Method,
1082 path: &str,
1083 headers: Vec<(&str, &str)>,
1084 ) -> Request {
1085 let uri: http::Uri = path.parse().unwrap();
1086 let mut builder = http::Request::builder().method(method).uri(uri);
1087
1088 for (name, value) in headers {
1089 builder = builder.header(name, value);
1090 }
1091
1092 let req = builder.body(()).unwrap();
1093 let (parts, _) = req.into_parts();
1094
1095 Request::new(
1096 parts,
1097 crate::request::BodyVariant::Buffered(Bytes::new()),
1098 Arc::new(Extensions::new()),
1099 PathParams::new(),
1100 )
1101 }
1102
1103 fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
1105 method: Method,
1106 path: &str,
1107 extension: T,
1108 ) -> Request {
1109 let uri: http::Uri = path.parse().unwrap();
1110 let builder = http::Request::builder().method(method).uri(uri);
1111
1112 let req = builder.body(()).unwrap();
1113 let (mut parts, _) = req.into_parts();
1114 parts.extensions.insert(extension);
1115
1116 Request::new(
1117 parts,
1118 crate::request::BodyVariant::Buffered(Bytes::new()),
1119 Arc::new(Extensions::new()),
1120 PathParams::new(),
1121 )
1122 }
1123
1124 proptest! {
1131 #![proptest_config(ProptestConfig::with_cases(100))]
1132
1133 #[test]
1134 fn prop_headers_extractor_completeness(
1135 headers in prop::collection::vec(
1138 (
1139 "[a-z][a-z0-9-]{0,20}", "[a-zA-Z0-9 ]{1,50}" ),
1142 0..10
1143 )
1144 ) {
1145 let result: Result<(), TestCaseError> = (|| {
1146 let header_tuples: Vec<(&str, &str)> = headers
1148 .iter()
1149 .map(|(k, v)| (k.as_str(), v.as_str()))
1150 .collect();
1151
1152 let request = create_test_request_with_headers(
1154 Method::GET,
1155 "/test",
1156 header_tuples.clone(),
1157 );
1158
1159 let extracted = Headers::from_request_parts(&request)
1161 .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
1162
1163 for (name, value) in &headers {
1166 let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
1168 prop_assert!(
1169 !all_values.is_empty(),
1170 "Header '{}' not found",
1171 name
1172 );
1173
1174 let value_found = all_values.iter().any(|v| {
1176 v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
1177 });
1178
1179 prop_assert!(
1180 value_found,
1181 "Header '{}' value '{}' not found in extracted values",
1182 name,
1183 value
1184 );
1185 }
1186
1187 Ok(())
1188 })();
1189 result?;
1190 }
1191 }
1192
1193 proptest! {
1200 #![proptest_config(ProptestConfig::with_cases(100))]
1201
1202 #[test]
1203 fn prop_header_value_extractor_correctness(
1204 header_name in "[a-z][a-z0-9-]{0,20}",
1205 header_value in "[a-zA-Z0-9 ]{1,50}",
1206 has_header in prop::bool::ANY,
1207 ) {
1208 let result: Result<(), TestCaseError> = (|| {
1209 let headers = if has_header {
1210 vec![(header_name.as_str(), header_value.as_str())]
1211 } else {
1212 vec![]
1213 };
1214
1215 let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1216
1217 let test_header = "x-test-header";
1220 let request_with_known_header = if has_header {
1221 create_test_request_with_headers(
1222 Method::GET,
1223 "/test",
1224 vec![(test_header, header_value.as_str())],
1225 )
1226 } else {
1227 create_test_request_with_headers(Method::GET, "/test", vec![])
1228 };
1229
1230 let result = HeaderValue::extract(&request_with_known_header, test_header);
1231
1232 if has_header {
1233 let extracted = result
1234 .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1235 prop_assert_eq!(
1236 extracted.value(),
1237 header_value.as_str(),
1238 "Header value mismatch"
1239 );
1240 } else {
1241 prop_assert!(
1242 result.is_err(),
1243 "Expected error when header is missing"
1244 );
1245 }
1246
1247 Ok(())
1248 })();
1249 result?;
1250 }
1251 }
1252
1253 proptest! {
1260 #![proptest_config(ProptestConfig::with_cases(100))]
1261
1262 #[test]
1263 fn prop_client_ip_extractor_with_forwarding(
1264 forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1266 .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1267 socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1268 .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1269 has_forwarded_header in prop::bool::ANY,
1270 trust_proxy in prop::bool::ANY,
1271 ) {
1272 let result: Result<(), TestCaseError> = (|| {
1273 let headers = if has_forwarded_header {
1274 vec![("x-forwarded-for", forwarded_ip.as_str())]
1275 } else {
1276 vec![]
1277 };
1278
1279 let uri: http::Uri = "/test".parse().unwrap();
1281 let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1282 for (name, value) in &headers {
1283 builder = builder.header(*name, *value);
1284 }
1285 let req = builder.body(()).unwrap();
1286 let (mut parts, _) = req.into_parts();
1287
1288 let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1290 parts.extensions.insert(socket_addr);
1291
1292 let request = Request::new(
1293 parts,
1294 crate::request::BodyVariant::Buffered(Bytes::new()),
1295 Arc::new(Extensions::new()),
1296 PathParams::new(),
1297 );
1298
1299 let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1300 .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1301
1302 if trust_proxy && has_forwarded_header {
1303 let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1305 .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1306 prop_assert_eq!(
1307 extracted.0,
1308 expected_ip,
1309 "Should use X-Forwarded-For IP when trust_proxy is enabled"
1310 );
1311 } else {
1312 prop_assert_eq!(
1314 extracted.0,
1315 socket_ip,
1316 "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1317 );
1318 }
1319
1320 Ok(())
1321 })();
1322 result?;
1323 }
1324 }
1325
1326 proptest! {
1333 #![proptest_config(ProptestConfig::with_cases(100))]
1334
1335 #[test]
1336 fn prop_extension_extractor_retrieval(
1337 value in any::<i64>(),
1338 has_extension in prop::bool::ANY,
1339 ) {
1340 let result: Result<(), TestCaseError> = (|| {
1341 #[derive(Clone, Debug, PartialEq)]
1343 struct TestExtension(i64);
1344
1345 let uri: http::Uri = "/test".parse().unwrap();
1346 let builder = http::Request::builder().method(Method::GET).uri(uri);
1347 let req = builder.body(()).unwrap();
1348 let (mut parts, _) = req.into_parts();
1349
1350 if has_extension {
1351 parts.extensions.insert(TestExtension(value));
1352 }
1353
1354 let request = Request::new(
1355 parts,
1356 crate::request::BodyVariant::Buffered(Bytes::new()),
1357 Arc::new(Extensions::new()),
1358 PathParams::new(),
1359 );
1360
1361 let result = Extension::<TestExtension>::from_request_parts(&request);
1362
1363 if has_extension {
1364 let extracted = result
1365 .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1366 prop_assert_eq!(
1367 extracted.0,
1368 TestExtension(value),
1369 "Extension value mismatch"
1370 );
1371 } else {
1372 prop_assert!(
1373 result.is_err(),
1374 "Expected error when extension is missing"
1375 );
1376 }
1377
1378 Ok(())
1379 })();
1380 result?;
1381 }
1382 }
1383
1384 #[test]
1387 fn test_headers_extractor_basic() {
1388 let request = create_test_request_with_headers(
1389 Method::GET,
1390 "/test",
1391 vec![
1392 ("content-type", "application/json"),
1393 ("accept", "text/html"),
1394 ],
1395 );
1396
1397 let headers = Headers::from_request_parts(&request).unwrap();
1398
1399 assert!(headers.contains("content-type"));
1400 assert!(headers.contains("accept"));
1401 assert!(!headers.contains("x-custom"));
1402 assert_eq!(headers.len(), 2);
1403 }
1404
1405 #[test]
1406 fn test_header_value_extractor_present() {
1407 let request = create_test_request_with_headers(
1408 Method::GET,
1409 "/test",
1410 vec![("authorization", "Bearer token123")],
1411 );
1412
1413 let result = HeaderValue::extract(&request, "authorization");
1414 assert!(result.is_ok());
1415 assert_eq!(result.unwrap().value(), "Bearer token123");
1416 }
1417
1418 #[test]
1419 fn test_header_value_extractor_missing() {
1420 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1421
1422 let result = HeaderValue::extract(&request, "authorization");
1423 assert!(result.is_err());
1424 }
1425
1426 #[test]
1427 fn test_client_ip_from_forwarded_header() {
1428 let request = create_test_request_with_headers(
1429 Method::GET,
1430 "/test",
1431 vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1432 );
1433
1434 let ip = ClientIp::extract_with_config(&request, true).unwrap();
1435 assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1436 }
1437
1438 #[test]
1439 fn test_client_ip_ignores_forwarded_when_not_trusted() {
1440 let uri: http::Uri = "/test".parse().unwrap();
1441 let builder = http::Request::builder()
1442 .method(Method::GET)
1443 .uri(uri)
1444 .header("x-forwarded-for", "192.168.1.100");
1445 let req = builder.body(()).unwrap();
1446 let (mut parts, _) = req.into_parts();
1447
1448 let socket_addr = std::net::SocketAddr::new(
1449 std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1450 8080,
1451 );
1452 parts.extensions.insert(socket_addr);
1453
1454 let request = Request::new(
1455 parts,
1456 crate::request::BodyVariant::Buffered(Bytes::new()),
1457 Arc::new(Extensions::new()),
1458 PathParams::new(),
1459 );
1460
1461 let ip = ClientIp::extract_with_config(&request, false).unwrap();
1462 assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1463 }
1464
1465 #[test]
1466 fn test_extension_extractor_present() {
1467 #[derive(Clone, Debug, PartialEq)]
1468 struct MyData(String);
1469
1470 let request =
1471 create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1472
1473 let result = Extension::<MyData>::from_request_parts(&request);
1474 assert!(result.is_ok());
1475 assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1476 }
1477
1478 #[test]
1479 fn test_extension_extractor_missing() {
1480 #[derive(Clone, Debug)]
1481 #[allow(dead_code)]
1482 struct MyData(String);
1483
1484 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1485
1486 let result = Extension::<MyData>::from_request_parts(&request);
1487 assert!(result.is_err());
1488 }
1489
1490 #[cfg(feature = "cookies")]
1492 mod cookies_tests {
1493 use super::*;
1494
1495 proptest! {
1503 #![proptest_config(ProptestConfig::with_cases(100))]
1504
1505 #[test]
1506 fn prop_cookies_extractor_parsing(
1507 cookies in prop::collection::vec(
1510 (
1511 "[a-zA-Z][a-zA-Z0-9_]{0,15}", "[a-zA-Z0-9]{1,30}" ),
1514 0..5
1515 )
1516 ) {
1517 let result: Result<(), TestCaseError> = (|| {
1518 let cookie_header = cookies
1520 .iter()
1521 .map(|(name, value)| format!("{}={}", name, value))
1522 .collect::<Vec<_>>()
1523 .join("; ");
1524
1525 let headers = if !cookies.is_empty() {
1526 vec![("cookie", cookie_header.as_str())]
1527 } else {
1528 vec![]
1529 };
1530
1531 let request = create_test_request_with_headers(Method::GET, "/test", headers);
1532
1533 let extracted = Cookies::from_request_parts(&request)
1535 .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1536
1537 let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1539 for (name, value) in &cookies {
1540 expected_cookies.insert(name.as_str(), value.as_str());
1541 }
1542
1543 for (name, expected_value) in &expected_cookies {
1545 let cookie = extracted.get(name)
1546 .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1547
1548 prop_assert_eq!(
1549 cookie.value(),
1550 *expected_value,
1551 "Cookie '{}' value mismatch",
1552 name
1553 );
1554 }
1555
1556 let extracted_count = extracted.iter().count();
1558 prop_assert_eq!(
1559 extracted_count,
1560 expected_cookies.len(),
1561 "Expected {} unique cookies, got {}",
1562 expected_cookies.len(),
1563 extracted_count
1564 );
1565
1566 Ok(())
1567 })();
1568 result?;
1569 }
1570 }
1571
1572 #[test]
1573 fn test_cookies_extractor_basic() {
1574 let request = create_test_request_with_headers(
1575 Method::GET,
1576 "/test",
1577 vec![("cookie", "session=abc123; user=john")],
1578 );
1579
1580 let cookies = Cookies::from_request_parts(&request).unwrap();
1581
1582 assert!(cookies.contains("session"));
1583 assert!(cookies.contains("user"));
1584 assert!(!cookies.contains("other"));
1585
1586 assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1587 assert_eq!(cookies.get("user").unwrap().value(), "john");
1588 }
1589
1590 #[test]
1591 fn test_cookies_extractor_empty() {
1592 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1593
1594 let cookies = Cookies::from_request_parts(&request).unwrap();
1595 assert_eq!(cookies.iter().count(), 0);
1596 }
1597
1598 #[test]
1599 fn test_cookies_extractor_single() {
1600 let request = create_test_request_with_headers(
1601 Method::GET,
1602 "/test",
1603 vec![("cookie", "token=xyz789")],
1604 );
1605
1606 let cookies = Cookies::from_request_parts(&request).unwrap();
1607 assert_eq!(cookies.iter().count(), 1);
1608 assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1609 }
1610 }
1611}