1use crate::error::{ApiError, Result};
58use crate::json;
59use crate::request::Request;
60use crate::response::IntoResponse;
61use crate::stream::{StreamingBody, StreamingConfig};
62use crate::validation::Validatable;
63use bytes::Bytes;
64use http::{header, StatusCode};
65use rustapi_validate::v2::{AsyncValidate, ValidationContext};
66
67use rustapi_openapi::schema::{RustApiSchema, SchemaCtx, SchemaRef};
68use serde::de::DeserializeOwned;
69use serde::Serialize;
70use std::collections::BTreeMap;
71use std::future::Future;
72use std::ops::{Deref, DerefMut};
73use std::str::FromStr;
74
75pub trait FromRequestParts: Sized {
100 fn from_request_parts(req: &Request) -> Result<Self>;
102}
103
104pub trait FromRequest: Sized {
134 fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
136}
137
138impl<T: FromRequestParts> FromRequest for T {
140 async fn from_request(req: &mut Request) -> Result<Self> {
141 T::from_request_parts(req)
142 }
143}
144
145#[derive(Debug, Clone, Copy, Default)]
164pub struct Json<T>(pub T);
165
166impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
167 async fn from_request(req: &mut Request) -> Result<Self> {
168 req.load_body().await?;
169 let body = req
170 .take_body()
171 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
172
173 let value: T = json::from_slice(&body)?;
175 Ok(Json(value))
176 }
177}
178
179impl<T> Deref for Json<T> {
180 type Target = T;
181
182 fn deref(&self) -> &Self::Target {
183 &self.0
184 }
185}
186
187impl<T> DerefMut for Json<T> {
188 fn deref_mut(&mut self) -> &mut Self::Target {
189 &mut self.0
190 }
191}
192
193impl<T> From<T> for Json<T> {
194 fn from(value: T) -> Self {
195 Json(value)
196 }
197}
198
199const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
202
203impl<T: Serialize> IntoResponse for Json<T> {
205 fn into_response(self) -> crate::response::Response {
206 match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
208 Ok(body) => http::Response::builder()
209 .status(StatusCode::OK)
210 .header(header::CONTENT_TYPE, "application/json")
211 .body(crate::response::Body::from(body))
212 .unwrap(),
213 Err(err) => {
214 ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
215 }
216 }
217 }
218}
219
220#[derive(Debug, Clone, Copy, Default)]
246pub struct ValidatedJson<T>(pub T);
247
248impl<T> ValidatedJson<T> {
249 pub fn new(value: T) -> Self {
251 Self(value)
252 }
253
254 pub fn into_inner(self) -> T {
256 self.0
257 }
258}
259
260impl<T: DeserializeOwned + Validatable + Send> FromRequest for ValidatedJson<T> {
261 async fn from_request(req: &mut Request) -> Result<Self> {
262 req.load_body().await?;
263 let body = req
265 .take_body()
266 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
267
268 let value: T = json::from_slice(&body)?;
269
270 value.do_validate()?;
272
273 Ok(ValidatedJson(value))
274 }
275}
276
277impl<T> Deref for ValidatedJson<T> {
278 type Target = T;
279
280 fn deref(&self) -> &Self::Target {
281 &self.0
282 }
283}
284
285impl<T> DerefMut for ValidatedJson<T> {
286 fn deref_mut(&mut self) -> &mut Self::Target {
287 &mut self.0
288 }
289}
290
291impl<T> From<T> for ValidatedJson<T> {
292 fn from(value: T) -> Self {
293 ValidatedJson(value)
294 }
295}
296
297impl<T: Serialize> IntoResponse for ValidatedJson<T> {
298 fn into_response(self) -> crate::response::Response {
299 Json(self.0).into_response()
300 }
301}
302
303#[derive(Debug, Clone, Copy, Default)]
330pub struct AsyncValidatedJson<T>(pub T);
331
332impl<T> AsyncValidatedJson<T> {
333 pub fn new(value: T) -> Self {
335 Self(value)
336 }
337
338 pub fn into_inner(self) -> T {
340 self.0
341 }
342}
343
344impl<T> Deref for AsyncValidatedJson<T> {
345 type Target = T;
346
347 fn deref(&self) -> &Self::Target {
348 &self.0
349 }
350}
351
352impl<T> DerefMut for AsyncValidatedJson<T> {
353 fn deref_mut(&mut self) -> &mut Self::Target {
354 &mut self.0
355 }
356}
357
358impl<T> From<T> for AsyncValidatedJson<T> {
359 fn from(value: T) -> Self {
360 AsyncValidatedJson(value)
361 }
362}
363
364impl<T: Serialize> IntoResponse for AsyncValidatedJson<T> {
365 fn into_response(self) -> crate::response::Response {
366 Json(self.0).into_response()
367 }
368}
369
370impl<T: DeserializeOwned + AsyncValidate + Send + Sync> FromRequest for AsyncValidatedJson<T> {
371 async fn from_request(req: &mut Request) -> Result<Self> {
372 req.load_body().await?;
373
374 let body = req
375 .take_body()
376 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
377
378 let value: T = json::from_slice(&body)?;
379
380 let ctx = if let Some(ctx) = req.state().get::<ValidationContext>() {
383 ctx.clone()
384 } else {
385 ValidationContext::default()
386 };
387
388 if let Err(errors) = value.validate_full(&ctx).await {
390 let field_errors: Vec<crate::error::FieldError> = errors
392 .fields
393 .iter()
394 .flat_map(|(field, errs)| {
395 let field_name = field.to_string();
396 errs.iter().map(move |e| crate::error::FieldError {
397 field: field_name.clone(),
398 code: e.code.to_string(),
399 message: e.message.clone(),
400 })
401 })
402 .collect();
403
404 return Err(ApiError::validation(field_errors));
405 }
406
407 Ok(AsyncValidatedJson(value))
408 }
409}
410
411#[derive(Debug, Clone)]
429pub struct Query<T>(pub T);
430
431impl<T: DeserializeOwned> FromRequestParts for Query<T> {
432 fn from_request_parts(req: &Request) -> Result<Self> {
433 let query = req.query_string().unwrap_or("");
434 let value: T = serde_urlencoded::from_str(query)
435 .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
436 Ok(Query(value))
437 }
438}
439
440impl<T> Deref for Query<T> {
441 type Target = T;
442
443 fn deref(&self) -> &Self::Target {
444 &self.0
445 }
446}
447
448#[derive(Debug, Clone)]
470pub struct Path<T>(pub T);
471
472impl<T: FromStr> FromRequestParts for Path<T>
473where
474 T::Err: std::fmt::Display,
475{
476 fn from_request_parts(req: &Request) -> Result<Self> {
477 let params = req.path_params();
478
479 if let Some((_, value)) = params.iter().next() {
481 let parsed = value
482 .parse::<T>()
483 .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
484 return Ok(Path(parsed));
485 }
486
487 Err(ApiError::internal("Missing path parameter"))
488 }
489}
490
491impl<T> Deref for Path<T> {
492 type Target = T;
493
494 fn deref(&self) -> &Self::Target {
495 &self.0
496 }
497}
498
499#[derive(Debug, Clone)]
519pub struct Typed<T>(pub T);
520
521impl<T: DeserializeOwned + Send> FromRequestParts for Typed<T> {
522 fn from_request_parts(req: &Request) -> Result<Self> {
523 let params = req.path_params();
524 let mut map = serde_json::Map::new();
525 for (k, v) in params.iter() {
526 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
527 }
528 let value = serde_json::Value::Object(map);
529 let parsed: T = serde_json::from_value(value)
530 .map_err(|e| ApiError::bad_request(format!("Invalid path parameters: {}", e)))?;
531 Ok(Typed(parsed))
532 }
533}
534
535impl<T> Deref for Typed<T> {
536 type Target = T;
537
538 fn deref(&self) -> &Self::Target {
539 &self.0
540 }
541}
542
543#[derive(Debug, Clone)]
560pub struct State<T>(pub T);
561
562impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
563 fn from_request_parts(req: &Request) -> Result<Self> {
564 req.state().get::<T>().cloned().map(State).ok_or_else(|| {
565 ApiError::internal(format!(
566 "State of type `{}` not found. Did you forget to call .state()?",
567 std::any::type_name::<T>()
568 ))
569 })
570 }
571}
572
573impl<T> Deref for State<T> {
574 type Target = T;
575
576 fn deref(&self) -> &Self::Target {
577 &self.0
578 }
579}
580
581#[derive(Debug, Clone)]
583pub struct Body(pub Bytes);
584
585impl FromRequest for Body {
586 async fn from_request(req: &mut Request) -> Result<Self> {
587 req.load_body().await?;
588 let body = req
589 .take_body()
590 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
591 Ok(Body(body))
592 }
593}
594
595impl Deref for Body {
596 type Target = Bytes;
597
598 fn deref(&self) -> &Self::Target {
599 &self.0
600 }
601}
602
603pub struct BodyStream(pub StreamingBody);
605
606impl FromRequest for BodyStream {
607 async fn from_request(req: &mut Request) -> Result<Self> {
608 let config = StreamingConfig::default();
609
610 if let Some(stream) = req.take_stream() {
611 Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
612 } else if let Some(bytes) = req.take_body() {
613 let stream = futures_util::stream::once(async move { Ok(bytes) });
615 Ok(BodyStream(StreamingBody::from_stream(
616 stream,
617 config.max_body_size,
618 )))
619 } else {
620 Err(ApiError::internal("Body already consumed"))
621 }
622 }
623}
624
625impl Deref for BodyStream {
626 type Target = StreamingBody;
627
628 fn deref(&self) -> &Self::Target {
629 &self.0
630 }
631}
632
633impl DerefMut for BodyStream {
634 fn deref_mut(&mut self) -> &mut Self::Target {
635 &mut self.0
636 }
637}
638
639impl futures_util::Stream for BodyStream {
641 type Item = Result<Bytes, ApiError>;
642
643 fn poll_next(
644 mut self: std::pin::Pin<&mut Self>,
645 cx: &mut std::task::Context<'_>,
646 ) -> std::task::Poll<Option<Self::Item>> {
647 std::pin::Pin::new(&mut self.0).poll_next(cx)
648 }
649}
650
651impl<T: FromRequestParts> FromRequestParts for Option<T> {
655 fn from_request_parts(req: &Request) -> Result<Self> {
656 Ok(T::from_request_parts(req).ok())
657 }
658}
659
660#[derive(Debug, Clone)]
678pub struct Headers(pub http::HeaderMap);
679
680impl Headers {
681 pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
683 self.0.get(name)
684 }
685
686 pub fn contains(&self, name: &str) -> bool {
688 self.0.contains_key(name)
689 }
690
691 pub fn len(&self) -> usize {
693 self.0.len()
694 }
695
696 pub fn is_empty(&self) -> bool {
698 self.0.is_empty()
699 }
700
701 pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
703 self.0.iter()
704 }
705}
706
707impl FromRequestParts for Headers {
708 fn from_request_parts(req: &Request) -> Result<Self> {
709 Ok(Headers(req.headers().clone()))
710 }
711}
712
713impl Deref for Headers {
714 type Target = http::HeaderMap;
715
716 fn deref(&self) -> &Self::Target {
717 &self.0
718 }
719}
720
721#[derive(Debug, Clone)]
740pub struct HeaderValue(pub String, pub &'static str);
741
742impl HeaderValue {
743 pub fn new(name: &'static str, value: String) -> Self {
745 Self(value, name)
746 }
747
748 pub fn value(&self) -> &str {
750 &self.0
751 }
752
753 pub fn name(&self) -> &'static str {
755 self.1
756 }
757
758 pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
760 req.headers()
761 .get(name)
762 .and_then(|v| v.to_str().ok())
763 .map(|s| HeaderValue(s.to_string(), name))
764 .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
765 }
766}
767
768impl Deref for HeaderValue {
769 type Target = String;
770
771 fn deref(&self) -> &Self::Target {
772 &self.0
773 }
774}
775
776#[derive(Debug, Clone)]
794pub struct Extension<T>(pub T);
795
796impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
797 fn from_request_parts(req: &Request) -> Result<Self> {
798 req.extensions()
799 .get::<T>()
800 .cloned()
801 .map(Extension)
802 .ok_or_else(|| {
803 ApiError::internal(format!(
804 "Extension of type `{}` not found. Did middleware insert it?",
805 std::any::type_name::<T>()
806 ))
807 })
808 }
809}
810
811impl<T> Deref for Extension<T> {
812 type Target = T;
813
814 fn deref(&self) -> &Self::Target {
815 &self.0
816 }
817}
818
819impl<T> DerefMut for Extension<T> {
820 fn deref_mut(&mut self) -> &mut Self::Target {
821 &mut self.0
822 }
823}
824
825#[derive(Debug, Clone)]
840pub struct ClientIp(pub std::net::IpAddr);
841
842impl ClientIp {
843 pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
845 if trust_proxy {
846 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
848 if let Ok(forwarded_str) = forwarded.to_str() {
849 if let Some(first_ip) = forwarded_str.split(',').next() {
851 if let Ok(ip) = first_ip.trim().parse() {
852 return Ok(ClientIp(ip));
853 }
854 }
855 }
856 }
857 }
858
859 if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
861 return Ok(ClientIp(addr.ip()));
862 }
863
864 Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
866 127, 0, 0, 1,
867 ))))
868 }
869}
870
871impl FromRequestParts for ClientIp {
872 fn from_request_parts(req: &Request) -> Result<Self> {
873 Self::extract_with_config(req, true)
875 }
876}
877
878#[cfg(feature = "cookies")]
896#[derive(Debug, Clone)]
897pub struct Cookies(pub cookie::CookieJar);
898
899#[cfg(feature = "cookies")]
900impl Cookies {
901 pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
903 self.0.get(name)
904 }
905
906 pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
908 self.0.iter()
909 }
910
911 pub fn contains(&self, name: &str) -> bool {
913 self.0.get(name).is_some()
914 }
915}
916
917#[cfg(feature = "cookies")]
918impl FromRequestParts for Cookies {
919 fn from_request_parts(req: &Request) -> Result<Self> {
920 let mut jar = cookie::CookieJar::new();
921
922 if let Some(cookie_header) = req.headers().get(header::COOKIE) {
923 if let Ok(cookie_str) = cookie_header.to_str() {
924 for cookie_part in cookie_str.split(';') {
926 let trimmed = cookie_part.trim();
927 if !trimmed.is_empty() {
928 if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
929 jar.add_original(cookie.into_owned());
930 }
931 }
932 }
933 }
934 }
935
936 Ok(Cookies(jar))
937 }
938}
939
940#[cfg(feature = "cookies")]
941impl Deref for Cookies {
942 type Target = cookie::CookieJar;
943
944 fn deref(&self) -> &Self::Target {
945 &self.0
946 }
947}
948
949macro_rules! impl_from_request_parts_for_primitives {
951 ($($ty:ty),*) => {
952 $(
953 impl FromRequestParts for $ty {
954 fn from_request_parts(req: &Request) -> Result<Self> {
955 let Path(value) = Path::<$ty>::from_request_parts(req)?;
956 Ok(value)
957 }
958 }
959 )*
960 };
961}
962
963impl_from_request_parts_for_primitives!(
964 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
965);
966
967use rustapi_openapi::{
970 MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier, ResponseSpec,
971};
972
973impl<T: RustApiSchema> OperationModifier for ValidatedJson<T> {
975 fn update_operation(op: &mut Operation) {
976 let mut ctx = SchemaCtx::new();
977 let schema_ref = T::schema(&mut ctx);
978
979 let mut content = BTreeMap::new();
980 content.insert(
981 "application/json".to_string(),
982 MediaType {
983 schema: Some(schema_ref),
984 example: None,
985 },
986 );
987
988 op.request_body = Some(RequestBody {
989 description: None,
990 required: Some(true),
991 content,
992 });
993
994 let mut responses_content = BTreeMap::new();
996 responses_content.insert(
997 "application/json".to_string(),
998 MediaType {
999 schema: Some(SchemaRef::Ref {
1000 reference: "#/components/schemas/ValidationErrorSchema".to_string(),
1001 }),
1002 example: None,
1003 },
1004 );
1005
1006 op.responses.insert(
1007 "422".to_string(),
1008 ResponseSpec {
1009 description: "Validation Error".to_string(),
1010 content: responses_content,
1011 headers: BTreeMap::new(),
1012 },
1013 );
1014 }
1015}
1016
1017impl<T: RustApiSchema> OperationModifier for AsyncValidatedJson<T> {
1019 fn update_operation(op: &mut Operation) {
1020 let mut ctx = SchemaCtx::new();
1021 let schema_ref = T::schema(&mut ctx);
1022
1023 let mut content = BTreeMap::new();
1024 content.insert(
1025 "application/json".to_string(),
1026 MediaType {
1027 schema: Some(schema_ref),
1028 example: None,
1029 },
1030 );
1031
1032 op.request_body = Some(RequestBody {
1033 description: None,
1034 required: Some(true),
1035 content,
1036 });
1037
1038 let mut responses_content = BTreeMap::new();
1040 responses_content.insert(
1041 "application/json".to_string(),
1042 MediaType {
1043 schema: Some(SchemaRef::Ref {
1044 reference: "#/components/schemas/ValidationErrorSchema".to_string(),
1045 }),
1046 example: None,
1047 },
1048 );
1049
1050 op.responses.insert(
1051 "422".to_string(),
1052 ResponseSpec {
1053 description: "Validation Error".to_string(),
1054 content: responses_content,
1055 headers: BTreeMap::new(),
1056 },
1057 );
1058 }
1059}
1060
1061impl<T: RustApiSchema> OperationModifier for Json<T> {
1063 fn update_operation(op: &mut Operation) {
1064 let mut ctx = SchemaCtx::new();
1065 let schema_ref = T::schema(&mut ctx);
1066
1067 let mut content = BTreeMap::new();
1068 content.insert(
1069 "application/json".to_string(),
1070 MediaType {
1071 schema: Some(schema_ref),
1072 example: None,
1073 },
1074 );
1075
1076 op.request_body = Some(RequestBody {
1077 description: None,
1078 required: Some(true),
1079 content,
1080 });
1081 }
1082}
1083
1084impl<T> OperationModifier for Path<T> {
1086 fn update_operation(_op: &mut Operation) {}
1087}
1088
1089impl<T> OperationModifier for Typed<T> {
1091 fn update_operation(_op: &mut Operation) {}
1092}
1093
1094impl<T: RustApiSchema> OperationModifier for Query<T> {
1096 fn update_operation(op: &mut Operation) {
1097 let mut ctx = SchemaCtx::new();
1098 if let Some(fields) = T::field_schemas(&mut ctx) {
1099 let new_params: Vec<Parameter> = fields
1100 .into_iter()
1101 .map(|(name, schema)| {
1102 Parameter {
1103 name,
1104 location: "query".to_string(),
1105 required: false, deprecated: None,
1107 description: None,
1108 schema: Some(schema),
1109 }
1110 })
1111 .collect();
1112
1113 op.parameters.extend(new_params);
1114 }
1115 }
1116}
1117
1118impl<T> OperationModifier for State<T> {
1120 fn update_operation(_op: &mut Operation) {}
1121}
1122
1123impl OperationModifier for Body {
1125 fn update_operation(op: &mut Operation) {
1126 let mut content = BTreeMap::new();
1127 content.insert(
1128 "application/octet-stream".to_string(),
1129 MediaType {
1130 schema: Some(SchemaRef::Inline(
1131 serde_json::json!({ "type": "string", "format": "binary" }),
1132 )),
1133 example: None,
1134 },
1135 );
1136
1137 op.request_body = Some(RequestBody {
1138 description: None,
1139 required: Some(true),
1140 content,
1141 });
1142 }
1143}
1144
1145impl OperationModifier for BodyStream {
1147 fn update_operation(op: &mut Operation) {
1148 let mut content = BTreeMap::new();
1149 content.insert(
1150 "application/octet-stream".to_string(),
1151 MediaType {
1152 schema: Some(SchemaRef::Inline(
1153 serde_json::json!({ "type": "string", "format": "binary" }),
1154 )),
1155 example: None,
1156 },
1157 );
1158
1159 op.request_body = Some(RequestBody {
1160 description: None,
1161 required: Some(true),
1162 content,
1163 });
1164 }
1165}
1166
1167impl<T: RustApiSchema> ResponseModifier for Json<T> {
1171 fn update_response(op: &mut Operation) {
1172 let mut ctx = SchemaCtx::new();
1173 let schema_ref = T::schema(&mut ctx);
1174
1175 let mut content = BTreeMap::new();
1176 content.insert(
1177 "application/json".to_string(),
1178 MediaType {
1179 schema: Some(schema_ref),
1180 example: None,
1181 },
1182 );
1183
1184 op.responses.insert(
1185 "200".to_string(),
1186 ResponseSpec {
1187 description: "Successful response".to_string(),
1188 content,
1189 headers: BTreeMap::new(),
1190 },
1191 );
1192 }
1193}
1194
1195impl<T: RustApiSchema> RustApiSchema for Json<T> {
1198 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1199 T::schema(ctx)
1200 }
1201}
1202
1203impl<T: RustApiSchema> RustApiSchema for ValidatedJson<T> {
1204 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1205 T::schema(ctx)
1206 }
1207}
1208
1209impl<T: RustApiSchema> RustApiSchema for AsyncValidatedJson<T> {
1210 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1211 T::schema(ctx)
1212 }
1213}
1214
1215impl<T: RustApiSchema> RustApiSchema for Query<T> {
1216 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1217 T::schema(ctx)
1218 }
1219 fn field_schemas(ctx: &mut SchemaCtx) -> Option<BTreeMap<String, SchemaRef>> {
1220 T::field_schemas(ctx)
1221 }
1222}
1223
1224const DEFAULT_PAGE: u64 = 1;
1228const DEFAULT_PER_PAGE: u64 = 20;
1230const MAX_PER_PAGE: u64 = 100;
1232
1233#[derive(Debug, Clone, Copy)]
1250pub struct Paginate {
1251 pub page: u64,
1253 pub per_page: u64,
1255}
1256
1257impl Paginate {
1258 pub fn new(page: u64, per_page: u64) -> Self {
1260 Self {
1261 page: page.max(1),
1262 per_page: per_page.clamp(1, MAX_PER_PAGE),
1263 }
1264 }
1265
1266 pub fn offset(&self) -> u64 {
1268 (self.page - 1) * self.per_page
1269 }
1270
1271 pub fn limit(&self) -> u64 {
1273 self.per_page
1274 }
1275
1276 pub fn paginate<T>(self, items: Vec<T>, total: u64) -> crate::hateoas::Paginated<T> {
1278 crate::hateoas::Paginated {
1279 items,
1280 page: self.page,
1281 per_page: self.per_page,
1282 total,
1283 }
1284 }
1285}
1286
1287impl Default for Paginate {
1288 fn default() -> Self {
1289 Self {
1290 page: DEFAULT_PAGE,
1291 per_page: DEFAULT_PER_PAGE,
1292 }
1293 }
1294}
1295
1296impl FromRequestParts for Paginate {
1297 fn from_request_parts(req: &Request) -> Result<Self> {
1298 let query = req.query_string().unwrap_or("");
1299
1300 #[derive(serde::Deserialize)]
1301 struct PaginateQuery {
1302 page: Option<u64>,
1303 per_page: Option<u64>,
1304 }
1305
1306 let params: PaginateQuery = serde_urlencoded::from_str(query).unwrap_or(PaginateQuery {
1307 page: None,
1308 per_page: None,
1309 });
1310
1311 Ok(Paginate::new(
1312 params.page.unwrap_or(DEFAULT_PAGE),
1313 params.per_page.unwrap_or(DEFAULT_PER_PAGE),
1314 ))
1315 }
1316}
1317
1318#[derive(Debug, Clone)]
1341pub struct CursorPaginate {
1342 pub cursor: Option<String>,
1344 pub per_page: u64,
1346}
1347
1348impl CursorPaginate {
1349 pub fn new(cursor: Option<String>, per_page: u64) -> Self {
1351 Self {
1352 cursor,
1353 per_page: per_page.clamp(1, MAX_PER_PAGE),
1354 }
1355 }
1356
1357 pub fn after(&self) -> Option<&str> {
1359 self.cursor.as_deref()
1360 }
1361
1362 pub fn limit(&self) -> u64 {
1364 self.per_page
1365 }
1366
1367 pub fn is_first_page(&self) -> bool {
1369 self.cursor.is_none()
1370 }
1371}
1372
1373impl Default for CursorPaginate {
1374 fn default() -> Self {
1375 Self {
1376 cursor: None,
1377 per_page: DEFAULT_PER_PAGE,
1378 }
1379 }
1380}
1381
1382impl FromRequestParts for CursorPaginate {
1383 fn from_request_parts(req: &Request) -> Result<Self> {
1384 let query = req.query_string().unwrap_or("");
1385
1386 #[derive(serde::Deserialize)]
1387 struct CursorQuery {
1388 cursor: Option<String>,
1389 limit: Option<u64>,
1390 }
1391
1392 let params: CursorQuery = serde_urlencoded::from_str(query).unwrap_or(CursorQuery {
1393 cursor: None,
1394 limit: None,
1395 });
1396
1397 Ok(CursorPaginate::new(
1398 params.cursor,
1399 params.limit.unwrap_or(DEFAULT_PER_PAGE),
1400 ))
1401 }
1402}
1403
1404#[cfg(test)]
1405mod tests {
1406 use super::*;
1407 use crate::path_params::PathParams;
1408 use bytes::Bytes;
1409 use http::{Extensions, Method};
1410 use proptest::prelude::*;
1411 use proptest::test_runner::TestCaseError;
1412 use std::sync::Arc;
1413
1414 fn create_test_request_with_headers(
1416 method: Method,
1417 path: &str,
1418 headers: Vec<(&str, &str)>,
1419 ) -> Request {
1420 let uri: http::Uri = path.parse().unwrap();
1421 let mut builder = http::Request::builder().method(method).uri(uri);
1422
1423 for (name, value) in headers {
1424 builder = builder.header(name, value);
1425 }
1426
1427 let req = builder.body(()).unwrap();
1428 let (parts, _) = req.into_parts();
1429
1430 Request::new(
1431 parts,
1432 crate::request::BodyVariant::Buffered(Bytes::new()),
1433 Arc::new(Extensions::new()),
1434 PathParams::new(),
1435 )
1436 }
1437
1438 fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
1440 method: Method,
1441 path: &str,
1442 extension: T,
1443 ) -> Request {
1444 let uri: http::Uri = path.parse().unwrap();
1445 let builder = http::Request::builder().method(method).uri(uri);
1446
1447 let req = builder.body(()).unwrap();
1448 let (mut parts, _) = req.into_parts();
1449 parts.extensions.insert(extension);
1450
1451 Request::new(
1452 parts,
1453 crate::request::BodyVariant::Buffered(Bytes::new()),
1454 Arc::new(Extensions::new()),
1455 PathParams::new(),
1456 )
1457 }
1458
1459 proptest! {
1466 #![proptest_config(ProptestConfig::with_cases(100))]
1467
1468 #[test]
1469 fn prop_headers_extractor_completeness(
1470 headers in prop::collection::vec(
1473 (
1474 "[a-z][a-z0-9-]{0,20}", "[a-zA-Z0-9 ]{1,50}" ),
1477 0..10
1478 )
1479 ) {
1480 let result: Result<(), TestCaseError> = (|| {
1481 let header_tuples: Vec<(&str, &str)> = headers
1483 .iter()
1484 .map(|(k, v)| (k.as_str(), v.as_str()))
1485 .collect();
1486
1487 let request = create_test_request_with_headers(
1489 Method::GET,
1490 "/test",
1491 header_tuples.clone(),
1492 );
1493
1494 let extracted = Headers::from_request_parts(&request)
1496 .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
1497
1498 for (name, value) in &headers {
1501 let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
1503 prop_assert!(
1504 !all_values.is_empty(),
1505 "Header '{}' not found",
1506 name
1507 );
1508
1509 let value_found = all_values.iter().any(|v| {
1511 v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
1512 });
1513
1514 prop_assert!(
1515 value_found,
1516 "Header '{}' value '{}' not found in extracted values",
1517 name,
1518 value
1519 );
1520 }
1521
1522 Ok(())
1523 })();
1524 result?;
1525 }
1526 }
1527
1528 proptest! {
1535 #![proptest_config(ProptestConfig::with_cases(100))]
1536
1537 #[test]
1538 fn prop_header_value_extractor_correctness(
1539 header_name in "[a-z][a-z0-9-]{0,20}",
1540 header_value in "[a-zA-Z0-9 ]{1,50}",
1541 has_header in prop::bool::ANY,
1542 ) {
1543 let result: Result<(), TestCaseError> = (|| {
1544 let headers = if has_header {
1545 vec![(header_name.as_str(), header_value.as_str())]
1546 } else {
1547 vec![]
1548 };
1549
1550 let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1551
1552 let test_header = "x-test-header";
1555 let request_with_known_header = if has_header {
1556 create_test_request_with_headers(
1557 Method::GET,
1558 "/test",
1559 vec![(test_header, header_value.as_str())],
1560 )
1561 } else {
1562 create_test_request_with_headers(Method::GET, "/test", vec![])
1563 };
1564
1565 let result = HeaderValue::extract(&request_with_known_header, test_header);
1566
1567 if has_header {
1568 let extracted = result
1569 .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1570 prop_assert_eq!(
1571 extracted.value(),
1572 header_value.as_str(),
1573 "Header value mismatch"
1574 );
1575 } else {
1576 prop_assert!(
1577 result.is_err(),
1578 "Expected error when header is missing"
1579 );
1580 }
1581
1582 Ok(())
1583 })();
1584 result?;
1585 }
1586 }
1587
1588 proptest! {
1595 #![proptest_config(ProptestConfig::with_cases(100))]
1596
1597 #[test]
1598 fn prop_client_ip_extractor_with_forwarding(
1599 forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1601 .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1602 socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1603 .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1604 has_forwarded_header in prop::bool::ANY,
1605 trust_proxy in prop::bool::ANY,
1606 ) {
1607 let result: Result<(), TestCaseError> = (|| {
1608 let headers = if has_forwarded_header {
1609 vec![("x-forwarded-for", forwarded_ip.as_str())]
1610 } else {
1611 vec![]
1612 };
1613
1614 let uri: http::Uri = "/test".parse().unwrap();
1616 let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1617 for (name, value) in &headers {
1618 builder = builder.header(*name, *value);
1619 }
1620 let req = builder.body(()).unwrap();
1621 let (mut parts, _) = req.into_parts();
1622
1623 let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1625 parts.extensions.insert(socket_addr);
1626
1627 let request = Request::new(
1628 parts,
1629 crate::request::BodyVariant::Buffered(Bytes::new()),
1630 Arc::new(Extensions::new()),
1631 PathParams::new(),
1632 );
1633
1634 let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1635 .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1636
1637 if trust_proxy && has_forwarded_header {
1638 let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1640 .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1641 prop_assert_eq!(
1642 extracted.0,
1643 expected_ip,
1644 "Should use X-Forwarded-For IP when trust_proxy is enabled"
1645 );
1646 } else {
1647 prop_assert_eq!(
1649 extracted.0,
1650 socket_ip,
1651 "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1652 );
1653 }
1654
1655 Ok(())
1656 })();
1657 result?;
1658 }
1659 }
1660
1661 proptest! {
1668 #![proptest_config(ProptestConfig::with_cases(100))]
1669
1670 #[test]
1671 fn prop_extension_extractor_retrieval(
1672 value in any::<i64>(),
1673 has_extension in prop::bool::ANY,
1674 ) {
1675 let result: Result<(), TestCaseError> = (|| {
1676 #[derive(Clone, Debug, PartialEq)]
1678 struct TestExtension(i64);
1679
1680 let uri: http::Uri = "/test".parse().unwrap();
1681 let builder = http::Request::builder().method(Method::GET).uri(uri);
1682 let req = builder.body(()).unwrap();
1683 let (mut parts, _) = req.into_parts();
1684
1685 if has_extension {
1686 parts.extensions.insert(TestExtension(value));
1687 }
1688
1689 let request = Request::new(
1690 parts,
1691 crate::request::BodyVariant::Buffered(Bytes::new()),
1692 Arc::new(Extensions::new()),
1693 PathParams::new(),
1694 );
1695
1696 let result = Extension::<TestExtension>::from_request_parts(&request);
1697
1698 if has_extension {
1699 let extracted = result
1700 .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1701 prop_assert_eq!(
1702 extracted.0,
1703 TestExtension(value),
1704 "Extension value mismatch"
1705 );
1706 } else {
1707 prop_assert!(
1708 result.is_err(),
1709 "Expected error when extension is missing"
1710 );
1711 }
1712
1713 Ok(())
1714 })();
1715 result?;
1716 }
1717 }
1718
1719 #[test]
1722 fn test_headers_extractor_basic() {
1723 let request = create_test_request_with_headers(
1724 Method::GET,
1725 "/test",
1726 vec![
1727 ("content-type", "application/json"),
1728 ("accept", "text/html"),
1729 ],
1730 );
1731
1732 let headers = Headers::from_request_parts(&request).unwrap();
1733
1734 assert!(headers.contains("content-type"));
1735 assert!(headers.contains("accept"));
1736 assert!(!headers.contains("x-custom"));
1737 assert_eq!(headers.len(), 2);
1738 }
1739
1740 #[test]
1741 fn test_header_value_extractor_present() {
1742 let request = create_test_request_with_headers(
1743 Method::GET,
1744 "/test",
1745 vec![("authorization", "Bearer token123")],
1746 );
1747
1748 let result = HeaderValue::extract(&request, "authorization");
1749 assert!(result.is_ok());
1750 assert_eq!(result.unwrap().value(), "Bearer token123");
1751 }
1752
1753 #[test]
1754 fn test_header_value_extractor_missing() {
1755 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1756
1757 let result = HeaderValue::extract(&request, "authorization");
1758 assert!(result.is_err());
1759 }
1760
1761 #[test]
1762 fn test_client_ip_from_forwarded_header() {
1763 let request = create_test_request_with_headers(
1764 Method::GET,
1765 "/test",
1766 vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1767 );
1768
1769 let ip = ClientIp::extract_with_config(&request, true).unwrap();
1770 assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1771 }
1772
1773 #[test]
1774 fn test_client_ip_ignores_forwarded_when_not_trusted() {
1775 let uri: http::Uri = "/test".parse().unwrap();
1776 let builder = http::Request::builder()
1777 .method(Method::GET)
1778 .uri(uri)
1779 .header("x-forwarded-for", "192.168.1.100");
1780 let req = builder.body(()).unwrap();
1781 let (mut parts, _) = req.into_parts();
1782
1783 let socket_addr = std::net::SocketAddr::new(
1784 std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1785 8080,
1786 );
1787 parts.extensions.insert(socket_addr);
1788
1789 let request = Request::new(
1790 parts,
1791 crate::request::BodyVariant::Buffered(Bytes::new()),
1792 Arc::new(Extensions::new()),
1793 PathParams::new(),
1794 );
1795
1796 let ip = ClientIp::extract_with_config(&request, false).unwrap();
1797 assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1798 }
1799
1800 #[test]
1801 fn test_extension_extractor_present() {
1802 #[derive(Clone, Debug, PartialEq)]
1803 struct MyData(String);
1804
1805 let request =
1806 create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1807
1808 let result = Extension::<MyData>::from_request_parts(&request);
1809 assert!(result.is_ok());
1810 assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1811 }
1812
1813 #[test]
1814 fn test_extension_extractor_missing() {
1815 #[derive(Clone, Debug)]
1816 #[allow(dead_code)]
1817 struct MyData(String);
1818
1819 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1820
1821 let result = Extension::<MyData>::from_request_parts(&request);
1822 assert!(result.is_err());
1823 }
1824
1825 #[cfg(feature = "cookies")]
1827 mod cookies_tests {
1828 use super::*;
1829
1830 proptest! {
1838 #![proptest_config(ProptestConfig::with_cases(100))]
1839
1840 #[test]
1841 fn prop_cookies_extractor_parsing(
1842 cookies in prop::collection::vec(
1845 (
1846 "[a-zA-Z][a-zA-Z0-9_]{0,15}", "[a-zA-Z0-9]{1,30}" ),
1849 0..5
1850 )
1851 ) {
1852 let result: Result<(), TestCaseError> = (|| {
1853 let cookie_header = cookies
1855 .iter()
1856 .map(|(name, value)| format!("{}={}", name, value))
1857 .collect::<Vec<_>>()
1858 .join("; ");
1859
1860 let headers = if !cookies.is_empty() {
1861 vec![("cookie", cookie_header.as_str())]
1862 } else {
1863 vec![]
1864 };
1865
1866 let request = create_test_request_with_headers(Method::GET, "/test", headers);
1867
1868 let extracted = Cookies::from_request_parts(&request)
1870 .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1871
1872 let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1874 for (name, value) in &cookies {
1875 expected_cookies.insert(name.as_str(), value.as_str());
1876 }
1877
1878 for (name, expected_value) in &expected_cookies {
1880 let cookie = extracted.get(name)
1881 .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1882
1883 prop_assert_eq!(
1884 cookie.value(),
1885 *expected_value,
1886 "Cookie '{}' value mismatch",
1887 name
1888 );
1889 }
1890
1891 let extracted_count = extracted.iter().count();
1893 prop_assert_eq!(
1894 extracted_count,
1895 expected_cookies.len(),
1896 "Expected {} unique cookies, got {}",
1897 expected_cookies.len(),
1898 extracted_count
1899 );
1900
1901 Ok(())
1902 })();
1903 result?;
1904 }
1905 }
1906
1907 #[test]
1908 fn test_cookies_extractor_basic() {
1909 let request = create_test_request_with_headers(
1910 Method::GET,
1911 "/test",
1912 vec![("cookie", "session=abc123; user=john")],
1913 );
1914
1915 let cookies = Cookies::from_request_parts(&request).unwrap();
1916
1917 assert!(cookies.contains("session"));
1918 assert!(cookies.contains("user"));
1919 assert!(!cookies.contains("other"));
1920
1921 assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1922 assert_eq!(cookies.get("user").unwrap().value(), "john");
1923 }
1924
1925 #[test]
1926 fn test_cookies_extractor_empty() {
1927 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1928
1929 let cookies = Cookies::from_request_parts(&request).unwrap();
1930 assert_eq!(cookies.iter().count(), 0);
1931 }
1932
1933 #[test]
1934 fn test_cookies_extractor_single() {
1935 let request = create_test_request_with_headers(
1936 Method::GET,
1937 "/test",
1938 vec![("cookie", "token=xyz789")],
1939 );
1940
1941 let cookies = Cookies::from_request_parts(&request).unwrap();
1942 assert_eq!(cookies.iter().count(), 1);
1943 assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1944 }
1945 }
1946
1947 #[tokio::test]
1948 async fn test_async_validated_json_with_state_context() {
1949 use async_trait::async_trait;
1950 use rustapi_validate::prelude::*;
1951 use rustapi_validate::v2::{
1952 AsyncValidationRule, DatabaseValidator, ValidationContextBuilder,
1953 };
1954 use serde::{Deserialize, Serialize};
1955
1956 struct MockDbValidator {
1957 unique_values: Vec<String>,
1958 }
1959
1960 #[async_trait]
1961 impl DatabaseValidator for MockDbValidator {
1962 async fn exists(
1963 &self,
1964 _table: &str,
1965 _column: &str,
1966 _value: &str,
1967 ) -> Result<bool, String> {
1968 Ok(true)
1969 }
1970 async fn is_unique(
1971 &self,
1972 _table: &str,
1973 _column: &str,
1974 value: &str,
1975 ) -> Result<bool, String> {
1976 Ok(!self.unique_values.contains(&value.to_string()))
1977 }
1978 async fn is_unique_except(
1979 &self,
1980 _table: &str,
1981 _column: &str,
1982 value: &str,
1983 _except_id: &str,
1984 ) -> Result<bool, String> {
1985 Ok(!self.unique_values.contains(&value.to_string()))
1986 }
1987 }
1988
1989 #[derive(Debug, Deserialize, Serialize)]
1990 struct TestUser {
1991 email: String,
1992 }
1993
1994 impl Validate for TestUser {
1995 fn validate_with_group(
1996 &self,
1997 _group: rustapi_validate::v2::ValidationGroup,
1998 ) -> Result<(), rustapi_validate::v2::ValidationErrors> {
1999 Ok(())
2000 }
2001 }
2002
2003 #[async_trait]
2004 impl AsyncValidate for TestUser {
2005 async fn validate_async_with_group(
2006 &self,
2007 ctx: &ValidationContext,
2008 _group: rustapi_validate::v2::ValidationGroup,
2009 ) -> Result<(), rustapi_validate::v2::ValidationErrors> {
2010 let mut errors = rustapi_validate::v2::ValidationErrors::new();
2011
2012 let rule = AsyncUniqueRule::new("users", "email");
2013 if let Err(e) = rule.validate_async(&self.email, ctx).await {
2014 errors.add("email", e);
2015 }
2016
2017 errors.into_result()
2018 }
2019 }
2020
2021 let uri: http::Uri = "/test".parse().unwrap();
2023 let user = TestUser {
2024 email: "new@example.com".to_string(),
2025 };
2026 let body_bytes = serde_json::to_vec(&user).unwrap();
2027
2028 let builder = http::Request::builder()
2029 .method(Method::POST)
2030 .uri(uri.clone())
2031 .header("content-type", "application/json");
2032 let req = builder.body(()).unwrap();
2033 let (parts, _) = req.into_parts();
2034
2035 let mut request = Request::new(
2037 parts,
2038 crate::request::BodyVariant::Buffered(Bytes::from(body_bytes.clone())),
2039 Arc::new(Extensions::new()),
2040 PathParams::new(),
2041 );
2042
2043 let result = AsyncValidatedJson::<TestUser>::from_request(&mut request).await;
2044
2045 assert!(result.is_err(), "Expected error when validator is missing");
2046 let err = result.unwrap_err();
2047 let err_str = format!("{:?}", err);
2048 assert!(
2049 err_str.contains("Database validator not configured")
2050 || err_str.contains("async_unique"),
2051 "Error should mention missing configuration or rule: {:?}",
2052 err_str
2053 );
2054
2055 let db_validator = MockDbValidator {
2057 unique_values: vec!["taken@example.com".to_string()],
2058 };
2059 let ctx = ValidationContextBuilder::new()
2060 .database(db_validator)
2061 .build();
2062
2063 let mut extensions = Extensions::new();
2064 extensions.insert(ctx);
2065
2066 let builder = http::Request::builder()
2067 .method(Method::POST)
2068 .uri(uri.clone())
2069 .header("content-type", "application/json");
2070 let req = builder.body(()).unwrap();
2071 let (parts, _) = req.into_parts();
2072
2073 let mut request = Request::new(
2074 parts,
2075 crate::request::BodyVariant::Buffered(Bytes::from(body_bytes.clone())),
2076 Arc::new(extensions),
2077 PathParams::new(),
2078 );
2079
2080 let result = AsyncValidatedJson::<TestUser>::from_request(&mut request).await;
2081 assert!(
2082 result.is_ok(),
2083 "Expected success when validator is present and value is unique. Error: {:?}",
2084 result.err()
2085 );
2086
2087 let user_taken = TestUser {
2089 email: "taken@example.com".to_string(),
2090 };
2091 let body_taken = serde_json::to_vec(&user_taken).unwrap();
2092
2093 let db_validator = MockDbValidator {
2094 unique_values: vec!["taken@example.com".to_string()],
2095 };
2096 let ctx = ValidationContextBuilder::new()
2097 .database(db_validator)
2098 .build();
2099
2100 let mut extensions = Extensions::new();
2101 extensions.insert(ctx);
2102
2103 let builder = http::Request::builder()
2104 .method(Method::POST)
2105 .uri("/test")
2106 .header("content-type", "application/json");
2107 let req = builder.body(()).unwrap();
2108 let (parts, _) = req.into_parts();
2109
2110 let mut request = Request::new(
2111 parts,
2112 crate::request::BodyVariant::Buffered(Bytes::from(body_taken)),
2113 Arc::new(extensions),
2114 PathParams::new(),
2115 );
2116
2117 let result = AsyncValidatedJson::<TestUser>::from_request(&mut request).await;
2118 assert!(result.is_err(), "Expected validation error for taken email");
2119 }
2120}