1use crate::error::{ApiError, Result};
58use crate::json;
59use crate::request::Request;
60use crate::response::IntoResponse;
61use crate::stream::{StreamingBody, StreamingConfig};
62use crate::validation::Validatable;
63use bytes::Bytes;
64use http::{header, StatusCode};
65use rustapi_validate::v2::{AsyncValidate, ValidationContext};
66
67use rustapi_openapi::schema::{RustApiSchema, SchemaCtx, SchemaRef};
68use serde::de::DeserializeOwned;
69use serde::Serialize;
70use std::collections::BTreeMap;
71use std::future::Future;
72use std::ops::{Deref, DerefMut};
73use std::str::FromStr;
74
75pub trait FromRequestParts: Sized {
100 fn from_request_parts(req: &Request) -> Result<Self>;
102}
103
104pub trait FromRequest: Sized {
134 fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
136}
137
138impl<T: FromRequestParts> FromRequest for T {
140 async fn from_request(req: &mut Request) -> Result<Self> {
141 T::from_request_parts(req)
142 }
143}
144
145#[derive(Debug, Clone, Copy, Default)]
164pub struct Json<T>(pub T);
165
166impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
167 async fn from_request(req: &mut Request) -> Result<Self> {
168 req.load_body().await?;
169 let body = req
170 .take_body()
171 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
172
173 let value: T = json::from_slice(&body)?;
175 Ok(Json(value))
176 }
177}
178
179impl<T> Deref for Json<T> {
180 type Target = T;
181
182 fn deref(&self) -> &Self::Target {
183 &self.0
184 }
185}
186
187impl<T> DerefMut for Json<T> {
188 fn deref_mut(&mut self) -> &mut Self::Target {
189 &mut self.0
190 }
191}
192
193impl<T> From<T> for Json<T> {
194 fn from(value: T) -> Self {
195 Json(value)
196 }
197}
198
199const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
202
203impl<T: Serialize> IntoResponse for Json<T> {
205 fn into_response(self) -> crate::response::Response {
206 match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
208 Ok(body) => http::Response::builder()
209 .status(StatusCode::OK)
210 .header(header::CONTENT_TYPE, "application/json")
211 .body(crate::response::Body::from(body))
212 .unwrap(),
213 Err(err) => {
214 ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
215 }
216 }
217 }
218}
219
220#[derive(Debug, Clone, Copy, Default)]
246pub struct ValidatedJson<T>(pub T);
247
248impl<T> ValidatedJson<T> {
249 pub fn new(value: T) -> Self {
251 Self(value)
252 }
253
254 pub fn into_inner(self) -> T {
256 self.0
257 }
258}
259
260impl<T: DeserializeOwned + Validatable + Send> FromRequest for ValidatedJson<T> {
261 async fn from_request(req: &mut Request) -> Result<Self> {
262 req.load_body().await?;
263 let body = req
265 .take_body()
266 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
267
268 let value: T = json::from_slice(&body)?;
269
270 value.do_validate()?;
272
273 Ok(ValidatedJson(value))
274 }
275}
276
277impl<T> Deref for ValidatedJson<T> {
278 type Target = T;
279
280 fn deref(&self) -> &Self::Target {
281 &self.0
282 }
283}
284
285impl<T> DerefMut for ValidatedJson<T> {
286 fn deref_mut(&mut self) -> &mut Self::Target {
287 &mut self.0
288 }
289}
290
291impl<T> From<T> for ValidatedJson<T> {
292 fn from(value: T) -> Self {
293 ValidatedJson(value)
294 }
295}
296
297impl<T: Serialize> IntoResponse for ValidatedJson<T> {
298 fn into_response(self) -> crate::response::Response {
299 Json(self.0).into_response()
300 }
301}
302
303#[derive(Debug, Clone, Copy, Default)]
330pub struct AsyncValidatedJson<T>(pub T);
331
332impl<T> AsyncValidatedJson<T> {
333 pub fn new(value: T) -> Self {
335 Self(value)
336 }
337
338 pub fn into_inner(self) -> T {
340 self.0
341 }
342}
343
344impl<T> Deref for AsyncValidatedJson<T> {
345 type Target = T;
346
347 fn deref(&self) -> &Self::Target {
348 &self.0
349 }
350}
351
352impl<T> DerefMut for AsyncValidatedJson<T> {
353 fn deref_mut(&mut self) -> &mut Self::Target {
354 &mut self.0
355 }
356}
357
358impl<T> From<T> for AsyncValidatedJson<T> {
359 fn from(value: T) -> Self {
360 AsyncValidatedJson(value)
361 }
362}
363
364impl<T: Serialize> IntoResponse for AsyncValidatedJson<T> {
365 fn into_response(self) -> crate::response::Response {
366 Json(self.0).into_response()
367 }
368}
369
370impl<T: DeserializeOwned + AsyncValidate + Send + Sync> FromRequest for AsyncValidatedJson<T> {
371 async fn from_request(req: &mut Request) -> Result<Self> {
372 req.load_body().await?;
373
374 let body = req
375 .take_body()
376 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
377
378 let value: T = json::from_slice(&body)?;
379
380 let ctx = if let Some(ctx) = req.state().get::<ValidationContext>() {
383 ctx.clone()
384 } else {
385 ValidationContext::default()
386 };
387
388 if let Err(errors) = value.validate_full(&ctx).await {
390 let field_errors: Vec<crate::error::FieldError> = errors
392 .fields
393 .iter()
394 .flat_map(|(field, errs)| {
395 let field_name = field.to_string();
396 errs.iter().map(move |e| crate::error::FieldError {
397 field: field_name.clone(),
398 code: e.code.to_string(),
399 message: e.message.clone(),
400 })
401 })
402 .collect();
403
404 return Err(ApiError::validation(field_errors));
405 }
406
407 Ok(AsyncValidatedJson(value))
408 }
409}
410
411#[derive(Debug, Clone)]
429pub struct Query<T>(pub T);
430
431impl<T: DeserializeOwned> FromRequestParts for Query<T> {
432 fn from_request_parts(req: &Request) -> Result<Self> {
433 let query = req.query_string().unwrap_or("");
434 let value: T = serde_urlencoded::from_str(query)
435 .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
436 Ok(Query(value))
437 }
438}
439
440impl<T> Deref for Query<T> {
441 type Target = T;
442
443 fn deref(&self) -> &Self::Target {
444 &self.0
445 }
446}
447
448#[derive(Debug, Clone)]
470pub struct Path<T>(pub T);
471
472impl<T: FromStr> FromRequestParts for Path<T>
473where
474 T::Err: std::fmt::Display,
475{
476 fn from_request_parts(req: &Request) -> Result<Self> {
477 let params = req.path_params();
478
479 if let Some((_, value)) = params.iter().next() {
481 let parsed = value
482 .parse::<T>()
483 .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
484 return Ok(Path(parsed));
485 }
486
487 Err(ApiError::internal("Missing path parameter"))
488 }
489}
490
491impl<T> Deref for Path<T> {
492 type Target = T;
493
494 fn deref(&self) -> &Self::Target {
495 &self.0
496 }
497}
498
499#[derive(Debug, Clone)]
519pub struct Typed<T>(pub T);
520
521impl<T: DeserializeOwned + Send> FromRequestParts for Typed<T> {
522 fn from_request_parts(req: &Request) -> Result<Self> {
523 let params = req.path_params();
524 let mut map = serde_json::Map::new();
525 for (k, v) in params.iter() {
526 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
527 }
528 let value = serde_json::Value::Object(map);
529 let parsed: T = serde_json::from_value(value)
530 .map_err(|e| ApiError::bad_request(format!("Invalid path parameters: {}", e)))?;
531 Ok(Typed(parsed))
532 }
533}
534
535impl<T> Deref for Typed<T> {
536 type Target = T;
537
538 fn deref(&self) -> &Self::Target {
539 &self.0
540 }
541}
542
543#[derive(Debug, Clone)]
560pub struct State<T>(pub T);
561
562impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
563 fn from_request_parts(req: &Request) -> Result<Self> {
564 req.state().get::<T>().cloned().map(State).ok_or_else(|| {
565 ApiError::internal(format!(
566 "State of type `{}` not found. Did you forget to call .state()?",
567 std::any::type_name::<T>()
568 ))
569 })
570 }
571}
572
573impl<T> Deref for State<T> {
574 type Target = T;
575
576 fn deref(&self) -> &Self::Target {
577 &self.0
578 }
579}
580
581#[derive(Debug, Clone)]
583pub struct Body(pub Bytes);
584
585impl FromRequest for Body {
586 async fn from_request(req: &mut Request) -> Result<Self> {
587 req.load_body().await?;
588 let body = req
589 .take_body()
590 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
591 Ok(Body(body))
592 }
593}
594
595impl Deref for Body {
596 type Target = Bytes;
597
598 fn deref(&self) -> &Self::Target {
599 &self.0
600 }
601}
602
603pub struct BodyStream(pub StreamingBody);
605
606impl FromRequest for BodyStream {
607 async fn from_request(req: &mut Request) -> Result<Self> {
608 let config = StreamingConfig::default();
609
610 if let Some(stream) = req.take_stream() {
611 Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
612 } else if let Some(bytes) = req.take_body() {
613 let stream = futures_util::stream::once(async move { Ok(bytes) });
615 Ok(BodyStream(StreamingBody::from_stream(
616 stream,
617 config.max_body_size,
618 )))
619 } else {
620 Err(ApiError::internal("Body already consumed"))
621 }
622 }
623}
624
625impl Deref for BodyStream {
626 type Target = StreamingBody;
627
628 fn deref(&self) -> &Self::Target {
629 &self.0
630 }
631}
632
633impl DerefMut for BodyStream {
634 fn deref_mut(&mut self) -> &mut Self::Target {
635 &mut self.0
636 }
637}
638
639impl futures_util::Stream for BodyStream {
641 type Item = Result<Bytes, ApiError>;
642
643 fn poll_next(
644 mut self: std::pin::Pin<&mut Self>,
645 cx: &mut std::task::Context<'_>,
646 ) -> std::task::Poll<Option<Self::Item>> {
647 std::pin::Pin::new(&mut self.0).poll_next(cx)
648 }
649}
650
651impl<T: FromRequestParts> FromRequestParts for Option<T> {
655 fn from_request_parts(req: &Request) -> Result<Self> {
656 Ok(T::from_request_parts(req).ok())
657 }
658}
659
660#[derive(Debug, Clone)]
678pub struct Headers(pub http::HeaderMap);
679
680impl Headers {
681 pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
683 self.0.get(name)
684 }
685
686 pub fn contains(&self, name: &str) -> bool {
688 self.0.contains_key(name)
689 }
690
691 pub fn len(&self) -> usize {
693 self.0.len()
694 }
695
696 pub fn is_empty(&self) -> bool {
698 self.0.is_empty()
699 }
700
701 pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
703 self.0.iter()
704 }
705}
706
707impl FromRequestParts for Headers {
708 fn from_request_parts(req: &Request) -> Result<Self> {
709 Ok(Headers(req.headers().clone()))
710 }
711}
712
713impl Deref for Headers {
714 type Target = http::HeaderMap;
715
716 fn deref(&self) -> &Self::Target {
717 &self.0
718 }
719}
720
721#[derive(Debug, Clone)]
740pub struct HeaderValue(pub String, pub &'static str);
741
742impl HeaderValue {
743 pub fn new(name: &'static str, value: String) -> Self {
745 Self(value, name)
746 }
747
748 pub fn value(&self) -> &str {
750 &self.0
751 }
752
753 pub fn name(&self) -> &'static str {
755 self.1
756 }
757
758 pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
760 req.headers()
761 .get(name)
762 .and_then(|v| v.to_str().ok())
763 .map(|s| HeaderValue(s.to_string(), name))
764 .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
765 }
766}
767
768impl Deref for HeaderValue {
769 type Target = String;
770
771 fn deref(&self) -> &Self::Target {
772 &self.0
773 }
774}
775
776#[derive(Debug, Clone)]
794pub struct Extension<T>(pub T);
795
796impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
797 fn from_request_parts(req: &Request) -> Result<Self> {
798 req.extensions()
799 .get::<T>()
800 .cloned()
801 .map(Extension)
802 .ok_or_else(|| {
803 ApiError::internal(format!(
804 "Extension of type `{}` not found. Did middleware insert it?",
805 std::any::type_name::<T>()
806 ))
807 })
808 }
809}
810
811impl<T> Deref for Extension<T> {
812 type Target = T;
813
814 fn deref(&self) -> &Self::Target {
815 &self.0
816 }
817}
818
819impl<T> DerefMut for Extension<T> {
820 fn deref_mut(&mut self) -> &mut Self::Target {
821 &mut self.0
822 }
823}
824
825#[derive(Debug, Clone)]
840pub struct ClientIp(pub std::net::IpAddr);
841
842impl ClientIp {
843 pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
845 if trust_proxy {
846 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
848 if let Ok(forwarded_str) = forwarded.to_str() {
849 if let Some(first_ip) = forwarded_str.split(',').next() {
851 if let Ok(ip) = first_ip.trim().parse() {
852 return Ok(ClientIp(ip));
853 }
854 }
855 }
856 }
857 }
858
859 if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
861 return Ok(ClientIp(addr.ip()));
862 }
863
864 Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
866 127, 0, 0, 1,
867 ))))
868 }
869}
870
871impl FromRequestParts for ClientIp {
872 fn from_request_parts(req: &Request) -> Result<Self> {
873 Self::extract_with_config(req, true)
875 }
876}
877
878#[cfg(feature = "cookies")]
896#[derive(Debug, Clone)]
897pub struct Cookies(pub cookie::CookieJar);
898
899#[cfg(feature = "cookies")]
900impl Cookies {
901 pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
903 self.0.get(name)
904 }
905
906 pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
908 self.0.iter()
909 }
910
911 pub fn contains(&self, name: &str) -> bool {
913 self.0.get(name).is_some()
914 }
915}
916
917#[cfg(feature = "cookies")]
918impl FromRequestParts for Cookies {
919 fn from_request_parts(req: &Request) -> Result<Self> {
920 let mut jar = cookie::CookieJar::new();
921
922 if let Some(cookie_header) = req.headers().get(header::COOKIE) {
923 if let Ok(cookie_str) = cookie_header.to_str() {
924 for cookie_part in cookie_str.split(';') {
926 let trimmed = cookie_part.trim();
927 if !trimmed.is_empty() {
928 if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
929 jar.add_original(cookie.into_owned());
930 }
931 }
932 }
933 }
934 }
935
936 Ok(Cookies(jar))
937 }
938}
939
940#[cfg(feature = "cookies")]
941impl Deref for Cookies {
942 type Target = cookie::CookieJar;
943
944 fn deref(&self) -> &Self::Target {
945 &self.0
946 }
947}
948
949macro_rules! impl_from_request_parts_for_primitives {
951 ($($ty:ty),*) => {
952 $(
953 impl FromRequestParts for $ty {
954 fn from_request_parts(req: &Request) -> Result<Self> {
955 let Path(value) = Path::<$ty>::from_request_parts(req)?;
956 Ok(value)
957 }
958 }
959 )*
960 };
961}
962
963impl_from_request_parts_for_primitives!(
964 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
965);
966
967use rustapi_openapi::{
970 MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier, ResponseSpec,
971};
972
973impl<T: RustApiSchema> OperationModifier for ValidatedJson<T> {
975 fn update_operation(op: &mut Operation) {
976 let mut ctx = SchemaCtx::new();
977 let schema_ref = T::schema(&mut ctx);
978
979 let mut content = BTreeMap::new();
980 content.insert(
981 "application/json".to_string(),
982 MediaType {
983 schema: Some(schema_ref),
984 example: None,
985 },
986 );
987
988 op.request_body = Some(RequestBody {
989 description: None,
990 required: Some(true),
991 content,
992 });
993
994 let mut responses_content = BTreeMap::new();
996 responses_content.insert(
997 "application/json".to_string(),
998 MediaType {
999 schema: Some(SchemaRef::Ref {
1000 reference: "#/components/schemas/ValidationErrorSchema".to_string(),
1001 }),
1002 example: None,
1003 },
1004 );
1005
1006 op.responses.insert(
1007 "422".to_string(),
1008 ResponseSpec {
1009 description: "Validation Error".to_string(),
1010 content: responses_content,
1011 headers: BTreeMap::new(),
1012 },
1013 );
1014 }
1015}
1016
1017impl<T: RustApiSchema> OperationModifier for Json<T> {
1019 fn update_operation(op: &mut Operation) {
1020 let mut ctx = SchemaCtx::new();
1021 let schema_ref = T::schema(&mut ctx);
1022
1023 let mut content = BTreeMap::new();
1024 content.insert(
1025 "application/json".to_string(),
1026 MediaType {
1027 schema: Some(schema_ref),
1028 example: None,
1029 },
1030 );
1031
1032 op.request_body = Some(RequestBody {
1033 description: None,
1034 required: Some(true),
1035 content,
1036 });
1037 }
1038}
1039
1040impl<T> OperationModifier for Path<T> {
1042 fn update_operation(_op: &mut Operation) {}
1043}
1044
1045impl<T> OperationModifier for Typed<T> {
1047 fn update_operation(_op: &mut Operation) {}
1048}
1049
1050impl<T: RustApiSchema> OperationModifier for Query<T> {
1052 fn update_operation(op: &mut Operation) {
1053 let mut ctx = SchemaCtx::new();
1054 if let Some(fields) = T::field_schemas(&mut ctx) {
1055 let new_params: Vec<Parameter> = fields
1056 .into_iter()
1057 .map(|(name, schema)| {
1058 Parameter {
1059 name,
1060 location: "query".to_string(),
1061 required: false, deprecated: None,
1063 description: None,
1064 schema: Some(schema),
1065 }
1066 })
1067 .collect();
1068
1069 op.parameters.extend(new_params);
1070 }
1071 }
1072}
1073
1074impl<T> OperationModifier for State<T> {
1076 fn update_operation(_op: &mut Operation) {}
1077}
1078
1079impl OperationModifier for Body {
1081 fn update_operation(op: &mut Operation) {
1082 let mut content = BTreeMap::new();
1083 content.insert(
1084 "application/octet-stream".to_string(),
1085 MediaType {
1086 schema: Some(SchemaRef::Inline(
1087 serde_json::json!({ "type": "string", "format": "binary" }),
1088 )),
1089 example: None,
1090 },
1091 );
1092
1093 op.request_body = Some(RequestBody {
1094 description: None,
1095 required: Some(true),
1096 content,
1097 });
1098 }
1099}
1100
1101impl OperationModifier for BodyStream {
1103 fn update_operation(op: &mut Operation) {
1104 let mut content = BTreeMap::new();
1105 content.insert(
1106 "application/octet-stream".to_string(),
1107 MediaType {
1108 schema: Some(SchemaRef::Inline(
1109 serde_json::json!({ "type": "string", "format": "binary" }),
1110 )),
1111 example: None,
1112 },
1113 );
1114
1115 op.request_body = Some(RequestBody {
1116 description: None,
1117 required: Some(true),
1118 content,
1119 });
1120 }
1121}
1122
1123impl<T: RustApiSchema> ResponseModifier for Json<T> {
1127 fn update_response(op: &mut Operation) {
1128 let mut ctx = SchemaCtx::new();
1129 let schema_ref = T::schema(&mut ctx);
1130
1131 let mut content = BTreeMap::new();
1132 content.insert(
1133 "application/json".to_string(),
1134 MediaType {
1135 schema: Some(schema_ref),
1136 example: None,
1137 },
1138 );
1139
1140 op.responses.insert(
1141 "200".to_string(),
1142 ResponseSpec {
1143 description: "Successful response".to_string(),
1144 content,
1145 headers: BTreeMap::new(),
1146 },
1147 );
1148 }
1149}
1150
1151impl<T: RustApiSchema> RustApiSchema for Json<T> {
1154 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1155 T::schema(ctx)
1156 }
1157}
1158
1159impl<T: RustApiSchema> RustApiSchema for ValidatedJson<T> {
1160 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1161 T::schema(ctx)
1162 }
1163}
1164
1165impl<T: RustApiSchema> RustApiSchema for AsyncValidatedJson<T> {
1166 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1167 T::schema(ctx)
1168 }
1169}
1170
1171impl<T: RustApiSchema> RustApiSchema for Query<T> {
1172 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1173 T::schema(ctx)
1174 }
1175 fn field_schemas(ctx: &mut SchemaCtx) -> Option<BTreeMap<String, SchemaRef>> {
1176 T::field_schemas(ctx)
1177 }
1178}
1179
1180#[cfg(test)]
1181mod tests {
1182 use super::*;
1183 use crate::path_params::PathParams;
1184 use bytes::Bytes;
1185 use http::{Extensions, Method};
1186 use proptest::prelude::*;
1187 use proptest::test_runner::TestCaseError;
1188 use std::sync::Arc;
1189
1190 fn create_test_request_with_headers(
1192 method: Method,
1193 path: &str,
1194 headers: Vec<(&str, &str)>,
1195 ) -> Request {
1196 let uri: http::Uri = path.parse().unwrap();
1197 let mut builder = http::Request::builder().method(method).uri(uri);
1198
1199 for (name, value) in headers {
1200 builder = builder.header(name, value);
1201 }
1202
1203 let req = builder.body(()).unwrap();
1204 let (parts, _) = req.into_parts();
1205
1206 Request::new(
1207 parts,
1208 crate::request::BodyVariant::Buffered(Bytes::new()),
1209 Arc::new(Extensions::new()),
1210 PathParams::new(),
1211 )
1212 }
1213
1214 fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
1216 method: Method,
1217 path: &str,
1218 extension: T,
1219 ) -> Request {
1220 let uri: http::Uri = path.parse().unwrap();
1221 let builder = http::Request::builder().method(method).uri(uri);
1222
1223 let req = builder.body(()).unwrap();
1224 let (mut parts, _) = req.into_parts();
1225 parts.extensions.insert(extension);
1226
1227 Request::new(
1228 parts,
1229 crate::request::BodyVariant::Buffered(Bytes::new()),
1230 Arc::new(Extensions::new()),
1231 PathParams::new(),
1232 )
1233 }
1234
1235 proptest! {
1242 #![proptest_config(ProptestConfig::with_cases(100))]
1243
1244 #[test]
1245 fn prop_headers_extractor_completeness(
1246 headers in prop::collection::vec(
1249 (
1250 "[a-z][a-z0-9-]{0,20}", "[a-zA-Z0-9 ]{1,50}" ),
1253 0..10
1254 )
1255 ) {
1256 let result: Result<(), TestCaseError> = (|| {
1257 let header_tuples: Vec<(&str, &str)> = headers
1259 .iter()
1260 .map(|(k, v)| (k.as_str(), v.as_str()))
1261 .collect();
1262
1263 let request = create_test_request_with_headers(
1265 Method::GET,
1266 "/test",
1267 header_tuples.clone(),
1268 );
1269
1270 let extracted = Headers::from_request_parts(&request)
1272 .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
1273
1274 for (name, value) in &headers {
1277 let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
1279 prop_assert!(
1280 !all_values.is_empty(),
1281 "Header '{}' not found",
1282 name
1283 );
1284
1285 let value_found = all_values.iter().any(|v| {
1287 v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
1288 });
1289
1290 prop_assert!(
1291 value_found,
1292 "Header '{}' value '{}' not found in extracted values",
1293 name,
1294 value
1295 );
1296 }
1297
1298 Ok(())
1299 })();
1300 result?;
1301 }
1302 }
1303
1304 proptest! {
1311 #![proptest_config(ProptestConfig::with_cases(100))]
1312
1313 #[test]
1314 fn prop_header_value_extractor_correctness(
1315 header_name in "[a-z][a-z0-9-]{0,20}",
1316 header_value in "[a-zA-Z0-9 ]{1,50}",
1317 has_header in prop::bool::ANY,
1318 ) {
1319 let result: Result<(), TestCaseError> = (|| {
1320 let headers = if has_header {
1321 vec![(header_name.as_str(), header_value.as_str())]
1322 } else {
1323 vec![]
1324 };
1325
1326 let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1327
1328 let test_header = "x-test-header";
1331 let request_with_known_header = if has_header {
1332 create_test_request_with_headers(
1333 Method::GET,
1334 "/test",
1335 vec![(test_header, header_value.as_str())],
1336 )
1337 } else {
1338 create_test_request_with_headers(Method::GET, "/test", vec![])
1339 };
1340
1341 let result = HeaderValue::extract(&request_with_known_header, test_header);
1342
1343 if has_header {
1344 let extracted = result
1345 .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1346 prop_assert_eq!(
1347 extracted.value(),
1348 header_value.as_str(),
1349 "Header value mismatch"
1350 );
1351 } else {
1352 prop_assert!(
1353 result.is_err(),
1354 "Expected error when header is missing"
1355 );
1356 }
1357
1358 Ok(())
1359 })();
1360 result?;
1361 }
1362 }
1363
1364 proptest! {
1371 #![proptest_config(ProptestConfig::with_cases(100))]
1372
1373 #[test]
1374 fn prop_client_ip_extractor_with_forwarding(
1375 forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1377 .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1378 socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1379 .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1380 has_forwarded_header in prop::bool::ANY,
1381 trust_proxy in prop::bool::ANY,
1382 ) {
1383 let result: Result<(), TestCaseError> = (|| {
1384 let headers = if has_forwarded_header {
1385 vec![("x-forwarded-for", forwarded_ip.as_str())]
1386 } else {
1387 vec![]
1388 };
1389
1390 let uri: http::Uri = "/test".parse().unwrap();
1392 let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1393 for (name, value) in &headers {
1394 builder = builder.header(*name, *value);
1395 }
1396 let req = builder.body(()).unwrap();
1397 let (mut parts, _) = req.into_parts();
1398
1399 let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1401 parts.extensions.insert(socket_addr);
1402
1403 let request = Request::new(
1404 parts,
1405 crate::request::BodyVariant::Buffered(Bytes::new()),
1406 Arc::new(Extensions::new()),
1407 PathParams::new(),
1408 );
1409
1410 let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1411 .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1412
1413 if trust_proxy && has_forwarded_header {
1414 let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1416 .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1417 prop_assert_eq!(
1418 extracted.0,
1419 expected_ip,
1420 "Should use X-Forwarded-For IP when trust_proxy is enabled"
1421 );
1422 } else {
1423 prop_assert_eq!(
1425 extracted.0,
1426 socket_ip,
1427 "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1428 );
1429 }
1430
1431 Ok(())
1432 })();
1433 result?;
1434 }
1435 }
1436
1437 proptest! {
1444 #![proptest_config(ProptestConfig::with_cases(100))]
1445
1446 #[test]
1447 fn prop_extension_extractor_retrieval(
1448 value in any::<i64>(),
1449 has_extension in prop::bool::ANY,
1450 ) {
1451 let result: Result<(), TestCaseError> = (|| {
1452 #[derive(Clone, Debug, PartialEq)]
1454 struct TestExtension(i64);
1455
1456 let uri: http::Uri = "/test".parse().unwrap();
1457 let builder = http::Request::builder().method(Method::GET).uri(uri);
1458 let req = builder.body(()).unwrap();
1459 let (mut parts, _) = req.into_parts();
1460
1461 if has_extension {
1462 parts.extensions.insert(TestExtension(value));
1463 }
1464
1465 let request = Request::new(
1466 parts,
1467 crate::request::BodyVariant::Buffered(Bytes::new()),
1468 Arc::new(Extensions::new()),
1469 PathParams::new(),
1470 );
1471
1472 let result = Extension::<TestExtension>::from_request_parts(&request);
1473
1474 if has_extension {
1475 let extracted = result
1476 .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1477 prop_assert_eq!(
1478 extracted.0,
1479 TestExtension(value),
1480 "Extension value mismatch"
1481 );
1482 } else {
1483 prop_assert!(
1484 result.is_err(),
1485 "Expected error when extension is missing"
1486 );
1487 }
1488
1489 Ok(())
1490 })();
1491 result?;
1492 }
1493 }
1494
1495 #[test]
1498 fn test_headers_extractor_basic() {
1499 let request = create_test_request_with_headers(
1500 Method::GET,
1501 "/test",
1502 vec![
1503 ("content-type", "application/json"),
1504 ("accept", "text/html"),
1505 ],
1506 );
1507
1508 let headers = Headers::from_request_parts(&request).unwrap();
1509
1510 assert!(headers.contains("content-type"));
1511 assert!(headers.contains("accept"));
1512 assert!(!headers.contains("x-custom"));
1513 assert_eq!(headers.len(), 2);
1514 }
1515
1516 #[test]
1517 fn test_header_value_extractor_present() {
1518 let request = create_test_request_with_headers(
1519 Method::GET,
1520 "/test",
1521 vec![("authorization", "Bearer token123")],
1522 );
1523
1524 let result = HeaderValue::extract(&request, "authorization");
1525 assert!(result.is_ok());
1526 assert_eq!(result.unwrap().value(), "Bearer token123");
1527 }
1528
1529 #[test]
1530 fn test_header_value_extractor_missing() {
1531 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1532
1533 let result = HeaderValue::extract(&request, "authorization");
1534 assert!(result.is_err());
1535 }
1536
1537 #[test]
1538 fn test_client_ip_from_forwarded_header() {
1539 let request = create_test_request_with_headers(
1540 Method::GET,
1541 "/test",
1542 vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1543 );
1544
1545 let ip = ClientIp::extract_with_config(&request, true).unwrap();
1546 assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1547 }
1548
1549 #[test]
1550 fn test_client_ip_ignores_forwarded_when_not_trusted() {
1551 let uri: http::Uri = "/test".parse().unwrap();
1552 let builder = http::Request::builder()
1553 .method(Method::GET)
1554 .uri(uri)
1555 .header("x-forwarded-for", "192.168.1.100");
1556 let req = builder.body(()).unwrap();
1557 let (mut parts, _) = req.into_parts();
1558
1559 let socket_addr = std::net::SocketAddr::new(
1560 std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1561 8080,
1562 );
1563 parts.extensions.insert(socket_addr);
1564
1565 let request = Request::new(
1566 parts,
1567 crate::request::BodyVariant::Buffered(Bytes::new()),
1568 Arc::new(Extensions::new()),
1569 PathParams::new(),
1570 );
1571
1572 let ip = ClientIp::extract_with_config(&request, false).unwrap();
1573 assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1574 }
1575
1576 #[test]
1577 fn test_extension_extractor_present() {
1578 #[derive(Clone, Debug, PartialEq)]
1579 struct MyData(String);
1580
1581 let request =
1582 create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1583
1584 let result = Extension::<MyData>::from_request_parts(&request);
1585 assert!(result.is_ok());
1586 assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1587 }
1588
1589 #[test]
1590 fn test_extension_extractor_missing() {
1591 #[derive(Clone, Debug)]
1592 #[allow(dead_code)]
1593 struct MyData(String);
1594
1595 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1596
1597 let result = Extension::<MyData>::from_request_parts(&request);
1598 assert!(result.is_err());
1599 }
1600
1601 #[cfg(feature = "cookies")]
1603 mod cookies_tests {
1604 use super::*;
1605
1606 proptest! {
1614 #![proptest_config(ProptestConfig::with_cases(100))]
1615
1616 #[test]
1617 fn prop_cookies_extractor_parsing(
1618 cookies in prop::collection::vec(
1621 (
1622 "[a-zA-Z][a-zA-Z0-9_]{0,15}", "[a-zA-Z0-9]{1,30}" ),
1625 0..5
1626 )
1627 ) {
1628 let result: Result<(), TestCaseError> = (|| {
1629 let cookie_header = cookies
1631 .iter()
1632 .map(|(name, value)| format!("{}={}", name, value))
1633 .collect::<Vec<_>>()
1634 .join("; ");
1635
1636 let headers = if !cookies.is_empty() {
1637 vec![("cookie", cookie_header.as_str())]
1638 } else {
1639 vec![]
1640 };
1641
1642 let request = create_test_request_with_headers(Method::GET, "/test", headers);
1643
1644 let extracted = Cookies::from_request_parts(&request)
1646 .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1647
1648 let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1650 for (name, value) in &cookies {
1651 expected_cookies.insert(name.as_str(), value.as_str());
1652 }
1653
1654 for (name, expected_value) in &expected_cookies {
1656 let cookie = extracted.get(name)
1657 .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1658
1659 prop_assert_eq!(
1660 cookie.value(),
1661 *expected_value,
1662 "Cookie '{}' value mismatch",
1663 name
1664 );
1665 }
1666
1667 let extracted_count = extracted.iter().count();
1669 prop_assert_eq!(
1670 extracted_count,
1671 expected_cookies.len(),
1672 "Expected {} unique cookies, got {}",
1673 expected_cookies.len(),
1674 extracted_count
1675 );
1676
1677 Ok(())
1678 })();
1679 result?;
1680 }
1681 }
1682
1683 #[test]
1684 fn test_cookies_extractor_basic() {
1685 let request = create_test_request_with_headers(
1686 Method::GET,
1687 "/test",
1688 vec![("cookie", "session=abc123; user=john")],
1689 );
1690
1691 let cookies = Cookies::from_request_parts(&request).unwrap();
1692
1693 assert!(cookies.contains("session"));
1694 assert!(cookies.contains("user"));
1695 assert!(!cookies.contains("other"));
1696
1697 assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1698 assert_eq!(cookies.get("user").unwrap().value(), "john");
1699 }
1700
1701 #[test]
1702 fn test_cookies_extractor_empty() {
1703 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1704
1705 let cookies = Cookies::from_request_parts(&request).unwrap();
1706 assert_eq!(cookies.iter().count(), 0);
1707 }
1708
1709 #[test]
1710 fn test_cookies_extractor_single() {
1711 let request = create_test_request_with_headers(
1712 Method::GET,
1713 "/test",
1714 vec![("cookie", "token=xyz789")],
1715 );
1716
1717 let cookies = Cookies::from_request_parts(&request).unwrap();
1718 assert_eq!(cookies.iter().count(), 1);
1719 assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1720 }
1721 }
1722
1723 #[tokio::test]
1724 async fn test_async_validated_json_with_state_context() {
1725 use async_trait::async_trait;
1726 use rustapi_validate::prelude::*;
1727 use rustapi_validate::v2::{
1728 AsyncValidationRule, DatabaseValidator, ValidationContextBuilder,
1729 };
1730 use serde::{Deserialize, Serialize};
1731
1732 struct MockDbValidator {
1733 unique_values: Vec<String>,
1734 }
1735
1736 #[async_trait]
1737 impl DatabaseValidator for MockDbValidator {
1738 async fn exists(
1739 &self,
1740 _table: &str,
1741 _column: &str,
1742 _value: &str,
1743 ) -> Result<bool, String> {
1744 Ok(true)
1745 }
1746 async fn is_unique(
1747 &self,
1748 _table: &str,
1749 _column: &str,
1750 value: &str,
1751 ) -> Result<bool, String> {
1752 Ok(!self.unique_values.contains(&value.to_string()))
1753 }
1754 async fn is_unique_except(
1755 &self,
1756 _table: &str,
1757 _column: &str,
1758 value: &str,
1759 _except_id: &str,
1760 ) -> Result<bool, String> {
1761 Ok(!self.unique_values.contains(&value.to_string()))
1762 }
1763 }
1764
1765 #[derive(Debug, Deserialize, Serialize)]
1766 struct TestUser {
1767 email: String,
1768 }
1769
1770 impl Validate for TestUser {
1771 fn validate_with_group(
1772 &self,
1773 _group: rustapi_validate::v2::ValidationGroup,
1774 ) -> Result<(), rustapi_validate::v2::ValidationErrors> {
1775 Ok(())
1776 }
1777 }
1778
1779 #[async_trait]
1780 impl AsyncValidate for TestUser {
1781 async fn validate_async_with_group(
1782 &self,
1783 ctx: &ValidationContext,
1784 _group: rustapi_validate::v2::ValidationGroup,
1785 ) -> Result<(), rustapi_validate::v2::ValidationErrors> {
1786 let mut errors = rustapi_validate::v2::ValidationErrors::new();
1787
1788 let rule = AsyncUniqueRule::new("users", "email");
1789 if let Err(e) = rule.validate_async(&self.email, ctx).await {
1790 errors.add("email", e);
1791 }
1792
1793 errors.into_result()
1794 }
1795 }
1796
1797 let uri: http::Uri = "/test".parse().unwrap();
1799 let user = TestUser {
1800 email: "new@example.com".to_string(),
1801 };
1802 let body_bytes = serde_json::to_vec(&user).unwrap();
1803
1804 let builder = http::Request::builder()
1805 .method(Method::POST)
1806 .uri(uri.clone())
1807 .header("content-type", "application/json");
1808 let req = builder.body(()).unwrap();
1809 let (parts, _) = req.into_parts();
1810
1811 let mut request = Request::new(
1813 parts,
1814 crate::request::BodyVariant::Buffered(Bytes::from(body_bytes.clone())),
1815 Arc::new(Extensions::new()),
1816 PathParams::new(),
1817 );
1818
1819 let result = AsyncValidatedJson::<TestUser>::from_request(&mut request).await;
1820
1821 assert!(result.is_err(), "Expected error when validator is missing");
1822 let err = result.unwrap_err();
1823 let err_str = format!("{:?}", err);
1824 assert!(
1825 err_str.contains("Database validator not configured")
1826 || err_str.contains("async_unique"),
1827 "Error should mention missing configuration or rule: {:?}",
1828 err_str
1829 );
1830
1831 let db_validator = MockDbValidator {
1833 unique_values: vec!["taken@example.com".to_string()],
1834 };
1835 let ctx = ValidationContextBuilder::new()
1836 .database(db_validator)
1837 .build();
1838
1839 let mut extensions = Extensions::new();
1840 extensions.insert(ctx);
1841
1842 let builder = http::Request::builder()
1843 .method(Method::POST)
1844 .uri(uri.clone())
1845 .header("content-type", "application/json");
1846 let req = builder.body(()).unwrap();
1847 let (parts, _) = req.into_parts();
1848
1849 let mut request = Request::new(
1850 parts,
1851 crate::request::BodyVariant::Buffered(Bytes::from(body_bytes.clone())),
1852 Arc::new(extensions),
1853 PathParams::new(),
1854 );
1855
1856 let result = AsyncValidatedJson::<TestUser>::from_request(&mut request).await;
1857 assert!(
1858 result.is_ok(),
1859 "Expected success when validator is present and value is unique. Error: {:?}",
1860 result.err()
1861 );
1862
1863 let user_taken = TestUser {
1865 email: "taken@example.com".to_string(),
1866 };
1867 let body_taken = serde_json::to_vec(&user_taken).unwrap();
1868
1869 let db_validator = MockDbValidator {
1870 unique_values: vec!["taken@example.com".to_string()],
1871 };
1872 let ctx = ValidationContextBuilder::new()
1873 .database(db_validator)
1874 .build();
1875
1876 let mut extensions = Extensions::new();
1877 extensions.insert(ctx);
1878
1879 let builder = http::Request::builder()
1880 .method(Method::POST)
1881 .uri("/test")
1882 .header("content-type", "application/json");
1883 let req = builder.body(()).unwrap();
1884 let (parts, _) = req.into_parts();
1885
1886 let mut request = Request::new(
1887 parts,
1888 crate::request::BodyVariant::Buffered(Bytes::from(body_taken)),
1889 Arc::new(extensions),
1890 PathParams::new(),
1891 );
1892
1893 let result = AsyncValidatedJson::<TestUser>::from_request(&mut request).await;
1894 assert!(result.is_err(), "Expected validation error for taken email");
1895 }
1896}