1use crate::error::{ApiError, Result};
58use crate::json;
59use crate::request::Request;
60use crate::response::IntoResponse;
61use crate::stream::{StreamingBody, StreamingConfig};
62use crate::validation::Validatable;
63use bytes::Bytes;
64use http::{header, StatusCode};
65use rustapi_validate::v2::{AsyncValidate, ValidationContext};
66
67use serde::de::DeserializeOwned;
68use serde::Serialize;
69use std::future::Future;
70use std::ops::{Deref, DerefMut};
71use std::str::FromStr;
72
73pub trait FromRequestParts: Sized {
98 fn from_request_parts(req: &Request) -> Result<Self>;
100}
101
102pub trait FromRequest: Sized {
132 fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
134}
135
136impl<T: FromRequestParts> FromRequest for T {
138 async fn from_request(req: &mut Request) -> Result<Self> {
139 T::from_request_parts(req)
140 }
141}
142
143#[derive(Debug, Clone, Copy, Default)]
162pub struct Json<T>(pub T);
163
164impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
165 async fn from_request(req: &mut Request) -> Result<Self> {
166 req.load_body().await?;
167 let body = req
168 .take_body()
169 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
170
171 let value: T = json::from_slice(&body)?;
173 Ok(Json(value))
174 }
175}
176
177impl<T> Deref for Json<T> {
178 type Target = T;
179
180 fn deref(&self) -> &Self::Target {
181 &self.0
182 }
183}
184
185impl<T> DerefMut for Json<T> {
186 fn deref_mut(&mut self) -> &mut Self::Target {
187 &mut self.0
188 }
189}
190
191impl<T> From<T> for Json<T> {
192 fn from(value: T) -> Self {
193 Json(value)
194 }
195}
196
197const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
200
201impl<T: Serialize> IntoResponse for Json<T> {
203 fn into_response(self) -> crate::response::Response {
204 match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
206 Ok(body) => http::Response::builder()
207 .status(StatusCode::OK)
208 .header(header::CONTENT_TYPE, "application/json")
209 .body(crate::response::Body::from(body))
210 .unwrap(),
211 Err(err) => {
212 ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
213 }
214 }
215 }
216}
217
218#[derive(Debug, Clone, Copy, Default)]
244pub struct ValidatedJson<T>(pub T);
245
246impl<T> ValidatedJson<T> {
247 pub fn new(value: T) -> Self {
249 Self(value)
250 }
251
252 pub fn into_inner(self) -> T {
254 self.0
255 }
256}
257
258impl<T: DeserializeOwned + Validatable + Send> FromRequest for ValidatedJson<T> {
259 async fn from_request(req: &mut Request) -> Result<Self> {
260 req.load_body().await?;
261 let body = req
263 .take_body()
264 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
265
266 let value: T = json::from_slice(&body)?;
267
268 if let Err(e) = value.do_validate() {
270 return Err(e);
271 }
272
273 Ok(ValidatedJson(value))
274 }
275}
276
277impl<T> Deref for ValidatedJson<T> {
278 type Target = T;
279
280 fn deref(&self) -> &Self::Target {
281 &self.0
282 }
283}
284
285impl<T> DerefMut for ValidatedJson<T> {
286 fn deref_mut(&mut self) -> &mut Self::Target {
287 &mut self.0
288 }
289}
290
291impl<T> From<T> for ValidatedJson<T> {
292 fn from(value: T) -> Self {
293 ValidatedJson(value)
294 }
295}
296
297impl<T: Serialize> IntoResponse for ValidatedJson<T> {
298 fn into_response(self) -> crate::response::Response {
299 Json(self.0).into_response()
300 }
301}
302
303#[derive(Debug, Clone, Copy, Default)]
330pub struct AsyncValidatedJson<T>(pub T);
331
332impl<T> AsyncValidatedJson<T> {
333 pub fn new(value: T) -> Self {
335 Self(value)
336 }
337
338 pub fn into_inner(self) -> T {
340 self.0
341 }
342}
343
344impl<T> Deref for AsyncValidatedJson<T> {
345 type Target = T;
346
347 fn deref(&self) -> &Self::Target {
348 &self.0
349 }
350}
351
352impl<T> DerefMut for AsyncValidatedJson<T> {
353 fn deref_mut(&mut self) -> &mut Self::Target {
354 &mut self.0
355 }
356}
357
358impl<T> From<T> for AsyncValidatedJson<T> {
359 fn from(value: T) -> Self {
360 AsyncValidatedJson(value)
361 }
362}
363
364impl<T: Serialize> IntoResponse for AsyncValidatedJson<T> {
365 fn into_response(self) -> crate::response::Response {
366 Json(self.0).into_response()
367 }
368}
369
370impl<T: DeserializeOwned + AsyncValidate + Send + Sync> FromRequest for AsyncValidatedJson<T> {
371 async fn from_request(req: &mut Request) -> Result<Self> {
372 req.load_body().await?;
373
374 let body = req
375 .take_body()
376 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
377
378 let value: T = json::from_slice(&body)?;
379
380 let ctx = ValidationContext::default();
383
384 if let Err(errors) = value.validate_full(&ctx).await {
386 let field_errors: Vec<crate::error::FieldError> = errors
388 .fields
389 .iter()
390 .flat_map(|(field, errs)| {
391 let field_name = field.to_string();
392 errs.iter().map(move |e| crate::error::FieldError {
393 field: field_name.clone(),
394 code: e.code.to_string(),
395 message: e.message.clone(),
396 })
397 })
398 .collect();
399
400 return Err(ApiError::validation(field_errors));
401 }
402
403 Ok(AsyncValidatedJson(value))
404 }
405}
406
407#[derive(Debug, Clone)]
425pub struct Query<T>(pub T);
426
427impl<T: DeserializeOwned> FromRequestParts for Query<T> {
428 fn from_request_parts(req: &Request) -> Result<Self> {
429 let query = req.query_string().unwrap_or("");
430 let value: T = serde_urlencoded::from_str(query)
431 .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
432 Ok(Query(value))
433 }
434}
435
436impl<T> Deref for Query<T> {
437 type Target = T;
438
439 fn deref(&self) -> &Self::Target {
440 &self.0
441 }
442}
443
444#[derive(Debug, Clone)]
466pub struct Path<T>(pub T);
467
468impl<T: FromStr> FromRequestParts for Path<T>
469where
470 T::Err: std::fmt::Display,
471{
472 fn from_request_parts(req: &Request) -> Result<Self> {
473 let params = req.path_params();
474
475 if let Some((_, value)) = params.iter().next() {
477 let parsed = value
478 .parse::<T>()
479 .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
480 return Ok(Path(parsed));
481 }
482
483 Err(ApiError::internal("Missing path parameter"))
484 }
485}
486
487impl<T> Deref for Path<T> {
488 type Target = T;
489
490 fn deref(&self) -> &Self::Target {
491 &self.0
492 }
493}
494
495#[derive(Debug, Clone)]
515pub struct Typed<T>(pub T);
516
517impl<T: DeserializeOwned + Send> FromRequestParts for Typed<T> {
518 fn from_request_parts(req: &Request) -> Result<Self> {
519 let params = req.path_params();
520 let mut map = serde_json::Map::new();
521 for (k, v) in params.iter() {
522 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
523 }
524 let value = serde_json::Value::Object(map);
525 let parsed: T = serde_json::from_value(value)
526 .map_err(|e| ApiError::bad_request(format!("Invalid path parameters: {}", e)))?;
527 Ok(Typed(parsed))
528 }
529}
530
531impl<T> Deref for Typed<T> {
532 type Target = T;
533
534 fn deref(&self) -> &Self::Target {
535 &self.0
536 }
537}
538
539#[derive(Debug, Clone)]
556pub struct State<T>(pub T);
557
558impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
559 fn from_request_parts(req: &Request) -> Result<Self> {
560 req.state().get::<T>().cloned().map(State).ok_or_else(|| {
561 ApiError::internal(format!(
562 "State of type `{}` not found. Did you forget to call .state()?",
563 std::any::type_name::<T>()
564 ))
565 })
566 }
567}
568
569impl<T> Deref for State<T> {
570 type Target = T;
571
572 fn deref(&self) -> &Self::Target {
573 &self.0
574 }
575}
576
577#[derive(Debug, Clone)]
579pub struct Body(pub Bytes);
580
581impl FromRequest for Body {
582 async fn from_request(req: &mut Request) -> Result<Self> {
583 req.load_body().await?;
584 let body = req
585 .take_body()
586 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
587 Ok(Body(body))
588 }
589}
590
591impl Deref for Body {
592 type Target = Bytes;
593
594 fn deref(&self) -> &Self::Target {
595 &self.0
596 }
597}
598
599pub struct BodyStream(pub StreamingBody);
601
602impl FromRequest for BodyStream {
603 async fn from_request(req: &mut Request) -> Result<Self> {
604 let config = StreamingConfig::default();
605
606 if let Some(stream) = req.take_stream() {
607 Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
608 } else if let Some(bytes) = req.take_body() {
609 let stream = futures_util::stream::once(async move { Ok(bytes) });
611 Ok(BodyStream(StreamingBody::from_stream(
612 stream,
613 config.max_body_size,
614 )))
615 } else {
616 Err(ApiError::internal("Body already consumed"))
617 }
618 }
619}
620
621impl Deref for BodyStream {
622 type Target = StreamingBody;
623
624 fn deref(&self) -> &Self::Target {
625 &self.0
626 }
627}
628
629impl DerefMut for BodyStream {
630 fn deref_mut(&mut self) -> &mut Self::Target {
631 &mut self.0
632 }
633}
634
635impl futures_util::Stream for BodyStream {
637 type Item = Result<Bytes, ApiError>;
638
639 fn poll_next(
640 mut self: std::pin::Pin<&mut Self>,
641 cx: &mut std::task::Context<'_>,
642 ) -> std::task::Poll<Option<Self::Item>> {
643 std::pin::Pin::new(&mut self.0).poll_next(cx)
644 }
645}
646
647impl<T: FromRequestParts> FromRequestParts for Option<T> {
651 fn from_request_parts(req: &Request) -> Result<Self> {
652 Ok(T::from_request_parts(req).ok())
653 }
654}
655
656#[derive(Debug, Clone)]
674pub struct Headers(pub http::HeaderMap);
675
676impl Headers {
677 pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
679 self.0.get(name)
680 }
681
682 pub fn contains(&self, name: &str) -> bool {
684 self.0.contains_key(name)
685 }
686
687 pub fn len(&self) -> usize {
689 self.0.len()
690 }
691
692 pub fn is_empty(&self) -> bool {
694 self.0.is_empty()
695 }
696
697 pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
699 self.0.iter()
700 }
701}
702
703impl FromRequestParts for Headers {
704 fn from_request_parts(req: &Request) -> Result<Self> {
705 Ok(Headers(req.headers().clone()))
706 }
707}
708
709impl Deref for Headers {
710 type Target = http::HeaderMap;
711
712 fn deref(&self) -> &Self::Target {
713 &self.0
714 }
715}
716
717#[derive(Debug, Clone)]
736pub struct HeaderValue(pub String, pub &'static str);
737
738impl HeaderValue {
739 pub fn new(name: &'static str, value: String) -> Self {
741 Self(value, name)
742 }
743
744 pub fn value(&self) -> &str {
746 &self.0
747 }
748
749 pub fn name(&self) -> &'static str {
751 self.1
752 }
753
754 pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
756 req.headers()
757 .get(name)
758 .and_then(|v| v.to_str().ok())
759 .map(|s| HeaderValue(s.to_string(), name))
760 .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
761 }
762}
763
764impl Deref for HeaderValue {
765 type Target = String;
766
767 fn deref(&self) -> &Self::Target {
768 &self.0
769 }
770}
771
772#[derive(Debug, Clone)]
790pub struct Extension<T>(pub T);
791
792impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
793 fn from_request_parts(req: &Request) -> Result<Self> {
794 req.extensions()
795 .get::<T>()
796 .cloned()
797 .map(Extension)
798 .ok_or_else(|| {
799 ApiError::internal(format!(
800 "Extension of type `{}` not found. Did middleware insert it?",
801 std::any::type_name::<T>()
802 ))
803 })
804 }
805}
806
807impl<T> Deref for Extension<T> {
808 type Target = T;
809
810 fn deref(&self) -> &Self::Target {
811 &self.0
812 }
813}
814
815impl<T> DerefMut for Extension<T> {
816 fn deref_mut(&mut self) -> &mut Self::Target {
817 &mut self.0
818 }
819}
820
821#[derive(Debug, Clone)]
836pub struct ClientIp(pub std::net::IpAddr);
837
838impl ClientIp {
839 pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
841 if trust_proxy {
842 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
844 if let Ok(forwarded_str) = forwarded.to_str() {
845 if let Some(first_ip) = forwarded_str.split(',').next() {
847 if let Ok(ip) = first_ip.trim().parse() {
848 return Ok(ClientIp(ip));
849 }
850 }
851 }
852 }
853 }
854
855 if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
857 return Ok(ClientIp(addr.ip()));
858 }
859
860 Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
862 127, 0, 0, 1,
863 ))))
864 }
865}
866
867impl FromRequestParts for ClientIp {
868 fn from_request_parts(req: &Request) -> Result<Self> {
869 Self::extract_with_config(req, true)
871 }
872}
873
874#[cfg(feature = "cookies")]
892#[derive(Debug, Clone)]
893pub struct Cookies(pub cookie::CookieJar);
894
895#[cfg(feature = "cookies")]
896impl Cookies {
897 pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
899 self.0.get(name)
900 }
901
902 pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
904 self.0.iter()
905 }
906
907 pub fn contains(&self, name: &str) -> bool {
909 self.0.get(name).is_some()
910 }
911}
912
913#[cfg(feature = "cookies")]
914impl FromRequestParts for Cookies {
915 fn from_request_parts(req: &Request) -> Result<Self> {
916 let mut jar = cookie::CookieJar::new();
917
918 if let Some(cookie_header) = req.headers().get(header::COOKIE) {
919 if let Ok(cookie_str) = cookie_header.to_str() {
920 for cookie_part in cookie_str.split(';') {
922 let trimmed = cookie_part.trim();
923 if !trimmed.is_empty() {
924 if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
925 jar.add_original(cookie.into_owned());
926 }
927 }
928 }
929 }
930 }
931
932 Ok(Cookies(jar))
933 }
934}
935
936#[cfg(feature = "cookies")]
937impl Deref for Cookies {
938 type Target = cookie::CookieJar;
939
940 fn deref(&self) -> &Self::Target {
941 &self.0
942 }
943}
944
945macro_rules! impl_from_request_parts_for_primitives {
947 ($($ty:ty),*) => {
948 $(
949 impl FromRequestParts for $ty {
950 fn from_request_parts(req: &Request) -> Result<Self> {
951 let Path(value) = Path::<$ty>::from_request_parts(req)?;
952 Ok(value)
953 }
954 }
955 )*
956 };
957}
958
959impl_from_request_parts_for_primitives!(
960 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
961);
962
963use rustapi_openapi::utoipa_types::openapi;
966use rustapi_openapi::{
967 IntoParams, MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier,
968 ResponseSpec, Schema, SchemaRef,
969};
970use std::collections::HashMap;
971
972impl<T: for<'a> Schema<'a>> OperationModifier for ValidatedJson<T> {
974 fn update_operation(op: &mut Operation) {
975 let (name, _) = T::schema();
976
977 let schema_ref = SchemaRef::Ref {
978 reference: format!("#/components/schemas/{}", name),
979 };
980
981 let mut content = HashMap::new();
982 content.insert(
983 "application/json".to_string(),
984 MediaType { schema: schema_ref },
985 );
986
987 op.request_body = Some(RequestBody {
988 required: true,
989 content,
990 });
991
992 op.responses.insert(
994 "422".to_string(),
995 ResponseSpec {
996 description: "Validation Error".to_string(),
997 content: {
998 let mut map = HashMap::new();
999 map.insert(
1000 "application/json".to_string(),
1001 MediaType {
1002 schema: SchemaRef::Ref {
1003 reference: "#/components/schemas/ValidationErrorSchema".to_string(),
1004 },
1005 },
1006 );
1007 Some(map)
1008 },
1009 },
1010 );
1011 }
1012}
1013
1014impl<T: for<'a> Schema<'a>> OperationModifier for Json<T> {
1016 fn update_operation(op: &mut Operation) {
1017 let (name, _) = T::schema();
1018
1019 let schema_ref = SchemaRef::Ref {
1020 reference: format!("#/components/schemas/{}", name),
1021 };
1022
1023 let mut content = HashMap::new();
1024 content.insert(
1025 "application/json".to_string(),
1026 MediaType { schema: schema_ref },
1027 );
1028
1029 op.request_body = Some(RequestBody {
1030 required: true,
1031 content,
1032 });
1033 }
1034}
1035
1036impl<T> OperationModifier for Path<T> {
1040 fn update_operation(_op: &mut Operation) {
1041 }
1048}
1049
1050impl<T> OperationModifier for Typed<T> {
1052 fn update_operation(_op: &mut Operation) {
1053 }
1055}
1056
1057impl<T: IntoParams> OperationModifier for Query<T> {
1059 fn update_operation(op: &mut Operation) {
1060 let params = T::into_params(|| Some(openapi::path::ParameterIn::Query));
1061
1062 let new_params: Vec<Parameter> = params
1063 .into_iter()
1064 .map(|p| {
1065 let schema = match p.schema {
1066 Some(schema) => match schema {
1067 openapi::RefOr::Ref(r) => SchemaRef::Ref {
1068 reference: r.ref_location,
1069 },
1070 openapi::RefOr::T(s) => {
1071 let value = serde_json::to_value(s).unwrap_or(serde_json::Value::Null);
1072 SchemaRef::Inline(value)
1073 }
1074 },
1075 None => SchemaRef::Inline(serde_json::Value::Null),
1076 };
1077
1078 let required = match p.required {
1079 openapi::Required::True => true,
1080 openapi::Required::False => false,
1081 };
1082
1083 Parameter {
1084 name: p.name,
1085 location: "query".to_string(), required,
1087 description: p.description,
1088 schema,
1089 }
1090 })
1091 .collect();
1092
1093 if let Some(existing) = &mut op.parameters {
1094 existing.extend(new_params);
1095 } else {
1096 op.parameters = Some(new_params);
1097 }
1098 }
1099}
1100
1101impl<T> OperationModifier for State<T> {
1103 fn update_operation(_op: &mut Operation) {}
1104}
1105
1106impl OperationModifier for Body {
1108 fn update_operation(op: &mut Operation) {
1109 let mut content = HashMap::new();
1110 content.insert(
1111 "application/octet-stream".to_string(),
1112 MediaType {
1113 schema: SchemaRef::Inline(
1114 serde_json::json!({ "type": "string", "format": "binary" }),
1115 ),
1116 },
1117 );
1118
1119 op.request_body = Some(RequestBody {
1120 required: true,
1121 content,
1122 });
1123 }
1124}
1125
1126impl OperationModifier for BodyStream {
1128 fn update_operation(op: &mut Operation) {
1129 let mut content = HashMap::new();
1130 content.insert(
1131 "application/octet-stream".to_string(),
1132 MediaType {
1133 schema: SchemaRef::Inline(
1134 serde_json::json!({ "type": "string", "format": "binary" }),
1135 ),
1136 },
1137 );
1138
1139 op.request_body = Some(RequestBody {
1140 required: true,
1141 content,
1142 });
1143 }
1144}
1145
1146impl<T: for<'a> Schema<'a>> ResponseModifier for Json<T> {
1150 fn update_response(op: &mut Operation) {
1151 let (name, _) = T::schema();
1152
1153 let schema_ref = SchemaRef::Ref {
1154 reference: format!("#/components/schemas/{}", name),
1155 };
1156
1157 op.responses.insert(
1158 "200".to_string(),
1159 ResponseSpec {
1160 description: "Successful response".to_string(),
1161 content: {
1162 let mut map = HashMap::new();
1163 map.insert(
1164 "application/json".to_string(),
1165 MediaType { schema: schema_ref },
1166 );
1167 Some(map)
1168 },
1169 },
1170 );
1171 }
1172}
1173
1174#[cfg(test)]
1175mod tests {
1176 use super::*;
1177 use crate::path_params::PathParams;
1178 use bytes::Bytes;
1179 use http::{Extensions, Method};
1180 use proptest::prelude::*;
1181 use proptest::test_runner::TestCaseError;
1182 use std::sync::Arc;
1183
1184 fn create_test_request_with_headers(
1186 method: Method,
1187 path: &str,
1188 headers: Vec<(&str, &str)>,
1189 ) -> Request {
1190 let uri: http::Uri = path.parse().unwrap();
1191 let mut builder = http::Request::builder().method(method).uri(uri);
1192
1193 for (name, value) in headers {
1194 builder = builder.header(name, value);
1195 }
1196
1197 let req = builder.body(()).unwrap();
1198 let (parts, _) = req.into_parts();
1199
1200 Request::new(
1201 parts,
1202 crate::request::BodyVariant::Buffered(Bytes::new()),
1203 Arc::new(Extensions::new()),
1204 PathParams::new(),
1205 )
1206 }
1207
1208 fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
1210 method: Method,
1211 path: &str,
1212 extension: T,
1213 ) -> Request {
1214 let uri: http::Uri = path.parse().unwrap();
1215 let builder = http::Request::builder().method(method).uri(uri);
1216
1217 let req = builder.body(()).unwrap();
1218 let (mut parts, _) = req.into_parts();
1219 parts.extensions.insert(extension);
1220
1221 Request::new(
1222 parts,
1223 crate::request::BodyVariant::Buffered(Bytes::new()),
1224 Arc::new(Extensions::new()),
1225 PathParams::new(),
1226 )
1227 }
1228
1229 proptest! {
1236 #![proptest_config(ProptestConfig::with_cases(100))]
1237
1238 #[test]
1239 fn prop_headers_extractor_completeness(
1240 headers in prop::collection::vec(
1243 (
1244 "[a-z][a-z0-9-]{0,20}", "[a-zA-Z0-9 ]{1,50}" ),
1247 0..10
1248 )
1249 ) {
1250 let result: Result<(), TestCaseError> = (|| {
1251 let header_tuples: Vec<(&str, &str)> = headers
1253 .iter()
1254 .map(|(k, v)| (k.as_str(), v.as_str()))
1255 .collect();
1256
1257 let request = create_test_request_with_headers(
1259 Method::GET,
1260 "/test",
1261 header_tuples.clone(),
1262 );
1263
1264 let extracted = Headers::from_request_parts(&request)
1266 .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
1267
1268 for (name, value) in &headers {
1271 let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
1273 prop_assert!(
1274 !all_values.is_empty(),
1275 "Header '{}' not found",
1276 name
1277 );
1278
1279 let value_found = all_values.iter().any(|v| {
1281 v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
1282 });
1283
1284 prop_assert!(
1285 value_found,
1286 "Header '{}' value '{}' not found in extracted values",
1287 name,
1288 value
1289 );
1290 }
1291
1292 Ok(())
1293 })();
1294 result?;
1295 }
1296 }
1297
1298 proptest! {
1305 #![proptest_config(ProptestConfig::with_cases(100))]
1306
1307 #[test]
1308 fn prop_header_value_extractor_correctness(
1309 header_name in "[a-z][a-z0-9-]{0,20}",
1310 header_value in "[a-zA-Z0-9 ]{1,50}",
1311 has_header in prop::bool::ANY,
1312 ) {
1313 let result: Result<(), TestCaseError> = (|| {
1314 let headers = if has_header {
1315 vec![(header_name.as_str(), header_value.as_str())]
1316 } else {
1317 vec![]
1318 };
1319
1320 let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1321
1322 let test_header = "x-test-header";
1325 let request_with_known_header = if has_header {
1326 create_test_request_with_headers(
1327 Method::GET,
1328 "/test",
1329 vec![(test_header, header_value.as_str())],
1330 )
1331 } else {
1332 create_test_request_with_headers(Method::GET, "/test", vec![])
1333 };
1334
1335 let result = HeaderValue::extract(&request_with_known_header, test_header);
1336
1337 if has_header {
1338 let extracted = result
1339 .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1340 prop_assert_eq!(
1341 extracted.value(),
1342 header_value.as_str(),
1343 "Header value mismatch"
1344 );
1345 } else {
1346 prop_assert!(
1347 result.is_err(),
1348 "Expected error when header is missing"
1349 );
1350 }
1351
1352 Ok(())
1353 })();
1354 result?;
1355 }
1356 }
1357
1358 proptest! {
1365 #![proptest_config(ProptestConfig::with_cases(100))]
1366
1367 #[test]
1368 fn prop_client_ip_extractor_with_forwarding(
1369 forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1371 .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1372 socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1373 .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1374 has_forwarded_header in prop::bool::ANY,
1375 trust_proxy in prop::bool::ANY,
1376 ) {
1377 let result: Result<(), TestCaseError> = (|| {
1378 let headers = if has_forwarded_header {
1379 vec![("x-forwarded-for", forwarded_ip.as_str())]
1380 } else {
1381 vec![]
1382 };
1383
1384 let uri: http::Uri = "/test".parse().unwrap();
1386 let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1387 for (name, value) in &headers {
1388 builder = builder.header(*name, *value);
1389 }
1390 let req = builder.body(()).unwrap();
1391 let (mut parts, _) = req.into_parts();
1392
1393 let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1395 parts.extensions.insert(socket_addr);
1396
1397 let request = Request::new(
1398 parts,
1399 crate::request::BodyVariant::Buffered(Bytes::new()),
1400 Arc::new(Extensions::new()),
1401 PathParams::new(),
1402 );
1403
1404 let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1405 .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1406
1407 if trust_proxy && has_forwarded_header {
1408 let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1410 .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1411 prop_assert_eq!(
1412 extracted.0,
1413 expected_ip,
1414 "Should use X-Forwarded-For IP when trust_proxy is enabled"
1415 );
1416 } else {
1417 prop_assert_eq!(
1419 extracted.0,
1420 socket_ip,
1421 "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1422 );
1423 }
1424
1425 Ok(())
1426 })();
1427 result?;
1428 }
1429 }
1430
1431 proptest! {
1438 #![proptest_config(ProptestConfig::with_cases(100))]
1439
1440 #[test]
1441 fn prop_extension_extractor_retrieval(
1442 value in any::<i64>(),
1443 has_extension in prop::bool::ANY,
1444 ) {
1445 let result: Result<(), TestCaseError> = (|| {
1446 #[derive(Clone, Debug, PartialEq)]
1448 struct TestExtension(i64);
1449
1450 let uri: http::Uri = "/test".parse().unwrap();
1451 let builder = http::Request::builder().method(Method::GET).uri(uri);
1452 let req = builder.body(()).unwrap();
1453 let (mut parts, _) = req.into_parts();
1454
1455 if has_extension {
1456 parts.extensions.insert(TestExtension(value));
1457 }
1458
1459 let request = Request::new(
1460 parts,
1461 crate::request::BodyVariant::Buffered(Bytes::new()),
1462 Arc::new(Extensions::new()),
1463 PathParams::new(),
1464 );
1465
1466 let result = Extension::<TestExtension>::from_request_parts(&request);
1467
1468 if has_extension {
1469 let extracted = result
1470 .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1471 prop_assert_eq!(
1472 extracted.0,
1473 TestExtension(value),
1474 "Extension value mismatch"
1475 );
1476 } else {
1477 prop_assert!(
1478 result.is_err(),
1479 "Expected error when extension is missing"
1480 );
1481 }
1482
1483 Ok(())
1484 })();
1485 result?;
1486 }
1487 }
1488
1489 #[test]
1492 fn test_headers_extractor_basic() {
1493 let request = create_test_request_with_headers(
1494 Method::GET,
1495 "/test",
1496 vec![
1497 ("content-type", "application/json"),
1498 ("accept", "text/html"),
1499 ],
1500 );
1501
1502 let headers = Headers::from_request_parts(&request).unwrap();
1503
1504 assert!(headers.contains("content-type"));
1505 assert!(headers.contains("accept"));
1506 assert!(!headers.contains("x-custom"));
1507 assert_eq!(headers.len(), 2);
1508 }
1509
1510 #[test]
1511 fn test_header_value_extractor_present() {
1512 let request = create_test_request_with_headers(
1513 Method::GET,
1514 "/test",
1515 vec![("authorization", "Bearer token123")],
1516 );
1517
1518 let result = HeaderValue::extract(&request, "authorization");
1519 assert!(result.is_ok());
1520 assert_eq!(result.unwrap().value(), "Bearer token123");
1521 }
1522
1523 #[test]
1524 fn test_header_value_extractor_missing() {
1525 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1526
1527 let result = HeaderValue::extract(&request, "authorization");
1528 assert!(result.is_err());
1529 }
1530
1531 #[test]
1532 fn test_client_ip_from_forwarded_header() {
1533 let request = create_test_request_with_headers(
1534 Method::GET,
1535 "/test",
1536 vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1537 );
1538
1539 let ip = ClientIp::extract_with_config(&request, true).unwrap();
1540 assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1541 }
1542
1543 #[test]
1544 fn test_client_ip_ignores_forwarded_when_not_trusted() {
1545 let uri: http::Uri = "/test".parse().unwrap();
1546 let builder = http::Request::builder()
1547 .method(Method::GET)
1548 .uri(uri)
1549 .header("x-forwarded-for", "192.168.1.100");
1550 let req = builder.body(()).unwrap();
1551 let (mut parts, _) = req.into_parts();
1552
1553 let socket_addr = std::net::SocketAddr::new(
1554 std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1555 8080,
1556 );
1557 parts.extensions.insert(socket_addr);
1558
1559 let request = Request::new(
1560 parts,
1561 crate::request::BodyVariant::Buffered(Bytes::new()),
1562 Arc::new(Extensions::new()),
1563 PathParams::new(),
1564 );
1565
1566 let ip = ClientIp::extract_with_config(&request, false).unwrap();
1567 assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1568 }
1569
1570 #[test]
1571 fn test_extension_extractor_present() {
1572 #[derive(Clone, Debug, PartialEq)]
1573 struct MyData(String);
1574
1575 let request =
1576 create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1577
1578 let result = Extension::<MyData>::from_request_parts(&request);
1579 assert!(result.is_ok());
1580 assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1581 }
1582
1583 #[test]
1584 fn test_extension_extractor_missing() {
1585 #[derive(Clone, Debug)]
1586 #[allow(dead_code)]
1587 struct MyData(String);
1588
1589 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1590
1591 let result = Extension::<MyData>::from_request_parts(&request);
1592 assert!(result.is_err());
1593 }
1594
1595 #[cfg(feature = "cookies")]
1597 mod cookies_tests {
1598 use super::*;
1599
1600 proptest! {
1608 #![proptest_config(ProptestConfig::with_cases(100))]
1609
1610 #[test]
1611 fn prop_cookies_extractor_parsing(
1612 cookies in prop::collection::vec(
1615 (
1616 "[a-zA-Z][a-zA-Z0-9_]{0,15}", "[a-zA-Z0-9]{1,30}" ),
1619 0..5
1620 )
1621 ) {
1622 let result: Result<(), TestCaseError> = (|| {
1623 let cookie_header = cookies
1625 .iter()
1626 .map(|(name, value)| format!("{}={}", name, value))
1627 .collect::<Vec<_>>()
1628 .join("; ");
1629
1630 let headers = if !cookies.is_empty() {
1631 vec![("cookie", cookie_header.as_str())]
1632 } else {
1633 vec![]
1634 };
1635
1636 let request = create_test_request_with_headers(Method::GET, "/test", headers);
1637
1638 let extracted = Cookies::from_request_parts(&request)
1640 .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1641
1642 let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1644 for (name, value) in &cookies {
1645 expected_cookies.insert(name.as_str(), value.as_str());
1646 }
1647
1648 for (name, expected_value) in &expected_cookies {
1650 let cookie = extracted.get(name)
1651 .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1652
1653 prop_assert_eq!(
1654 cookie.value(),
1655 *expected_value,
1656 "Cookie '{}' value mismatch",
1657 name
1658 );
1659 }
1660
1661 let extracted_count = extracted.iter().count();
1663 prop_assert_eq!(
1664 extracted_count,
1665 expected_cookies.len(),
1666 "Expected {} unique cookies, got {}",
1667 expected_cookies.len(),
1668 extracted_count
1669 );
1670
1671 Ok(())
1672 })();
1673 result?;
1674 }
1675 }
1676
1677 #[test]
1678 fn test_cookies_extractor_basic() {
1679 let request = create_test_request_with_headers(
1680 Method::GET,
1681 "/test",
1682 vec![("cookie", "session=abc123; user=john")],
1683 );
1684
1685 let cookies = Cookies::from_request_parts(&request).unwrap();
1686
1687 assert!(cookies.contains("session"));
1688 assert!(cookies.contains("user"));
1689 assert!(!cookies.contains("other"));
1690
1691 assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1692 assert_eq!(cookies.get("user").unwrap().value(), "john");
1693 }
1694
1695 #[test]
1696 fn test_cookies_extractor_empty() {
1697 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1698
1699 let cookies = Cookies::from_request_parts(&request).unwrap();
1700 assert_eq!(cookies.iter().count(), 0);
1701 }
1702
1703 #[test]
1704 fn test_cookies_extractor_single() {
1705 let request = create_test_request_with_headers(
1706 Method::GET,
1707 "/test",
1708 vec![("cookie", "token=xyz789")],
1709 );
1710
1711 let cookies = Cookies::from_request_parts(&request).unwrap();
1712 assert_eq!(cookies.iter().count(), 1);
1713 assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1714 }
1715 }
1716}