1use crate::error::{ApiError, Result};
58use crate::json;
59use crate::request::Request;
60use crate::response::IntoResponse;
61use crate::stream::{StreamingBody, StreamingConfig};
62use crate::validation::Validatable;
63use bytes::Bytes;
64use http::{header, StatusCode};
65use rustapi_validate::v2::{AsyncValidate, ValidationContext};
66
67use rustapi_openapi::schema::{RustApiSchema, SchemaCtx, SchemaRef};
68use serde::de::DeserializeOwned;
69use serde::Serialize;
70use std::collections::BTreeMap;
71use std::future::Future;
72use std::ops::{Deref, DerefMut};
73use std::str::FromStr;
74
75pub trait FromRequestParts: Sized {
100 fn from_request_parts(req: &Request) -> Result<Self>;
102}
103
104pub trait FromRequest: Sized {
134 fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
136}
137
138impl<T: FromRequestParts> FromRequest for T {
140 async fn from_request(req: &mut Request) -> Result<Self> {
141 T::from_request_parts(req)
142 }
143}
144
145#[derive(Debug, Clone, Copy, Default)]
164pub struct Json<T>(pub T);
165
166impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
167 async fn from_request(req: &mut Request) -> Result<Self> {
168 req.load_body().await?;
169 let body = req
170 .take_body()
171 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
172
173 let value: T = json::from_slice(&body)?;
175 Ok(Json(value))
176 }
177}
178
179impl<T> Deref for Json<T> {
180 type Target = T;
181
182 fn deref(&self) -> &Self::Target {
183 &self.0
184 }
185}
186
187impl<T> DerefMut for Json<T> {
188 fn deref_mut(&mut self) -> &mut Self::Target {
189 &mut self.0
190 }
191}
192
193impl<T> From<T> for Json<T> {
194 fn from(value: T) -> Self {
195 Json(value)
196 }
197}
198
199const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
202
203impl<T: Serialize> IntoResponse for Json<T> {
205 fn into_response(self) -> crate::response::Response {
206 match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
208 Ok(body) => http::Response::builder()
209 .status(StatusCode::OK)
210 .header(header::CONTENT_TYPE, "application/json")
211 .body(crate::response::Body::from(body))
212 .unwrap(),
213 Err(err) => {
214 ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
215 }
216 }
217 }
218}
219
220#[derive(Debug, Clone, Copy, Default)]
246pub struct ValidatedJson<T>(pub T);
247
248impl<T> ValidatedJson<T> {
249 pub fn new(value: T) -> Self {
251 Self(value)
252 }
253
254 pub fn into_inner(self) -> T {
256 self.0
257 }
258}
259
260impl<T: DeserializeOwned + Validatable + Send> FromRequest for ValidatedJson<T> {
261 async fn from_request(req: &mut Request) -> Result<Self> {
262 req.load_body().await?;
263 let body = req
265 .take_body()
266 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
267
268 let value: T = json::from_slice(&body)?;
269
270 value.do_validate()?;
272
273 Ok(ValidatedJson(value))
274 }
275}
276
277impl<T> Deref for ValidatedJson<T> {
278 type Target = T;
279
280 fn deref(&self) -> &Self::Target {
281 &self.0
282 }
283}
284
285impl<T> DerefMut for ValidatedJson<T> {
286 fn deref_mut(&mut self) -> &mut Self::Target {
287 &mut self.0
288 }
289}
290
291impl<T> From<T> for ValidatedJson<T> {
292 fn from(value: T) -> Self {
293 ValidatedJson(value)
294 }
295}
296
297impl<T: Serialize> IntoResponse for ValidatedJson<T> {
298 fn into_response(self) -> crate::response::Response {
299 Json(self.0).into_response()
300 }
301}
302
303#[derive(Debug, Clone, Copy, Default)]
330pub struct AsyncValidatedJson<T>(pub T);
331
332impl<T> AsyncValidatedJson<T> {
333 pub fn new(value: T) -> Self {
335 Self(value)
336 }
337
338 pub fn into_inner(self) -> T {
340 self.0
341 }
342}
343
344impl<T> Deref for AsyncValidatedJson<T> {
345 type Target = T;
346
347 fn deref(&self) -> &Self::Target {
348 &self.0
349 }
350}
351
352impl<T> DerefMut for AsyncValidatedJson<T> {
353 fn deref_mut(&mut self) -> &mut Self::Target {
354 &mut self.0
355 }
356}
357
358impl<T> From<T> for AsyncValidatedJson<T> {
359 fn from(value: T) -> Self {
360 AsyncValidatedJson(value)
361 }
362}
363
364impl<T: Serialize> IntoResponse for AsyncValidatedJson<T> {
365 fn into_response(self) -> crate::response::Response {
366 Json(self.0).into_response()
367 }
368}
369
370impl<T: DeserializeOwned + AsyncValidate + Send + Sync> FromRequest for AsyncValidatedJson<T> {
371 async fn from_request(req: &mut Request) -> Result<Self> {
372 req.load_body().await?;
373
374 let body = req
375 .take_body()
376 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
377
378 let value: T = json::from_slice(&body)?;
379
380 let ctx = ValidationContext::default();
383
384 if let Err(errors) = value.validate_full(&ctx).await {
386 let field_errors: Vec<crate::error::FieldError> = errors
388 .fields
389 .iter()
390 .flat_map(|(field, errs)| {
391 let field_name = field.to_string();
392 errs.iter().map(move |e| crate::error::FieldError {
393 field: field_name.clone(),
394 code: e.code.to_string(),
395 message: e.message.clone(),
396 })
397 })
398 .collect();
399
400 return Err(ApiError::validation(field_errors));
401 }
402
403 Ok(AsyncValidatedJson(value))
404 }
405}
406
407#[derive(Debug, Clone)]
425pub struct Query<T>(pub T);
426
427impl<T: DeserializeOwned> FromRequestParts for Query<T> {
428 fn from_request_parts(req: &Request) -> Result<Self> {
429 let query = req.query_string().unwrap_or("");
430 let value: T = serde_urlencoded::from_str(query)
431 .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
432 Ok(Query(value))
433 }
434}
435
436impl<T> Deref for Query<T> {
437 type Target = T;
438
439 fn deref(&self) -> &Self::Target {
440 &self.0
441 }
442}
443
444#[derive(Debug, Clone)]
466pub struct Path<T>(pub T);
467
468impl<T: FromStr> FromRequestParts for Path<T>
469where
470 T::Err: std::fmt::Display,
471{
472 fn from_request_parts(req: &Request) -> Result<Self> {
473 let params = req.path_params();
474
475 if let Some((_, value)) = params.iter().next() {
477 let parsed = value
478 .parse::<T>()
479 .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
480 return Ok(Path(parsed));
481 }
482
483 Err(ApiError::internal("Missing path parameter"))
484 }
485}
486
487impl<T> Deref for Path<T> {
488 type Target = T;
489
490 fn deref(&self) -> &Self::Target {
491 &self.0
492 }
493}
494
495#[derive(Debug, Clone)]
515pub struct Typed<T>(pub T);
516
517impl<T: DeserializeOwned + Send> FromRequestParts for Typed<T> {
518 fn from_request_parts(req: &Request) -> Result<Self> {
519 let params = req.path_params();
520 let mut map = serde_json::Map::new();
521 for (k, v) in params.iter() {
522 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
523 }
524 let value = serde_json::Value::Object(map);
525 let parsed: T = serde_json::from_value(value)
526 .map_err(|e| ApiError::bad_request(format!("Invalid path parameters: {}", e)))?;
527 Ok(Typed(parsed))
528 }
529}
530
531impl<T> Deref for Typed<T> {
532 type Target = T;
533
534 fn deref(&self) -> &Self::Target {
535 &self.0
536 }
537}
538
539#[derive(Debug, Clone)]
556pub struct State<T>(pub T);
557
558impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
559 fn from_request_parts(req: &Request) -> Result<Self> {
560 req.state().get::<T>().cloned().map(State).ok_or_else(|| {
561 ApiError::internal(format!(
562 "State of type `{}` not found. Did you forget to call .state()?",
563 std::any::type_name::<T>()
564 ))
565 })
566 }
567}
568
569impl<T> Deref for State<T> {
570 type Target = T;
571
572 fn deref(&self) -> &Self::Target {
573 &self.0
574 }
575}
576
577#[derive(Debug, Clone)]
579pub struct Body(pub Bytes);
580
581impl FromRequest for Body {
582 async fn from_request(req: &mut Request) -> Result<Self> {
583 req.load_body().await?;
584 let body = req
585 .take_body()
586 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
587 Ok(Body(body))
588 }
589}
590
591impl Deref for Body {
592 type Target = Bytes;
593
594 fn deref(&self) -> &Self::Target {
595 &self.0
596 }
597}
598
599pub struct BodyStream(pub StreamingBody);
601
602impl FromRequest for BodyStream {
603 async fn from_request(req: &mut Request) -> Result<Self> {
604 let config = StreamingConfig::default();
605
606 if let Some(stream) = req.take_stream() {
607 Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
608 } else if let Some(bytes) = req.take_body() {
609 let stream = futures_util::stream::once(async move { Ok(bytes) });
611 Ok(BodyStream(StreamingBody::from_stream(
612 stream,
613 config.max_body_size,
614 )))
615 } else {
616 Err(ApiError::internal("Body already consumed"))
617 }
618 }
619}
620
621impl Deref for BodyStream {
622 type Target = StreamingBody;
623
624 fn deref(&self) -> &Self::Target {
625 &self.0
626 }
627}
628
629impl DerefMut for BodyStream {
630 fn deref_mut(&mut self) -> &mut Self::Target {
631 &mut self.0
632 }
633}
634
635impl futures_util::Stream for BodyStream {
637 type Item = Result<Bytes, ApiError>;
638
639 fn poll_next(
640 mut self: std::pin::Pin<&mut Self>,
641 cx: &mut std::task::Context<'_>,
642 ) -> std::task::Poll<Option<Self::Item>> {
643 std::pin::Pin::new(&mut self.0).poll_next(cx)
644 }
645}
646
647impl<T: FromRequestParts> FromRequestParts for Option<T> {
651 fn from_request_parts(req: &Request) -> Result<Self> {
652 Ok(T::from_request_parts(req).ok())
653 }
654}
655
656#[derive(Debug, Clone)]
674pub struct Headers(pub http::HeaderMap);
675
676impl Headers {
677 pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
679 self.0.get(name)
680 }
681
682 pub fn contains(&self, name: &str) -> bool {
684 self.0.contains_key(name)
685 }
686
687 pub fn len(&self) -> usize {
689 self.0.len()
690 }
691
692 pub fn is_empty(&self) -> bool {
694 self.0.is_empty()
695 }
696
697 pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
699 self.0.iter()
700 }
701}
702
703impl FromRequestParts for Headers {
704 fn from_request_parts(req: &Request) -> Result<Self> {
705 Ok(Headers(req.headers().clone()))
706 }
707}
708
709impl Deref for Headers {
710 type Target = http::HeaderMap;
711
712 fn deref(&self) -> &Self::Target {
713 &self.0
714 }
715}
716
717#[derive(Debug, Clone)]
736pub struct HeaderValue(pub String, pub &'static str);
737
738impl HeaderValue {
739 pub fn new(name: &'static str, value: String) -> Self {
741 Self(value, name)
742 }
743
744 pub fn value(&self) -> &str {
746 &self.0
747 }
748
749 pub fn name(&self) -> &'static str {
751 self.1
752 }
753
754 pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
756 req.headers()
757 .get(name)
758 .and_then(|v| v.to_str().ok())
759 .map(|s| HeaderValue(s.to_string(), name))
760 .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
761 }
762}
763
764impl Deref for HeaderValue {
765 type Target = String;
766
767 fn deref(&self) -> &Self::Target {
768 &self.0
769 }
770}
771
772#[derive(Debug, Clone)]
790pub struct Extension<T>(pub T);
791
792impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
793 fn from_request_parts(req: &Request) -> Result<Self> {
794 req.extensions()
795 .get::<T>()
796 .cloned()
797 .map(Extension)
798 .ok_or_else(|| {
799 ApiError::internal(format!(
800 "Extension of type `{}` not found. Did middleware insert it?",
801 std::any::type_name::<T>()
802 ))
803 })
804 }
805}
806
807impl<T> Deref for Extension<T> {
808 type Target = T;
809
810 fn deref(&self) -> &Self::Target {
811 &self.0
812 }
813}
814
815impl<T> DerefMut for Extension<T> {
816 fn deref_mut(&mut self) -> &mut Self::Target {
817 &mut self.0
818 }
819}
820
821#[derive(Debug, Clone)]
836pub struct ClientIp(pub std::net::IpAddr);
837
838impl ClientIp {
839 pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
841 if trust_proxy {
842 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
844 if let Ok(forwarded_str) = forwarded.to_str() {
845 if let Some(first_ip) = forwarded_str.split(',').next() {
847 if let Ok(ip) = first_ip.trim().parse() {
848 return Ok(ClientIp(ip));
849 }
850 }
851 }
852 }
853 }
854
855 if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
857 return Ok(ClientIp(addr.ip()));
858 }
859
860 Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
862 127, 0, 0, 1,
863 ))))
864 }
865}
866
867impl FromRequestParts for ClientIp {
868 fn from_request_parts(req: &Request) -> Result<Self> {
869 Self::extract_with_config(req, true)
871 }
872}
873
874#[cfg(feature = "cookies")]
892#[derive(Debug, Clone)]
893pub struct Cookies(pub cookie::CookieJar);
894
895#[cfg(feature = "cookies")]
896impl Cookies {
897 pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
899 self.0.get(name)
900 }
901
902 pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
904 self.0.iter()
905 }
906
907 pub fn contains(&self, name: &str) -> bool {
909 self.0.get(name).is_some()
910 }
911}
912
913#[cfg(feature = "cookies")]
914impl FromRequestParts for Cookies {
915 fn from_request_parts(req: &Request) -> Result<Self> {
916 let mut jar = cookie::CookieJar::new();
917
918 if let Some(cookie_header) = req.headers().get(header::COOKIE) {
919 if let Ok(cookie_str) = cookie_header.to_str() {
920 for cookie_part in cookie_str.split(';') {
922 let trimmed = cookie_part.trim();
923 if !trimmed.is_empty() {
924 if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
925 jar.add_original(cookie.into_owned());
926 }
927 }
928 }
929 }
930 }
931
932 Ok(Cookies(jar))
933 }
934}
935
936#[cfg(feature = "cookies")]
937impl Deref for Cookies {
938 type Target = cookie::CookieJar;
939
940 fn deref(&self) -> &Self::Target {
941 &self.0
942 }
943}
944
945macro_rules! impl_from_request_parts_for_primitives {
947 ($($ty:ty),*) => {
948 $(
949 impl FromRequestParts for $ty {
950 fn from_request_parts(req: &Request) -> Result<Self> {
951 let Path(value) = Path::<$ty>::from_request_parts(req)?;
952 Ok(value)
953 }
954 }
955 )*
956 };
957}
958
959impl_from_request_parts_for_primitives!(
960 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
961);
962
963use rustapi_openapi::{
966 MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier, ResponseSpec,
967};
968
969impl<T: RustApiSchema> OperationModifier for ValidatedJson<T> {
971 fn update_operation(op: &mut Operation) {
972 let mut ctx = SchemaCtx::new();
973 let schema_ref = T::schema(&mut ctx);
974
975 let mut content = BTreeMap::new();
976 content.insert(
977 "application/json".to_string(),
978 MediaType {
979 schema: Some(schema_ref),
980 example: None,
981 },
982 );
983
984 op.request_body = Some(RequestBody {
985 description: None,
986 required: Some(true),
987 content,
988 });
989
990 let mut responses_content = BTreeMap::new();
992 responses_content.insert(
993 "application/json".to_string(),
994 MediaType {
995 schema: Some(SchemaRef::Ref {
996 reference: "#/components/schemas/ValidationErrorSchema".to_string(),
997 }),
998 example: None,
999 },
1000 );
1001
1002 op.responses.insert(
1003 "422".to_string(),
1004 ResponseSpec {
1005 description: "Validation Error".to_string(),
1006 content: responses_content,
1007 headers: BTreeMap::new(),
1008 },
1009 );
1010 }
1011}
1012
1013impl<T: RustApiSchema> OperationModifier for Json<T> {
1015 fn update_operation(op: &mut Operation) {
1016 let mut ctx = SchemaCtx::new();
1017 let schema_ref = T::schema(&mut ctx);
1018
1019 let mut content = BTreeMap::new();
1020 content.insert(
1021 "application/json".to_string(),
1022 MediaType {
1023 schema: Some(schema_ref),
1024 example: None,
1025 },
1026 );
1027
1028 op.request_body = Some(RequestBody {
1029 description: None,
1030 required: Some(true),
1031 content,
1032 });
1033 }
1034}
1035
1036impl<T> OperationModifier for Path<T> {
1038 fn update_operation(_op: &mut Operation) {}
1039}
1040
1041impl<T> OperationModifier for Typed<T> {
1043 fn update_operation(_op: &mut Operation) {}
1044}
1045
1046impl<T: RustApiSchema> OperationModifier for Query<T> {
1048 fn update_operation(op: &mut Operation) {
1049 let mut ctx = SchemaCtx::new();
1050 if let Some(fields) = T::field_schemas(&mut ctx) {
1051 let new_params: Vec<Parameter> = fields
1052 .into_iter()
1053 .map(|(name, schema)| {
1054 Parameter {
1055 name,
1056 location: "query".to_string(),
1057 required: false, deprecated: None,
1059 description: None,
1060 schema: Some(schema),
1061 }
1062 })
1063 .collect();
1064
1065 op.parameters.extend(new_params);
1066 }
1067 }
1068}
1069
1070impl<T> OperationModifier for State<T> {
1072 fn update_operation(_op: &mut Operation) {}
1073}
1074
1075impl OperationModifier for Body {
1077 fn update_operation(op: &mut Operation) {
1078 let mut content = BTreeMap::new();
1079 content.insert(
1080 "application/octet-stream".to_string(),
1081 MediaType {
1082 schema: Some(SchemaRef::Inline(
1083 serde_json::json!({ "type": "string", "format": "binary" }),
1084 )),
1085 example: None,
1086 },
1087 );
1088
1089 op.request_body = Some(RequestBody {
1090 description: None,
1091 required: Some(true),
1092 content,
1093 });
1094 }
1095}
1096
1097impl OperationModifier for BodyStream {
1099 fn update_operation(op: &mut Operation) {
1100 let mut content = BTreeMap::new();
1101 content.insert(
1102 "application/octet-stream".to_string(),
1103 MediaType {
1104 schema: Some(SchemaRef::Inline(
1105 serde_json::json!({ "type": "string", "format": "binary" }),
1106 )),
1107 example: None,
1108 },
1109 );
1110
1111 op.request_body = Some(RequestBody {
1112 description: None,
1113 required: Some(true),
1114 content,
1115 });
1116 }
1117}
1118
1119impl<T: RustApiSchema> ResponseModifier for Json<T> {
1123 fn update_response(op: &mut Operation) {
1124 let mut ctx = SchemaCtx::new();
1125 let schema_ref = T::schema(&mut ctx);
1126
1127 let mut content = BTreeMap::new();
1128 content.insert(
1129 "application/json".to_string(),
1130 MediaType {
1131 schema: Some(schema_ref),
1132 example: None,
1133 },
1134 );
1135
1136 op.responses.insert(
1137 "200".to_string(),
1138 ResponseSpec {
1139 description: "Successful response".to_string(),
1140 content,
1141 headers: BTreeMap::new(),
1142 },
1143 );
1144 }
1145}
1146
1147impl<T: RustApiSchema> RustApiSchema for Json<T> {
1150 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1151 T::schema(ctx)
1152 }
1153}
1154
1155impl<T: RustApiSchema> RustApiSchema for ValidatedJson<T> {
1156 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1157 T::schema(ctx)
1158 }
1159}
1160
1161impl<T: RustApiSchema> RustApiSchema for AsyncValidatedJson<T> {
1162 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1163 T::schema(ctx)
1164 }
1165}
1166
1167impl<T: RustApiSchema> RustApiSchema for Query<T> {
1168 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1169 T::schema(ctx)
1170 }
1171 fn field_schemas(ctx: &mut SchemaCtx) -> Option<BTreeMap<String, SchemaRef>> {
1172 T::field_schemas(ctx)
1173 }
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178 use super::*;
1179 use crate::path_params::PathParams;
1180 use bytes::Bytes;
1181 use http::{Extensions, Method};
1182 use proptest::prelude::*;
1183 use proptest::test_runner::TestCaseError;
1184 use std::sync::Arc;
1185
1186 fn create_test_request_with_headers(
1188 method: Method,
1189 path: &str,
1190 headers: Vec<(&str, &str)>,
1191 ) -> Request {
1192 let uri: http::Uri = path.parse().unwrap();
1193 let mut builder = http::Request::builder().method(method).uri(uri);
1194
1195 for (name, value) in headers {
1196 builder = builder.header(name, value);
1197 }
1198
1199 let req = builder.body(()).unwrap();
1200 let (parts, _) = req.into_parts();
1201
1202 Request::new(
1203 parts,
1204 crate::request::BodyVariant::Buffered(Bytes::new()),
1205 Arc::new(Extensions::new()),
1206 PathParams::new(),
1207 )
1208 }
1209
1210 fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
1212 method: Method,
1213 path: &str,
1214 extension: T,
1215 ) -> Request {
1216 let uri: http::Uri = path.parse().unwrap();
1217 let builder = http::Request::builder().method(method).uri(uri);
1218
1219 let req = builder.body(()).unwrap();
1220 let (mut parts, _) = req.into_parts();
1221 parts.extensions.insert(extension);
1222
1223 Request::new(
1224 parts,
1225 crate::request::BodyVariant::Buffered(Bytes::new()),
1226 Arc::new(Extensions::new()),
1227 PathParams::new(),
1228 )
1229 }
1230
1231 proptest! {
1238 #![proptest_config(ProptestConfig::with_cases(100))]
1239
1240 #[test]
1241 fn prop_headers_extractor_completeness(
1242 headers in prop::collection::vec(
1245 (
1246 "[a-z][a-z0-9-]{0,20}", "[a-zA-Z0-9 ]{1,50}" ),
1249 0..10
1250 )
1251 ) {
1252 let result: Result<(), TestCaseError> = (|| {
1253 let header_tuples: Vec<(&str, &str)> = headers
1255 .iter()
1256 .map(|(k, v)| (k.as_str(), v.as_str()))
1257 .collect();
1258
1259 let request = create_test_request_with_headers(
1261 Method::GET,
1262 "/test",
1263 header_tuples.clone(),
1264 );
1265
1266 let extracted = Headers::from_request_parts(&request)
1268 .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
1269
1270 for (name, value) in &headers {
1273 let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
1275 prop_assert!(
1276 !all_values.is_empty(),
1277 "Header '{}' not found",
1278 name
1279 );
1280
1281 let value_found = all_values.iter().any(|v| {
1283 v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
1284 });
1285
1286 prop_assert!(
1287 value_found,
1288 "Header '{}' value '{}' not found in extracted values",
1289 name,
1290 value
1291 );
1292 }
1293
1294 Ok(())
1295 })();
1296 result?;
1297 }
1298 }
1299
1300 proptest! {
1307 #![proptest_config(ProptestConfig::with_cases(100))]
1308
1309 #[test]
1310 fn prop_header_value_extractor_correctness(
1311 header_name in "[a-z][a-z0-9-]{0,20}",
1312 header_value in "[a-zA-Z0-9 ]{1,50}",
1313 has_header in prop::bool::ANY,
1314 ) {
1315 let result: Result<(), TestCaseError> = (|| {
1316 let headers = if has_header {
1317 vec![(header_name.as_str(), header_value.as_str())]
1318 } else {
1319 vec![]
1320 };
1321
1322 let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1323
1324 let test_header = "x-test-header";
1327 let request_with_known_header = if has_header {
1328 create_test_request_with_headers(
1329 Method::GET,
1330 "/test",
1331 vec![(test_header, header_value.as_str())],
1332 )
1333 } else {
1334 create_test_request_with_headers(Method::GET, "/test", vec![])
1335 };
1336
1337 let result = HeaderValue::extract(&request_with_known_header, test_header);
1338
1339 if has_header {
1340 let extracted = result
1341 .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1342 prop_assert_eq!(
1343 extracted.value(),
1344 header_value.as_str(),
1345 "Header value mismatch"
1346 );
1347 } else {
1348 prop_assert!(
1349 result.is_err(),
1350 "Expected error when header is missing"
1351 );
1352 }
1353
1354 Ok(())
1355 })();
1356 result?;
1357 }
1358 }
1359
1360 proptest! {
1367 #![proptest_config(ProptestConfig::with_cases(100))]
1368
1369 #[test]
1370 fn prop_client_ip_extractor_with_forwarding(
1371 forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1373 .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1374 socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1375 .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1376 has_forwarded_header in prop::bool::ANY,
1377 trust_proxy in prop::bool::ANY,
1378 ) {
1379 let result: Result<(), TestCaseError> = (|| {
1380 let headers = if has_forwarded_header {
1381 vec![("x-forwarded-for", forwarded_ip.as_str())]
1382 } else {
1383 vec![]
1384 };
1385
1386 let uri: http::Uri = "/test".parse().unwrap();
1388 let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1389 for (name, value) in &headers {
1390 builder = builder.header(*name, *value);
1391 }
1392 let req = builder.body(()).unwrap();
1393 let (mut parts, _) = req.into_parts();
1394
1395 let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1397 parts.extensions.insert(socket_addr);
1398
1399 let request = Request::new(
1400 parts,
1401 crate::request::BodyVariant::Buffered(Bytes::new()),
1402 Arc::new(Extensions::new()),
1403 PathParams::new(),
1404 );
1405
1406 let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1407 .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1408
1409 if trust_proxy && has_forwarded_header {
1410 let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1412 .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1413 prop_assert_eq!(
1414 extracted.0,
1415 expected_ip,
1416 "Should use X-Forwarded-For IP when trust_proxy is enabled"
1417 );
1418 } else {
1419 prop_assert_eq!(
1421 extracted.0,
1422 socket_ip,
1423 "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1424 );
1425 }
1426
1427 Ok(())
1428 })();
1429 result?;
1430 }
1431 }
1432
1433 proptest! {
1440 #![proptest_config(ProptestConfig::with_cases(100))]
1441
1442 #[test]
1443 fn prop_extension_extractor_retrieval(
1444 value in any::<i64>(),
1445 has_extension in prop::bool::ANY,
1446 ) {
1447 let result: Result<(), TestCaseError> = (|| {
1448 #[derive(Clone, Debug, PartialEq)]
1450 struct TestExtension(i64);
1451
1452 let uri: http::Uri = "/test".parse().unwrap();
1453 let builder = http::Request::builder().method(Method::GET).uri(uri);
1454 let req = builder.body(()).unwrap();
1455 let (mut parts, _) = req.into_parts();
1456
1457 if has_extension {
1458 parts.extensions.insert(TestExtension(value));
1459 }
1460
1461 let request = Request::new(
1462 parts,
1463 crate::request::BodyVariant::Buffered(Bytes::new()),
1464 Arc::new(Extensions::new()),
1465 PathParams::new(),
1466 );
1467
1468 let result = Extension::<TestExtension>::from_request_parts(&request);
1469
1470 if has_extension {
1471 let extracted = result
1472 .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1473 prop_assert_eq!(
1474 extracted.0,
1475 TestExtension(value),
1476 "Extension value mismatch"
1477 );
1478 } else {
1479 prop_assert!(
1480 result.is_err(),
1481 "Expected error when extension is missing"
1482 );
1483 }
1484
1485 Ok(())
1486 })();
1487 result?;
1488 }
1489 }
1490
1491 #[test]
1494 fn test_headers_extractor_basic() {
1495 let request = create_test_request_with_headers(
1496 Method::GET,
1497 "/test",
1498 vec![
1499 ("content-type", "application/json"),
1500 ("accept", "text/html"),
1501 ],
1502 );
1503
1504 let headers = Headers::from_request_parts(&request).unwrap();
1505
1506 assert!(headers.contains("content-type"));
1507 assert!(headers.contains("accept"));
1508 assert!(!headers.contains("x-custom"));
1509 assert_eq!(headers.len(), 2);
1510 }
1511
1512 #[test]
1513 fn test_header_value_extractor_present() {
1514 let request = create_test_request_with_headers(
1515 Method::GET,
1516 "/test",
1517 vec![("authorization", "Bearer token123")],
1518 );
1519
1520 let result = HeaderValue::extract(&request, "authorization");
1521 assert!(result.is_ok());
1522 assert_eq!(result.unwrap().value(), "Bearer token123");
1523 }
1524
1525 #[test]
1526 fn test_header_value_extractor_missing() {
1527 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1528
1529 let result = HeaderValue::extract(&request, "authorization");
1530 assert!(result.is_err());
1531 }
1532
1533 #[test]
1534 fn test_client_ip_from_forwarded_header() {
1535 let request = create_test_request_with_headers(
1536 Method::GET,
1537 "/test",
1538 vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1539 );
1540
1541 let ip = ClientIp::extract_with_config(&request, true).unwrap();
1542 assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1543 }
1544
1545 #[test]
1546 fn test_client_ip_ignores_forwarded_when_not_trusted() {
1547 let uri: http::Uri = "/test".parse().unwrap();
1548 let builder = http::Request::builder()
1549 .method(Method::GET)
1550 .uri(uri)
1551 .header("x-forwarded-for", "192.168.1.100");
1552 let req = builder.body(()).unwrap();
1553 let (mut parts, _) = req.into_parts();
1554
1555 let socket_addr = std::net::SocketAddr::new(
1556 std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1557 8080,
1558 );
1559 parts.extensions.insert(socket_addr);
1560
1561 let request = Request::new(
1562 parts,
1563 crate::request::BodyVariant::Buffered(Bytes::new()),
1564 Arc::new(Extensions::new()),
1565 PathParams::new(),
1566 );
1567
1568 let ip = ClientIp::extract_with_config(&request, false).unwrap();
1569 assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1570 }
1571
1572 #[test]
1573 fn test_extension_extractor_present() {
1574 #[derive(Clone, Debug, PartialEq)]
1575 struct MyData(String);
1576
1577 let request =
1578 create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1579
1580 let result = Extension::<MyData>::from_request_parts(&request);
1581 assert!(result.is_ok());
1582 assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1583 }
1584
1585 #[test]
1586 fn test_extension_extractor_missing() {
1587 #[derive(Clone, Debug)]
1588 #[allow(dead_code)]
1589 struct MyData(String);
1590
1591 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1592
1593 let result = Extension::<MyData>::from_request_parts(&request);
1594 assert!(result.is_err());
1595 }
1596
1597 #[cfg(feature = "cookies")]
1599 mod cookies_tests {
1600 use super::*;
1601
1602 proptest! {
1610 #![proptest_config(ProptestConfig::with_cases(100))]
1611
1612 #[test]
1613 fn prop_cookies_extractor_parsing(
1614 cookies in prop::collection::vec(
1617 (
1618 "[a-zA-Z][a-zA-Z0-9_]{0,15}", "[a-zA-Z0-9]{1,30}" ),
1621 0..5
1622 )
1623 ) {
1624 let result: Result<(), TestCaseError> = (|| {
1625 let cookie_header = cookies
1627 .iter()
1628 .map(|(name, value)| format!("{}={}", name, value))
1629 .collect::<Vec<_>>()
1630 .join("; ");
1631
1632 let headers = if !cookies.is_empty() {
1633 vec![("cookie", cookie_header.as_str())]
1634 } else {
1635 vec![]
1636 };
1637
1638 let request = create_test_request_with_headers(Method::GET, "/test", headers);
1639
1640 let extracted = Cookies::from_request_parts(&request)
1642 .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1643
1644 let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1646 for (name, value) in &cookies {
1647 expected_cookies.insert(name.as_str(), value.as_str());
1648 }
1649
1650 for (name, expected_value) in &expected_cookies {
1652 let cookie = extracted.get(name)
1653 .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1654
1655 prop_assert_eq!(
1656 cookie.value(),
1657 *expected_value,
1658 "Cookie '{}' value mismatch",
1659 name
1660 );
1661 }
1662
1663 let extracted_count = extracted.iter().count();
1665 prop_assert_eq!(
1666 extracted_count,
1667 expected_cookies.len(),
1668 "Expected {} unique cookies, got {}",
1669 expected_cookies.len(),
1670 extracted_count
1671 );
1672
1673 Ok(())
1674 })();
1675 result?;
1676 }
1677 }
1678
1679 #[test]
1680 fn test_cookies_extractor_basic() {
1681 let request = create_test_request_with_headers(
1682 Method::GET,
1683 "/test",
1684 vec![("cookie", "session=abc123; user=john")],
1685 );
1686
1687 let cookies = Cookies::from_request_parts(&request).unwrap();
1688
1689 assert!(cookies.contains("session"));
1690 assert!(cookies.contains("user"));
1691 assert!(!cookies.contains("other"));
1692
1693 assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1694 assert_eq!(cookies.get("user").unwrap().value(), "john");
1695 }
1696
1697 #[test]
1698 fn test_cookies_extractor_empty() {
1699 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1700
1701 let cookies = Cookies::from_request_parts(&request).unwrap();
1702 assert_eq!(cookies.iter().count(), 0);
1703 }
1704
1705 #[test]
1706 fn test_cookies_extractor_single() {
1707 let request = create_test_request_with_headers(
1708 Method::GET,
1709 "/test",
1710 vec![("cookie", "token=xyz789")],
1711 );
1712
1713 let cookies = Cookies::from_request_parts(&request).unwrap();
1714 assert_eq!(cookies.iter().count(), 1);
1715 assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1716 }
1717 }
1718}