1use crate::error::{ApiError, Result};
58use crate::json;
59use crate::request::Request;
60use crate::response::IntoResponse;
61use crate::stream::{StreamingBody, StreamingConfig};
62use crate::validation::Validatable;
63use bytes::Bytes;
64use http::{header, StatusCode};
65use rustapi_validate::v2::{AsyncValidate, ValidationContext};
66
67use serde::de::DeserializeOwned;
68use serde::Serialize;
69use std::future::Future;
70use std::ops::{Deref, DerefMut};
71use std::str::FromStr;
72
73pub trait FromRequestParts: Sized {
98 fn from_request_parts(req: &Request) -> Result<Self>;
100}
101
102pub trait FromRequest: Sized {
132 fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
134}
135
136impl<T: FromRequestParts> FromRequest for T {
138 async fn from_request(req: &mut Request) -> Result<Self> {
139 T::from_request_parts(req)
140 }
141}
142
143#[derive(Debug, Clone, Copy, Default)]
162pub struct Json<T>(pub T);
163
164impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
165 async fn from_request(req: &mut Request) -> Result<Self> {
166 req.load_body().await?;
167 let body = req
168 .take_body()
169 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
170
171 let value: T = json::from_slice(&body)?;
173 Ok(Json(value))
174 }
175}
176
177impl<T> Deref for Json<T> {
178 type Target = T;
179
180 fn deref(&self) -> &Self::Target {
181 &self.0
182 }
183}
184
185impl<T> DerefMut for Json<T> {
186 fn deref_mut(&mut self) -> &mut Self::Target {
187 &mut self.0
188 }
189}
190
191impl<T> From<T> for Json<T> {
192 fn from(value: T) -> Self {
193 Json(value)
194 }
195}
196
197const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
200
201impl<T: Serialize> IntoResponse for Json<T> {
203 fn into_response(self) -> crate::response::Response {
204 match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
206 Ok(body) => http::Response::builder()
207 .status(StatusCode::OK)
208 .header(header::CONTENT_TYPE, "application/json")
209 .body(crate::response::Body::from(body))
210 .unwrap(),
211 Err(err) => {
212 ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
213 }
214 }
215 }
216}
217
218#[derive(Debug, Clone, Copy, Default)]
244pub struct ValidatedJson<T>(pub T);
245
246impl<T> ValidatedJson<T> {
247 pub fn new(value: T) -> Self {
249 Self(value)
250 }
251
252 pub fn into_inner(self) -> T {
254 self.0
255 }
256}
257
258impl<T: DeserializeOwned + Validatable + Send> FromRequest for ValidatedJson<T> {
259 async fn from_request(req: &mut Request) -> Result<Self> {
260 req.load_body().await?;
261 let body = req
263 .take_body()
264 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
265
266 let value: T = json::from_slice(&body)?;
267
268 value.do_validate()?;
270
271 Ok(ValidatedJson(value))
272 }
273}
274
275impl<T> Deref for ValidatedJson<T> {
276 type Target = T;
277
278 fn deref(&self) -> &Self::Target {
279 &self.0
280 }
281}
282
283impl<T> DerefMut for ValidatedJson<T> {
284 fn deref_mut(&mut self) -> &mut Self::Target {
285 &mut self.0
286 }
287}
288
289impl<T> From<T> for ValidatedJson<T> {
290 fn from(value: T) -> Self {
291 ValidatedJson(value)
292 }
293}
294
295impl<T: Serialize> IntoResponse for ValidatedJson<T> {
296 fn into_response(self) -> crate::response::Response {
297 Json(self.0).into_response()
298 }
299}
300
301#[derive(Debug, Clone, Copy, Default)]
328pub struct AsyncValidatedJson<T>(pub T);
329
330impl<T> AsyncValidatedJson<T> {
331 pub fn new(value: T) -> Self {
333 Self(value)
334 }
335
336 pub fn into_inner(self) -> T {
338 self.0
339 }
340}
341
342impl<T> Deref for AsyncValidatedJson<T> {
343 type Target = T;
344
345 fn deref(&self) -> &Self::Target {
346 &self.0
347 }
348}
349
350impl<T> DerefMut for AsyncValidatedJson<T> {
351 fn deref_mut(&mut self) -> &mut Self::Target {
352 &mut self.0
353 }
354}
355
356impl<T> From<T> for AsyncValidatedJson<T> {
357 fn from(value: T) -> Self {
358 AsyncValidatedJson(value)
359 }
360}
361
362impl<T: Serialize> IntoResponse for AsyncValidatedJson<T> {
363 fn into_response(self) -> crate::response::Response {
364 Json(self.0).into_response()
365 }
366}
367
368impl<T: DeserializeOwned + AsyncValidate + Send + Sync> FromRequest for AsyncValidatedJson<T> {
369 async fn from_request(req: &mut Request) -> Result<Self> {
370 req.load_body().await?;
371
372 let body = req
373 .take_body()
374 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
375
376 let value: T = json::from_slice(&body)?;
377
378 let ctx = ValidationContext::default();
381
382 if let Err(errors) = value.validate_full(&ctx).await {
384 let field_errors: Vec<crate::error::FieldError> = errors
386 .fields
387 .iter()
388 .flat_map(|(field, errs)| {
389 let field_name = field.to_string();
390 errs.iter().map(move |e| crate::error::FieldError {
391 field: field_name.clone(),
392 code: e.code.to_string(),
393 message: e.message.clone(),
394 })
395 })
396 .collect();
397
398 return Err(ApiError::validation(field_errors));
399 }
400
401 Ok(AsyncValidatedJson(value))
402 }
403}
404
405#[derive(Debug, Clone)]
423pub struct Query<T>(pub T);
424
425impl<T: DeserializeOwned> FromRequestParts for Query<T> {
426 fn from_request_parts(req: &Request) -> Result<Self> {
427 let query = req.query_string().unwrap_or("");
428 let value: T = serde_urlencoded::from_str(query)
429 .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
430 Ok(Query(value))
431 }
432}
433
434impl<T> Deref for Query<T> {
435 type Target = T;
436
437 fn deref(&self) -> &Self::Target {
438 &self.0
439 }
440}
441
442#[derive(Debug, Clone)]
464pub struct Path<T>(pub T);
465
466impl<T: FromStr> FromRequestParts for Path<T>
467where
468 T::Err: std::fmt::Display,
469{
470 fn from_request_parts(req: &Request) -> Result<Self> {
471 let params = req.path_params();
472
473 if let Some((_, value)) = params.iter().next() {
475 let parsed = value
476 .parse::<T>()
477 .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
478 return Ok(Path(parsed));
479 }
480
481 Err(ApiError::internal("Missing path parameter"))
482 }
483}
484
485impl<T> Deref for Path<T> {
486 type Target = T;
487
488 fn deref(&self) -> &Self::Target {
489 &self.0
490 }
491}
492
493#[derive(Debug, Clone)]
513pub struct Typed<T>(pub T);
514
515impl<T: DeserializeOwned + Send> FromRequestParts for Typed<T> {
516 fn from_request_parts(req: &Request) -> Result<Self> {
517 let params = req.path_params();
518 let mut map = serde_json::Map::new();
519 for (k, v) in params.iter() {
520 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
521 }
522 let value = serde_json::Value::Object(map);
523 let parsed: T = serde_json::from_value(value)
524 .map_err(|e| ApiError::bad_request(format!("Invalid path parameters: {}", e)))?;
525 Ok(Typed(parsed))
526 }
527}
528
529impl<T> Deref for Typed<T> {
530 type Target = T;
531
532 fn deref(&self) -> &Self::Target {
533 &self.0
534 }
535}
536
537#[derive(Debug, Clone)]
554pub struct State<T>(pub T);
555
556impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
557 fn from_request_parts(req: &Request) -> Result<Self> {
558 req.state().get::<T>().cloned().map(State).ok_or_else(|| {
559 ApiError::internal(format!(
560 "State of type `{}` not found. Did you forget to call .state()?",
561 std::any::type_name::<T>()
562 ))
563 })
564 }
565}
566
567impl<T> Deref for State<T> {
568 type Target = T;
569
570 fn deref(&self) -> &Self::Target {
571 &self.0
572 }
573}
574
575#[derive(Debug, Clone)]
577pub struct Body(pub Bytes);
578
579impl FromRequest for Body {
580 async fn from_request(req: &mut Request) -> Result<Self> {
581 req.load_body().await?;
582 let body = req
583 .take_body()
584 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
585 Ok(Body(body))
586 }
587}
588
589impl Deref for Body {
590 type Target = Bytes;
591
592 fn deref(&self) -> &Self::Target {
593 &self.0
594 }
595}
596
597pub struct BodyStream(pub StreamingBody);
599
600impl FromRequest for BodyStream {
601 async fn from_request(req: &mut Request) -> Result<Self> {
602 let config = StreamingConfig::default();
603
604 if let Some(stream) = req.take_stream() {
605 Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
606 } else if let Some(bytes) = req.take_body() {
607 let stream = futures_util::stream::once(async move { Ok(bytes) });
609 Ok(BodyStream(StreamingBody::from_stream(
610 stream,
611 config.max_body_size,
612 )))
613 } else {
614 Err(ApiError::internal("Body already consumed"))
615 }
616 }
617}
618
619impl Deref for BodyStream {
620 type Target = StreamingBody;
621
622 fn deref(&self) -> &Self::Target {
623 &self.0
624 }
625}
626
627impl DerefMut for BodyStream {
628 fn deref_mut(&mut self) -> &mut Self::Target {
629 &mut self.0
630 }
631}
632
633impl futures_util::Stream for BodyStream {
635 type Item = Result<Bytes, ApiError>;
636
637 fn poll_next(
638 mut self: std::pin::Pin<&mut Self>,
639 cx: &mut std::task::Context<'_>,
640 ) -> std::task::Poll<Option<Self::Item>> {
641 std::pin::Pin::new(&mut self.0).poll_next(cx)
642 }
643}
644
645impl<T: FromRequestParts> FromRequestParts for Option<T> {
649 fn from_request_parts(req: &Request) -> Result<Self> {
650 Ok(T::from_request_parts(req).ok())
651 }
652}
653
654#[derive(Debug, Clone)]
672pub struct Headers(pub http::HeaderMap);
673
674impl Headers {
675 pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
677 self.0.get(name)
678 }
679
680 pub fn contains(&self, name: &str) -> bool {
682 self.0.contains_key(name)
683 }
684
685 pub fn len(&self) -> usize {
687 self.0.len()
688 }
689
690 pub fn is_empty(&self) -> bool {
692 self.0.is_empty()
693 }
694
695 pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
697 self.0.iter()
698 }
699}
700
701impl FromRequestParts for Headers {
702 fn from_request_parts(req: &Request) -> Result<Self> {
703 Ok(Headers(req.headers().clone()))
704 }
705}
706
707impl Deref for Headers {
708 type Target = http::HeaderMap;
709
710 fn deref(&self) -> &Self::Target {
711 &self.0
712 }
713}
714
715#[derive(Debug, Clone)]
734pub struct HeaderValue(pub String, pub &'static str);
735
736impl HeaderValue {
737 pub fn new(name: &'static str, value: String) -> Self {
739 Self(value, name)
740 }
741
742 pub fn value(&self) -> &str {
744 &self.0
745 }
746
747 pub fn name(&self) -> &'static str {
749 self.1
750 }
751
752 pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
754 req.headers()
755 .get(name)
756 .and_then(|v| v.to_str().ok())
757 .map(|s| HeaderValue(s.to_string(), name))
758 .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
759 }
760}
761
762impl Deref for HeaderValue {
763 type Target = String;
764
765 fn deref(&self) -> &Self::Target {
766 &self.0
767 }
768}
769
770#[derive(Debug, Clone)]
788pub struct Extension<T>(pub T);
789
790impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
791 fn from_request_parts(req: &Request) -> Result<Self> {
792 req.extensions()
793 .get::<T>()
794 .cloned()
795 .map(Extension)
796 .ok_or_else(|| {
797 ApiError::internal(format!(
798 "Extension of type `{}` not found. Did middleware insert it?",
799 std::any::type_name::<T>()
800 ))
801 })
802 }
803}
804
805impl<T> Deref for Extension<T> {
806 type Target = T;
807
808 fn deref(&self) -> &Self::Target {
809 &self.0
810 }
811}
812
813impl<T> DerefMut for Extension<T> {
814 fn deref_mut(&mut self) -> &mut Self::Target {
815 &mut self.0
816 }
817}
818
819#[derive(Debug, Clone)]
834pub struct ClientIp(pub std::net::IpAddr);
835
836impl ClientIp {
837 pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
839 if trust_proxy {
840 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
842 if let Ok(forwarded_str) = forwarded.to_str() {
843 if let Some(first_ip) = forwarded_str.split(',').next() {
845 if let Ok(ip) = first_ip.trim().parse() {
846 return Ok(ClientIp(ip));
847 }
848 }
849 }
850 }
851 }
852
853 if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
855 return Ok(ClientIp(addr.ip()));
856 }
857
858 Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
860 127, 0, 0, 1,
861 ))))
862 }
863}
864
865impl FromRequestParts for ClientIp {
866 fn from_request_parts(req: &Request) -> Result<Self> {
867 Self::extract_with_config(req, true)
869 }
870}
871
872#[cfg(feature = "cookies")]
890#[derive(Debug, Clone)]
891pub struct Cookies(pub cookie::CookieJar);
892
893#[cfg(feature = "cookies")]
894impl Cookies {
895 pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
897 self.0.get(name)
898 }
899
900 pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
902 self.0.iter()
903 }
904
905 pub fn contains(&self, name: &str) -> bool {
907 self.0.get(name).is_some()
908 }
909}
910
911#[cfg(feature = "cookies")]
912impl FromRequestParts for Cookies {
913 fn from_request_parts(req: &Request) -> Result<Self> {
914 let mut jar = cookie::CookieJar::new();
915
916 if let Some(cookie_header) = req.headers().get(header::COOKIE) {
917 if let Ok(cookie_str) = cookie_header.to_str() {
918 for cookie_part in cookie_str.split(';') {
920 let trimmed = cookie_part.trim();
921 if !trimmed.is_empty() {
922 if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
923 jar.add_original(cookie.into_owned());
924 }
925 }
926 }
927 }
928 }
929
930 Ok(Cookies(jar))
931 }
932}
933
934#[cfg(feature = "cookies")]
935impl Deref for Cookies {
936 type Target = cookie::CookieJar;
937
938 fn deref(&self) -> &Self::Target {
939 &self.0
940 }
941}
942
943macro_rules! impl_from_request_parts_for_primitives {
945 ($($ty:ty),*) => {
946 $(
947 impl FromRequestParts for $ty {
948 fn from_request_parts(req: &Request) -> Result<Self> {
949 let Path(value) = Path::<$ty>::from_request_parts(req)?;
950 Ok(value)
951 }
952 }
953 )*
954 };
955}
956
957impl_from_request_parts_for_primitives!(
958 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
959);
960
961use rustapi_openapi::utoipa_types::openapi;
964use rustapi_openapi::{
965 IntoParams, MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier,
966 ResponseSpec, Schema, SchemaRef,
967};
968use std::collections::HashMap;
969
970impl<T: for<'a> Schema<'a>> OperationModifier for ValidatedJson<T> {
972 fn update_operation(op: &mut Operation) {
973 let (name, _) = T::schema();
974
975 let schema_ref = SchemaRef::Ref {
976 reference: format!("#/components/schemas/{}", name),
977 };
978
979 let mut content = HashMap::new();
980 content.insert(
981 "application/json".to_string(),
982 MediaType { schema: schema_ref },
983 );
984
985 op.request_body = Some(RequestBody {
986 required: true,
987 content,
988 });
989
990 op.responses.insert(
992 "422".to_string(),
993 ResponseSpec {
994 description: "Validation Error".to_string(),
995 content: {
996 let mut map = HashMap::new();
997 map.insert(
998 "application/json".to_string(),
999 MediaType {
1000 schema: SchemaRef::Ref {
1001 reference: "#/components/schemas/ValidationErrorSchema".to_string(),
1002 },
1003 },
1004 );
1005 Some(map)
1006 },
1007 },
1008 );
1009 }
1010}
1011
1012impl<T: for<'a> Schema<'a>> OperationModifier for Json<T> {
1014 fn update_operation(op: &mut Operation) {
1015 let (name, _) = T::schema();
1016
1017 let schema_ref = SchemaRef::Ref {
1018 reference: format!("#/components/schemas/{}", name),
1019 };
1020
1021 let mut content = HashMap::new();
1022 content.insert(
1023 "application/json".to_string(),
1024 MediaType { schema: schema_ref },
1025 );
1026
1027 op.request_body = Some(RequestBody {
1028 required: true,
1029 content,
1030 });
1031 }
1032}
1033
1034impl<T> OperationModifier for Path<T> {
1038 fn update_operation(_op: &mut Operation) {
1039 }
1046}
1047
1048impl<T> OperationModifier for Typed<T> {
1050 fn update_operation(_op: &mut Operation) {
1051 }
1053}
1054
1055impl<T: IntoParams> OperationModifier for Query<T> {
1057 fn update_operation(op: &mut Operation) {
1058 let params = T::into_params(|| Some(openapi::path::ParameterIn::Query));
1059
1060 let new_params: Vec<Parameter> = params
1061 .into_iter()
1062 .map(|p| {
1063 let schema = match p.schema {
1064 Some(schema) => match schema {
1065 openapi::RefOr::Ref(r) => SchemaRef::Ref {
1066 reference: r.ref_location,
1067 },
1068 openapi::RefOr::T(s) => {
1069 let value = serde_json::to_value(s).unwrap_or(serde_json::Value::Null);
1070 SchemaRef::Inline(value)
1071 }
1072 },
1073 None => SchemaRef::Inline(serde_json::Value::Null),
1074 };
1075
1076 let required = match p.required {
1077 openapi::Required::True => true,
1078 openapi::Required::False => false,
1079 };
1080
1081 Parameter {
1082 name: p.name,
1083 location: "query".to_string(), required,
1085 description: p.description,
1086 schema,
1087 }
1088 })
1089 .collect();
1090
1091 if let Some(existing) = &mut op.parameters {
1092 existing.extend(new_params);
1093 } else {
1094 op.parameters = Some(new_params);
1095 }
1096 }
1097}
1098
1099impl<T> OperationModifier for State<T> {
1101 fn update_operation(_op: &mut Operation) {}
1102}
1103
1104impl OperationModifier for Body {
1106 fn update_operation(op: &mut Operation) {
1107 let mut content = HashMap::new();
1108 content.insert(
1109 "application/octet-stream".to_string(),
1110 MediaType {
1111 schema: SchemaRef::Inline(
1112 serde_json::json!({ "type": "string", "format": "binary" }),
1113 ),
1114 },
1115 );
1116
1117 op.request_body = Some(RequestBody {
1118 required: true,
1119 content,
1120 });
1121 }
1122}
1123
1124impl OperationModifier for BodyStream {
1126 fn update_operation(op: &mut Operation) {
1127 let mut content = HashMap::new();
1128 content.insert(
1129 "application/octet-stream".to_string(),
1130 MediaType {
1131 schema: SchemaRef::Inline(
1132 serde_json::json!({ "type": "string", "format": "binary" }),
1133 ),
1134 },
1135 );
1136
1137 op.request_body = Some(RequestBody {
1138 required: true,
1139 content,
1140 });
1141 }
1142}
1143
1144impl<T: for<'a> Schema<'a>> ResponseModifier for Json<T> {
1148 fn update_response(op: &mut Operation) {
1149 let (name, _) = T::schema();
1150
1151 let schema_ref = SchemaRef::Ref {
1152 reference: format!("#/components/schemas/{}", name),
1153 };
1154
1155 op.responses.insert(
1156 "200".to_string(),
1157 ResponseSpec {
1158 description: "Successful response".to_string(),
1159 content: {
1160 let mut map = HashMap::new();
1161 map.insert(
1162 "application/json".to_string(),
1163 MediaType { schema: schema_ref },
1164 );
1165 Some(map)
1166 },
1167 },
1168 );
1169 }
1170}
1171
1172#[cfg(test)]
1173mod tests {
1174 use super::*;
1175 use crate::path_params::PathParams;
1176 use bytes::Bytes;
1177 use http::{Extensions, Method};
1178 use proptest::prelude::*;
1179 use proptest::test_runner::TestCaseError;
1180 use std::sync::Arc;
1181
1182 fn create_test_request_with_headers(
1184 method: Method,
1185 path: &str,
1186 headers: Vec<(&str, &str)>,
1187 ) -> Request {
1188 let uri: http::Uri = path.parse().unwrap();
1189 let mut builder = http::Request::builder().method(method).uri(uri);
1190
1191 for (name, value) in headers {
1192 builder = builder.header(name, value);
1193 }
1194
1195 let req = builder.body(()).unwrap();
1196 let (parts, _) = req.into_parts();
1197
1198 Request::new(
1199 parts,
1200 crate::request::BodyVariant::Buffered(Bytes::new()),
1201 Arc::new(Extensions::new()),
1202 PathParams::new(),
1203 )
1204 }
1205
1206 fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
1208 method: Method,
1209 path: &str,
1210 extension: T,
1211 ) -> Request {
1212 let uri: http::Uri = path.parse().unwrap();
1213 let builder = http::Request::builder().method(method).uri(uri);
1214
1215 let req = builder.body(()).unwrap();
1216 let (mut parts, _) = req.into_parts();
1217 parts.extensions.insert(extension);
1218
1219 Request::new(
1220 parts,
1221 crate::request::BodyVariant::Buffered(Bytes::new()),
1222 Arc::new(Extensions::new()),
1223 PathParams::new(),
1224 )
1225 }
1226
1227 proptest! {
1234 #![proptest_config(ProptestConfig::with_cases(100))]
1235
1236 #[test]
1237 fn prop_headers_extractor_completeness(
1238 headers in prop::collection::vec(
1241 (
1242 "[a-z][a-z0-9-]{0,20}", "[a-zA-Z0-9 ]{1,50}" ),
1245 0..10
1246 )
1247 ) {
1248 let result: Result<(), TestCaseError> = (|| {
1249 let header_tuples: Vec<(&str, &str)> = headers
1251 .iter()
1252 .map(|(k, v)| (k.as_str(), v.as_str()))
1253 .collect();
1254
1255 let request = create_test_request_with_headers(
1257 Method::GET,
1258 "/test",
1259 header_tuples.clone(),
1260 );
1261
1262 let extracted = Headers::from_request_parts(&request)
1264 .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
1265
1266 for (name, value) in &headers {
1269 let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
1271 prop_assert!(
1272 !all_values.is_empty(),
1273 "Header '{}' not found",
1274 name
1275 );
1276
1277 let value_found = all_values.iter().any(|v| {
1279 v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
1280 });
1281
1282 prop_assert!(
1283 value_found,
1284 "Header '{}' value '{}' not found in extracted values",
1285 name,
1286 value
1287 );
1288 }
1289
1290 Ok(())
1291 })();
1292 result?;
1293 }
1294 }
1295
1296 proptest! {
1303 #![proptest_config(ProptestConfig::with_cases(100))]
1304
1305 #[test]
1306 fn prop_header_value_extractor_correctness(
1307 header_name in "[a-z][a-z0-9-]{0,20}",
1308 header_value in "[a-zA-Z0-9 ]{1,50}",
1309 has_header in prop::bool::ANY,
1310 ) {
1311 let result: Result<(), TestCaseError> = (|| {
1312 let headers = if has_header {
1313 vec![(header_name.as_str(), header_value.as_str())]
1314 } else {
1315 vec![]
1316 };
1317
1318 let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1319
1320 let test_header = "x-test-header";
1323 let request_with_known_header = if has_header {
1324 create_test_request_with_headers(
1325 Method::GET,
1326 "/test",
1327 vec![(test_header, header_value.as_str())],
1328 )
1329 } else {
1330 create_test_request_with_headers(Method::GET, "/test", vec![])
1331 };
1332
1333 let result = HeaderValue::extract(&request_with_known_header, test_header);
1334
1335 if has_header {
1336 let extracted = result
1337 .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1338 prop_assert_eq!(
1339 extracted.value(),
1340 header_value.as_str(),
1341 "Header value mismatch"
1342 );
1343 } else {
1344 prop_assert!(
1345 result.is_err(),
1346 "Expected error when header is missing"
1347 );
1348 }
1349
1350 Ok(())
1351 })();
1352 result?;
1353 }
1354 }
1355
1356 proptest! {
1363 #![proptest_config(ProptestConfig::with_cases(100))]
1364
1365 #[test]
1366 fn prop_client_ip_extractor_with_forwarding(
1367 forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1369 .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1370 socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1371 .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1372 has_forwarded_header in prop::bool::ANY,
1373 trust_proxy in prop::bool::ANY,
1374 ) {
1375 let result: Result<(), TestCaseError> = (|| {
1376 let headers = if has_forwarded_header {
1377 vec![("x-forwarded-for", forwarded_ip.as_str())]
1378 } else {
1379 vec![]
1380 };
1381
1382 let uri: http::Uri = "/test".parse().unwrap();
1384 let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1385 for (name, value) in &headers {
1386 builder = builder.header(*name, *value);
1387 }
1388 let req = builder.body(()).unwrap();
1389 let (mut parts, _) = req.into_parts();
1390
1391 let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1393 parts.extensions.insert(socket_addr);
1394
1395 let request = Request::new(
1396 parts,
1397 crate::request::BodyVariant::Buffered(Bytes::new()),
1398 Arc::new(Extensions::new()),
1399 PathParams::new(),
1400 );
1401
1402 let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1403 .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1404
1405 if trust_proxy && has_forwarded_header {
1406 let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1408 .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1409 prop_assert_eq!(
1410 extracted.0,
1411 expected_ip,
1412 "Should use X-Forwarded-For IP when trust_proxy is enabled"
1413 );
1414 } else {
1415 prop_assert_eq!(
1417 extracted.0,
1418 socket_ip,
1419 "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1420 );
1421 }
1422
1423 Ok(())
1424 })();
1425 result?;
1426 }
1427 }
1428
1429 proptest! {
1436 #![proptest_config(ProptestConfig::with_cases(100))]
1437
1438 #[test]
1439 fn prop_extension_extractor_retrieval(
1440 value in any::<i64>(),
1441 has_extension in prop::bool::ANY,
1442 ) {
1443 let result: Result<(), TestCaseError> = (|| {
1444 #[derive(Clone, Debug, PartialEq)]
1446 struct TestExtension(i64);
1447
1448 let uri: http::Uri = "/test".parse().unwrap();
1449 let builder = http::Request::builder().method(Method::GET).uri(uri);
1450 let req = builder.body(()).unwrap();
1451 let (mut parts, _) = req.into_parts();
1452
1453 if has_extension {
1454 parts.extensions.insert(TestExtension(value));
1455 }
1456
1457 let request = Request::new(
1458 parts,
1459 crate::request::BodyVariant::Buffered(Bytes::new()),
1460 Arc::new(Extensions::new()),
1461 PathParams::new(),
1462 );
1463
1464 let result = Extension::<TestExtension>::from_request_parts(&request);
1465
1466 if has_extension {
1467 let extracted = result
1468 .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1469 prop_assert_eq!(
1470 extracted.0,
1471 TestExtension(value),
1472 "Extension value mismatch"
1473 );
1474 } else {
1475 prop_assert!(
1476 result.is_err(),
1477 "Expected error when extension is missing"
1478 );
1479 }
1480
1481 Ok(())
1482 })();
1483 result?;
1484 }
1485 }
1486
1487 #[test]
1490 fn test_headers_extractor_basic() {
1491 let request = create_test_request_with_headers(
1492 Method::GET,
1493 "/test",
1494 vec![
1495 ("content-type", "application/json"),
1496 ("accept", "text/html"),
1497 ],
1498 );
1499
1500 let headers = Headers::from_request_parts(&request).unwrap();
1501
1502 assert!(headers.contains("content-type"));
1503 assert!(headers.contains("accept"));
1504 assert!(!headers.contains("x-custom"));
1505 assert_eq!(headers.len(), 2);
1506 }
1507
1508 #[test]
1509 fn test_header_value_extractor_present() {
1510 let request = create_test_request_with_headers(
1511 Method::GET,
1512 "/test",
1513 vec![("authorization", "Bearer token123")],
1514 );
1515
1516 let result = HeaderValue::extract(&request, "authorization");
1517 assert!(result.is_ok());
1518 assert_eq!(result.unwrap().value(), "Bearer token123");
1519 }
1520
1521 #[test]
1522 fn test_header_value_extractor_missing() {
1523 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1524
1525 let result = HeaderValue::extract(&request, "authorization");
1526 assert!(result.is_err());
1527 }
1528
1529 #[test]
1530 fn test_client_ip_from_forwarded_header() {
1531 let request = create_test_request_with_headers(
1532 Method::GET,
1533 "/test",
1534 vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1535 );
1536
1537 let ip = ClientIp::extract_with_config(&request, true).unwrap();
1538 assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1539 }
1540
1541 #[test]
1542 fn test_client_ip_ignores_forwarded_when_not_trusted() {
1543 let uri: http::Uri = "/test".parse().unwrap();
1544 let builder = http::Request::builder()
1545 .method(Method::GET)
1546 .uri(uri)
1547 .header("x-forwarded-for", "192.168.1.100");
1548 let req = builder.body(()).unwrap();
1549 let (mut parts, _) = req.into_parts();
1550
1551 let socket_addr = std::net::SocketAddr::new(
1552 std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1553 8080,
1554 );
1555 parts.extensions.insert(socket_addr);
1556
1557 let request = Request::new(
1558 parts,
1559 crate::request::BodyVariant::Buffered(Bytes::new()),
1560 Arc::new(Extensions::new()),
1561 PathParams::new(),
1562 );
1563
1564 let ip = ClientIp::extract_with_config(&request, false).unwrap();
1565 assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1566 }
1567
1568 #[test]
1569 fn test_extension_extractor_present() {
1570 #[derive(Clone, Debug, PartialEq)]
1571 struct MyData(String);
1572
1573 let request =
1574 create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1575
1576 let result = Extension::<MyData>::from_request_parts(&request);
1577 assert!(result.is_ok());
1578 assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1579 }
1580
1581 #[test]
1582 fn test_extension_extractor_missing() {
1583 #[derive(Clone, Debug)]
1584 #[allow(dead_code)]
1585 struct MyData(String);
1586
1587 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1588
1589 let result = Extension::<MyData>::from_request_parts(&request);
1590 assert!(result.is_err());
1591 }
1592
1593 #[cfg(feature = "cookies")]
1595 mod cookies_tests {
1596 use super::*;
1597
1598 proptest! {
1606 #![proptest_config(ProptestConfig::with_cases(100))]
1607
1608 #[test]
1609 fn prop_cookies_extractor_parsing(
1610 cookies in prop::collection::vec(
1613 (
1614 "[a-zA-Z][a-zA-Z0-9_]{0,15}", "[a-zA-Z0-9]{1,30}" ),
1617 0..5
1618 )
1619 ) {
1620 let result: Result<(), TestCaseError> = (|| {
1621 let cookie_header = cookies
1623 .iter()
1624 .map(|(name, value)| format!("{}={}", name, value))
1625 .collect::<Vec<_>>()
1626 .join("; ");
1627
1628 let headers = if !cookies.is_empty() {
1629 vec![("cookie", cookie_header.as_str())]
1630 } else {
1631 vec![]
1632 };
1633
1634 let request = create_test_request_with_headers(Method::GET, "/test", headers);
1635
1636 let extracted = Cookies::from_request_parts(&request)
1638 .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1639
1640 let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1642 for (name, value) in &cookies {
1643 expected_cookies.insert(name.as_str(), value.as_str());
1644 }
1645
1646 for (name, expected_value) in &expected_cookies {
1648 let cookie = extracted.get(name)
1649 .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1650
1651 prop_assert_eq!(
1652 cookie.value(),
1653 *expected_value,
1654 "Cookie '{}' value mismatch",
1655 name
1656 );
1657 }
1658
1659 let extracted_count = extracted.iter().count();
1661 prop_assert_eq!(
1662 extracted_count,
1663 expected_cookies.len(),
1664 "Expected {} unique cookies, got {}",
1665 expected_cookies.len(),
1666 extracted_count
1667 );
1668
1669 Ok(())
1670 })();
1671 result?;
1672 }
1673 }
1674
1675 #[test]
1676 fn test_cookies_extractor_basic() {
1677 let request = create_test_request_with_headers(
1678 Method::GET,
1679 "/test",
1680 vec![("cookie", "session=abc123; user=john")],
1681 );
1682
1683 let cookies = Cookies::from_request_parts(&request).unwrap();
1684
1685 assert!(cookies.contains("session"));
1686 assert!(cookies.contains("user"));
1687 assert!(!cookies.contains("other"));
1688
1689 assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1690 assert_eq!(cookies.get("user").unwrap().value(), "john");
1691 }
1692
1693 #[test]
1694 fn test_cookies_extractor_empty() {
1695 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1696
1697 let cookies = Cookies::from_request_parts(&request).unwrap();
1698 assert_eq!(cookies.iter().count(), 0);
1699 }
1700
1701 #[test]
1702 fn test_cookies_extractor_single() {
1703 let request = create_test_request_with_headers(
1704 Method::GET,
1705 "/test",
1706 vec![("cookie", "token=xyz789")],
1707 );
1708
1709 let cookies = Cookies::from_request_parts(&request).unwrap();
1710 assert_eq!(cookies.iter().count(), 1);
1711 assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1712 }
1713 }
1714}