1use crate::error::{ApiError, Result};
58use crate::json;
59use crate::request::Request;
60use crate::response::IntoResponse;
61use crate::stream::{StreamingBody, StreamingConfig};
62use crate::validation::Validatable;
63use bytes::Bytes;
64use http::{header, StatusCode};
65use rustapi_validate::v2::{AsyncValidate, ValidationContext};
66
67use rustapi_openapi::schema::{RustApiSchema, SchemaCtx, SchemaRef};
68use serde::de::DeserializeOwned;
69use serde::Serialize;
70use std::collections::BTreeMap;
71use std::future::Future;
72use std::ops::{Deref, DerefMut};
73use std::str::FromStr;
74
75pub trait FromRequestParts: Sized {
100 fn from_request_parts(req: &Request) -> Result<Self>;
102}
103
104pub trait FromRequest: Sized {
134 fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
136}
137
138impl<T: FromRequestParts> FromRequest for T {
140 async fn from_request(req: &mut Request) -> Result<Self> {
141 T::from_request_parts(req)
142 }
143}
144
145#[derive(Debug, Clone, Copy, Default)]
164pub struct Json<T>(pub T);
165
166impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
167 async fn from_request(req: &mut Request) -> Result<Self> {
168 req.load_body().await?;
169 let body = req
170 .take_body()
171 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
172
173 let value: T = json::from_slice(&body)?;
175 Ok(Json(value))
176 }
177}
178
179impl<T> Deref for Json<T> {
180 type Target = T;
181
182 fn deref(&self) -> &Self::Target {
183 &self.0
184 }
185}
186
187impl<T> DerefMut for Json<T> {
188 fn deref_mut(&mut self) -> &mut Self::Target {
189 &mut self.0
190 }
191}
192
193impl<T> From<T> for Json<T> {
194 fn from(value: T) -> Self {
195 Json(value)
196 }
197}
198
199const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
202
203impl<T: Serialize> IntoResponse for Json<T> {
205 fn into_response(self) -> crate::response::Response {
206 match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
208 Ok(body) => http::Response::builder()
209 .status(StatusCode::OK)
210 .header(header::CONTENT_TYPE, "application/json")
211 .body(crate::response::Body::from(body))
212 .unwrap(),
213 Err(err) => {
214 ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
215 }
216 }
217 }
218}
219
220#[derive(Debug, Clone, Copy, Default)]
246pub struct ValidatedJson<T>(pub T);
247
248impl<T> ValidatedJson<T> {
249 pub fn new(value: T) -> Self {
251 Self(value)
252 }
253
254 pub fn into_inner(self) -> T {
256 self.0
257 }
258}
259
260impl<T: DeserializeOwned + Validatable + Send> FromRequest for ValidatedJson<T> {
261 async fn from_request(req: &mut Request) -> Result<Self> {
262 req.load_body().await?;
263 let body = req
265 .take_body()
266 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
267
268 let value: T = json::from_slice(&body)?;
269
270 value.do_validate()?;
272
273 Ok(ValidatedJson(value))
274 }
275}
276
277impl<T> Deref for ValidatedJson<T> {
278 type Target = T;
279
280 fn deref(&self) -> &Self::Target {
281 &self.0
282 }
283}
284
285impl<T> DerefMut for ValidatedJson<T> {
286 fn deref_mut(&mut self) -> &mut Self::Target {
287 &mut self.0
288 }
289}
290
291impl<T> From<T> for ValidatedJson<T> {
292 fn from(value: T) -> Self {
293 ValidatedJson(value)
294 }
295}
296
297impl<T: Serialize> IntoResponse for ValidatedJson<T> {
298 fn into_response(self) -> crate::response::Response {
299 Json(self.0).into_response()
300 }
301}
302
303#[derive(Debug, Clone, Copy, Default)]
330pub struct AsyncValidatedJson<T>(pub T);
331
332impl<T> AsyncValidatedJson<T> {
333 pub fn new(value: T) -> Self {
335 Self(value)
336 }
337
338 pub fn into_inner(self) -> T {
340 self.0
341 }
342}
343
344impl<T> Deref for AsyncValidatedJson<T> {
345 type Target = T;
346
347 fn deref(&self) -> &Self::Target {
348 &self.0
349 }
350}
351
352impl<T> DerefMut for AsyncValidatedJson<T> {
353 fn deref_mut(&mut self) -> &mut Self::Target {
354 &mut self.0
355 }
356}
357
358impl<T> From<T> for AsyncValidatedJson<T> {
359 fn from(value: T) -> Self {
360 AsyncValidatedJson(value)
361 }
362}
363
364impl<T: Serialize> IntoResponse for AsyncValidatedJson<T> {
365 fn into_response(self) -> crate::response::Response {
366 Json(self.0).into_response()
367 }
368}
369
370impl<T: DeserializeOwned + AsyncValidate + Send + Sync> FromRequest for AsyncValidatedJson<T> {
371 async fn from_request(req: &mut Request) -> Result<Self> {
372 req.load_body().await?;
373
374 let body = req
375 .take_body()
376 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
377
378 let value: T = json::from_slice(&body)?;
379
380 let ctx = if let Some(ctx) = req.state().get::<ValidationContext>() {
383 ctx.clone()
384 } else {
385 ValidationContext::default()
386 };
387
388 if let Err(errors) = value.validate_full(&ctx).await {
390 let field_errors: Vec<crate::error::FieldError> = errors
392 .fields
393 .iter()
394 .flat_map(|(field, errs)| {
395 let field_name = field.to_string();
396 errs.iter().map(move |e| crate::error::FieldError {
397 field: field_name.clone(),
398 code: e.code.to_string(),
399 message: e.message.clone(),
400 })
401 })
402 .collect();
403
404 return Err(ApiError::validation(field_errors));
405 }
406
407 Ok(AsyncValidatedJson(value))
408 }
409}
410
411#[derive(Debug, Clone)]
429pub struct Query<T>(pub T);
430
431impl<T: DeserializeOwned> FromRequestParts for Query<T> {
432 fn from_request_parts(req: &Request) -> Result<Self> {
433 let query = req.query_string().unwrap_or("");
434 let value: T = serde_urlencoded::from_str(query)
435 .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
436 Ok(Query(value))
437 }
438}
439
440impl<T> Deref for Query<T> {
441 type Target = T;
442
443 fn deref(&self) -> &Self::Target {
444 &self.0
445 }
446}
447
448#[derive(Debug, Clone)]
470pub struct Path<T>(pub T);
471
472impl<T: FromStr> FromRequestParts for Path<T>
473where
474 T::Err: std::fmt::Display,
475{
476 fn from_request_parts(req: &Request) -> Result<Self> {
477 let params = req.path_params();
478
479 if let Some((_, value)) = params.iter().next() {
481 let parsed = value
482 .parse::<T>()
483 .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
484 return Ok(Path(parsed));
485 }
486
487 Err(ApiError::internal("Missing path parameter"))
488 }
489}
490
491impl<T> Deref for Path<T> {
492 type Target = T;
493
494 fn deref(&self) -> &Self::Target {
495 &self.0
496 }
497}
498
499#[derive(Debug, Clone)]
519pub struct Typed<T>(pub T);
520
521impl<T: DeserializeOwned + Send> FromRequestParts for Typed<T> {
522 fn from_request_parts(req: &Request) -> Result<Self> {
523 let params = req.path_params();
524 let mut map = serde_json::Map::new();
525 for (k, v) in params.iter() {
526 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
527 }
528 let value = serde_json::Value::Object(map);
529 let parsed: T = serde_json::from_value(value)
530 .map_err(|e| ApiError::bad_request(format!("Invalid path parameters: {}", e)))?;
531 Ok(Typed(parsed))
532 }
533}
534
535impl<T> Deref for Typed<T> {
536 type Target = T;
537
538 fn deref(&self) -> &Self::Target {
539 &self.0
540 }
541}
542
543#[derive(Debug, Clone)]
560pub struct State<T>(pub T);
561
562impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
563 fn from_request_parts(req: &Request) -> Result<Self> {
564 req.state().get::<T>().cloned().map(State).ok_or_else(|| {
565 ApiError::internal(format!(
566 "State of type `{}` not found. Did you forget to call .state()?",
567 std::any::type_name::<T>()
568 ))
569 })
570 }
571}
572
573impl<T> Deref for State<T> {
574 type Target = T;
575
576 fn deref(&self) -> &Self::Target {
577 &self.0
578 }
579}
580
581#[derive(Debug, Clone)]
583pub struct Body(pub Bytes);
584
585impl FromRequest for Body {
586 async fn from_request(req: &mut Request) -> Result<Self> {
587 req.load_body().await?;
588 let body = req
589 .take_body()
590 .ok_or_else(|| ApiError::internal("Body already consumed"))?;
591 Ok(Body(body))
592 }
593}
594
595impl Deref for Body {
596 type Target = Bytes;
597
598 fn deref(&self) -> &Self::Target {
599 &self.0
600 }
601}
602
603pub struct BodyStream(pub StreamingBody);
605
606impl FromRequest for BodyStream {
607 async fn from_request(req: &mut Request) -> Result<Self> {
608 let config = StreamingConfig::default();
609
610 if let Some(stream) = req.take_stream() {
611 Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
612 } else if let Some(bytes) = req.take_body() {
613 let stream = futures_util::stream::once(async move { Ok(bytes) });
615 Ok(BodyStream(StreamingBody::from_stream(
616 stream,
617 config.max_body_size,
618 )))
619 } else {
620 Err(ApiError::internal("Body already consumed"))
621 }
622 }
623}
624
625impl Deref for BodyStream {
626 type Target = StreamingBody;
627
628 fn deref(&self) -> &Self::Target {
629 &self.0
630 }
631}
632
633impl DerefMut for BodyStream {
634 fn deref_mut(&mut self) -> &mut Self::Target {
635 &mut self.0
636 }
637}
638
639impl futures_util::Stream for BodyStream {
641 type Item = Result<Bytes, ApiError>;
642
643 fn poll_next(
644 mut self: std::pin::Pin<&mut Self>,
645 cx: &mut std::task::Context<'_>,
646 ) -> std::task::Poll<Option<Self::Item>> {
647 std::pin::Pin::new(&mut self.0).poll_next(cx)
648 }
649}
650
651impl<T: FromRequestParts> FromRequestParts for Option<T> {
655 fn from_request_parts(req: &Request) -> Result<Self> {
656 Ok(T::from_request_parts(req).ok())
657 }
658}
659
660#[derive(Debug, Clone)]
678pub struct Headers(pub http::HeaderMap);
679
680impl Headers {
681 pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
683 self.0.get(name)
684 }
685
686 pub fn contains(&self, name: &str) -> bool {
688 self.0.contains_key(name)
689 }
690
691 pub fn len(&self) -> usize {
693 self.0.len()
694 }
695
696 pub fn is_empty(&self) -> bool {
698 self.0.is_empty()
699 }
700
701 pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
703 self.0.iter()
704 }
705}
706
707impl FromRequestParts for Headers {
708 fn from_request_parts(req: &Request) -> Result<Self> {
709 Ok(Headers(req.headers().clone()))
710 }
711}
712
713impl Deref for Headers {
714 type Target = http::HeaderMap;
715
716 fn deref(&self) -> &Self::Target {
717 &self.0
718 }
719}
720
721#[derive(Debug, Clone)]
740pub struct HeaderValue(pub String, pub &'static str);
741
742impl HeaderValue {
743 pub fn new(name: &'static str, value: String) -> Self {
745 Self(value, name)
746 }
747
748 pub fn value(&self) -> &str {
750 &self.0
751 }
752
753 pub fn name(&self) -> &'static str {
755 self.1
756 }
757
758 pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
760 req.headers()
761 .get(name)
762 .and_then(|v| v.to_str().ok())
763 .map(|s| HeaderValue(s.to_string(), name))
764 .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
765 }
766}
767
768impl Deref for HeaderValue {
769 type Target = String;
770
771 fn deref(&self) -> &Self::Target {
772 &self.0
773 }
774}
775
776#[derive(Debug, Clone)]
794pub struct Extension<T>(pub T);
795
796impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
797 fn from_request_parts(req: &Request) -> Result<Self> {
798 req.extensions()
799 .get::<T>()
800 .cloned()
801 .map(Extension)
802 .ok_or_else(|| {
803 ApiError::internal(format!(
804 "Extension of type `{}` not found. Did middleware insert it?",
805 std::any::type_name::<T>()
806 ))
807 })
808 }
809}
810
811impl<T> Deref for Extension<T> {
812 type Target = T;
813
814 fn deref(&self) -> &Self::Target {
815 &self.0
816 }
817}
818
819impl<T> DerefMut for Extension<T> {
820 fn deref_mut(&mut self) -> &mut Self::Target {
821 &mut self.0
822 }
823}
824
825#[derive(Debug, Clone)]
840pub struct ClientIp(pub std::net::IpAddr);
841
842impl ClientIp {
843 pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
845 if trust_proxy {
846 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
848 if let Ok(forwarded_str) = forwarded.to_str() {
849 if let Some(first_ip) = forwarded_str.split(',').next() {
851 if let Ok(ip) = first_ip.trim().parse() {
852 return Ok(ClientIp(ip));
853 }
854 }
855 }
856 }
857 }
858
859 if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
861 return Ok(ClientIp(addr.ip()));
862 }
863
864 Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
866 127, 0, 0, 1,
867 ))))
868 }
869}
870
871impl FromRequestParts for ClientIp {
872 fn from_request_parts(req: &Request) -> Result<Self> {
873 Self::extract_with_config(req, true)
875 }
876}
877
878#[cfg(feature = "cookies")]
896#[derive(Debug, Clone)]
897pub struct Cookies(pub cookie::CookieJar);
898
899#[cfg(feature = "cookies")]
900impl Cookies {
901 pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
903 self.0.get(name)
904 }
905
906 pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
908 self.0.iter()
909 }
910
911 pub fn contains(&self, name: &str) -> bool {
913 self.0.get(name).is_some()
914 }
915}
916
917#[cfg(feature = "cookies")]
918impl FromRequestParts for Cookies {
919 fn from_request_parts(req: &Request) -> Result<Self> {
920 let mut jar = cookie::CookieJar::new();
921
922 if let Some(cookie_header) = req.headers().get(header::COOKIE) {
923 if let Ok(cookie_str) = cookie_header.to_str() {
924 for cookie_part in cookie_str.split(';') {
926 let trimmed = cookie_part.trim();
927 if !trimmed.is_empty() {
928 if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
929 jar.add_original(cookie.into_owned());
930 }
931 }
932 }
933 }
934 }
935
936 Ok(Cookies(jar))
937 }
938}
939
940#[cfg(feature = "cookies")]
941impl Deref for Cookies {
942 type Target = cookie::CookieJar;
943
944 fn deref(&self) -> &Self::Target {
945 &self.0
946 }
947}
948
949macro_rules! impl_from_request_parts_for_primitives {
951 ($($ty:ty),*) => {
952 $(
953 impl FromRequestParts for $ty {
954 fn from_request_parts(req: &Request) -> Result<Self> {
955 let Path(value) = Path::<$ty>::from_request_parts(req)?;
956 Ok(value)
957 }
958 }
959 )*
960 };
961}
962
963impl_from_request_parts_for_primitives!(
964 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
965);
966
967use rustapi_openapi::{
970 MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier, ResponseSpec,
971};
972
973impl<T: RustApiSchema> OperationModifier for ValidatedJson<T> {
975 fn update_operation(op: &mut Operation) {
976 let mut ctx = SchemaCtx::new();
977 let schema_ref = T::schema(&mut ctx);
978
979 let mut content = BTreeMap::new();
980 content.insert(
981 "application/json".to_string(),
982 MediaType {
983 schema: Some(schema_ref),
984 example: None,
985 },
986 );
987
988 op.request_body = Some(RequestBody {
989 description: None,
990 required: Some(true),
991 content,
992 });
993
994 let mut responses_content = BTreeMap::new();
996 responses_content.insert(
997 "application/json".to_string(),
998 MediaType {
999 schema: Some(SchemaRef::Ref {
1000 reference: "#/components/schemas/ValidationErrorSchema".to_string(),
1001 }),
1002 example: None,
1003 },
1004 );
1005
1006 op.responses.insert(
1007 "422".to_string(),
1008 ResponseSpec {
1009 description: "Validation Error".to_string(),
1010 content: responses_content,
1011 headers: BTreeMap::new(),
1012 },
1013 );
1014 }
1015}
1016
1017impl<T: RustApiSchema> OperationModifier for Json<T> {
1019 fn update_operation(op: &mut Operation) {
1020 let mut ctx = SchemaCtx::new();
1021 let schema_ref = T::schema(&mut ctx);
1022
1023 let mut content = BTreeMap::new();
1024 content.insert(
1025 "application/json".to_string(),
1026 MediaType {
1027 schema: Some(schema_ref),
1028 example: None,
1029 },
1030 );
1031
1032 op.request_body = Some(RequestBody {
1033 description: None,
1034 required: Some(true),
1035 content,
1036 });
1037 }
1038}
1039
1040impl<T> OperationModifier for Path<T> {
1042 fn update_operation(_op: &mut Operation) {}
1043}
1044
1045impl<T> OperationModifier for Typed<T> {
1047 fn update_operation(_op: &mut Operation) {}
1048}
1049
1050impl<T: RustApiSchema> OperationModifier for Query<T> {
1052 fn update_operation(op: &mut Operation) {
1053 let mut ctx = SchemaCtx::new();
1054 if let Some(fields) = T::field_schemas(&mut ctx) {
1055 let new_params: Vec<Parameter> = fields
1056 .into_iter()
1057 .map(|(name, schema)| {
1058 Parameter {
1059 name,
1060 location: "query".to_string(),
1061 required: false, deprecated: None,
1063 description: None,
1064 schema: Some(schema),
1065 }
1066 })
1067 .collect();
1068
1069 op.parameters.extend(new_params);
1070 }
1071 }
1072}
1073
1074impl<T> OperationModifier for State<T> {
1076 fn update_operation(_op: &mut Operation) {}
1077}
1078
1079impl OperationModifier for Body {
1081 fn update_operation(op: &mut Operation) {
1082 let mut content = BTreeMap::new();
1083 content.insert(
1084 "application/octet-stream".to_string(),
1085 MediaType {
1086 schema: Some(SchemaRef::Inline(
1087 serde_json::json!({ "type": "string", "format": "binary" }),
1088 )),
1089 example: None,
1090 },
1091 );
1092
1093 op.request_body = Some(RequestBody {
1094 description: None,
1095 required: Some(true),
1096 content,
1097 });
1098 }
1099}
1100
1101impl OperationModifier for BodyStream {
1103 fn update_operation(op: &mut Operation) {
1104 let mut content = BTreeMap::new();
1105 content.insert(
1106 "application/octet-stream".to_string(),
1107 MediaType {
1108 schema: Some(SchemaRef::Inline(
1109 serde_json::json!({ "type": "string", "format": "binary" }),
1110 )),
1111 example: None,
1112 },
1113 );
1114
1115 op.request_body = Some(RequestBody {
1116 description: None,
1117 required: Some(true),
1118 content,
1119 });
1120 }
1121}
1122
1123impl<T: RustApiSchema> ResponseModifier for Json<T> {
1127 fn update_response(op: &mut Operation) {
1128 let mut ctx = SchemaCtx::new();
1129 let schema_ref = T::schema(&mut ctx);
1130
1131 let mut content = BTreeMap::new();
1132 content.insert(
1133 "application/json".to_string(),
1134 MediaType {
1135 schema: Some(schema_ref),
1136 example: None,
1137 },
1138 );
1139
1140 op.responses.insert(
1141 "200".to_string(),
1142 ResponseSpec {
1143 description: "Successful response".to_string(),
1144 content,
1145 headers: BTreeMap::new(),
1146 },
1147 );
1148 }
1149}
1150
1151impl<T: RustApiSchema> RustApiSchema for Json<T> {
1154 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1155 T::schema(ctx)
1156 }
1157}
1158
1159impl<T: RustApiSchema> RustApiSchema for ValidatedJson<T> {
1160 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1161 T::schema(ctx)
1162 }
1163}
1164
1165impl<T: RustApiSchema> RustApiSchema for AsyncValidatedJson<T> {
1166 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1167 T::schema(ctx)
1168 }
1169}
1170
1171impl<T: RustApiSchema> RustApiSchema for Query<T> {
1172 fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1173 T::schema(ctx)
1174 }
1175 fn field_schemas(ctx: &mut SchemaCtx) -> Option<BTreeMap<String, SchemaRef>> {
1176 T::field_schemas(ctx)
1177 }
1178}
1179
1180const DEFAULT_PAGE: u64 = 1;
1184const DEFAULT_PER_PAGE: u64 = 20;
1186const MAX_PER_PAGE: u64 = 100;
1188
1189#[derive(Debug, Clone, Copy)]
1206pub struct Paginate {
1207 pub page: u64,
1209 pub per_page: u64,
1211}
1212
1213impl Paginate {
1214 pub fn new(page: u64, per_page: u64) -> Self {
1216 Self {
1217 page: page.max(1),
1218 per_page: per_page.clamp(1, MAX_PER_PAGE),
1219 }
1220 }
1221
1222 pub fn offset(&self) -> u64 {
1224 (self.page - 1) * self.per_page
1225 }
1226
1227 pub fn limit(&self) -> u64 {
1229 self.per_page
1230 }
1231
1232 pub fn paginate<T>(self, items: Vec<T>, total: u64) -> crate::hateoas::Paginated<T> {
1234 crate::hateoas::Paginated {
1235 items,
1236 page: self.page,
1237 per_page: self.per_page,
1238 total,
1239 }
1240 }
1241}
1242
1243impl Default for Paginate {
1244 fn default() -> Self {
1245 Self {
1246 page: DEFAULT_PAGE,
1247 per_page: DEFAULT_PER_PAGE,
1248 }
1249 }
1250}
1251
1252impl FromRequestParts for Paginate {
1253 fn from_request_parts(req: &Request) -> Result<Self> {
1254 let query = req.query_string().unwrap_or("");
1255
1256 #[derive(serde::Deserialize)]
1257 struct PaginateQuery {
1258 page: Option<u64>,
1259 per_page: Option<u64>,
1260 }
1261
1262 let params: PaginateQuery = serde_urlencoded::from_str(query).unwrap_or(PaginateQuery {
1263 page: None,
1264 per_page: None,
1265 });
1266
1267 Ok(Paginate::new(
1268 params.page.unwrap_or(DEFAULT_PAGE),
1269 params.per_page.unwrap_or(DEFAULT_PER_PAGE),
1270 ))
1271 }
1272}
1273
1274#[derive(Debug, Clone)]
1297pub struct CursorPaginate {
1298 pub cursor: Option<String>,
1300 pub per_page: u64,
1302}
1303
1304impl CursorPaginate {
1305 pub fn new(cursor: Option<String>, per_page: u64) -> Self {
1307 Self {
1308 cursor,
1309 per_page: per_page.clamp(1, MAX_PER_PAGE),
1310 }
1311 }
1312
1313 pub fn after(&self) -> Option<&str> {
1315 self.cursor.as_deref()
1316 }
1317
1318 pub fn limit(&self) -> u64 {
1320 self.per_page
1321 }
1322
1323 pub fn is_first_page(&self) -> bool {
1325 self.cursor.is_none()
1326 }
1327}
1328
1329impl Default for CursorPaginate {
1330 fn default() -> Self {
1331 Self {
1332 cursor: None,
1333 per_page: DEFAULT_PER_PAGE,
1334 }
1335 }
1336}
1337
1338impl FromRequestParts for CursorPaginate {
1339 fn from_request_parts(req: &Request) -> Result<Self> {
1340 let query = req.query_string().unwrap_or("");
1341
1342 #[derive(serde::Deserialize)]
1343 struct CursorQuery {
1344 cursor: Option<String>,
1345 limit: Option<u64>,
1346 }
1347
1348 let params: CursorQuery = serde_urlencoded::from_str(query).unwrap_or(CursorQuery {
1349 cursor: None,
1350 limit: None,
1351 });
1352
1353 Ok(CursorPaginate::new(
1354 params.cursor,
1355 params.limit.unwrap_or(DEFAULT_PER_PAGE),
1356 ))
1357 }
1358}
1359
1360#[cfg(test)]
1361mod tests {
1362 use super::*;
1363 use crate::path_params::PathParams;
1364 use bytes::Bytes;
1365 use http::{Extensions, Method};
1366 use proptest::prelude::*;
1367 use proptest::test_runner::TestCaseError;
1368 use std::sync::Arc;
1369
1370 fn create_test_request_with_headers(
1372 method: Method,
1373 path: &str,
1374 headers: Vec<(&str, &str)>,
1375 ) -> Request {
1376 let uri: http::Uri = path.parse().unwrap();
1377 let mut builder = http::Request::builder().method(method).uri(uri);
1378
1379 for (name, value) in headers {
1380 builder = builder.header(name, value);
1381 }
1382
1383 let req = builder.body(()).unwrap();
1384 let (parts, _) = req.into_parts();
1385
1386 Request::new(
1387 parts,
1388 crate::request::BodyVariant::Buffered(Bytes::new()),
1389 Arc::new(Extensions::new()),
1390 PathParams::new(),
1391 )
1392 }
1393
1394 fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
1396 method: Method,
1397 path: &str,
1398 extension: T,
1399 ) -> Request {
1400 let uri: http::Uri = path.parse().unwrap();
1401 let builder = http::Request::builder().method(method).uri(uri);
1402
1403 let req = builder.body(()).unwrap();
1404 let (mut parts, _) = req.into_parts();
1405 parts.extensions.insert(extension);
1406
1407 Request::new(
1408 parts,
1409 crate::request::BodyVariant::Buffered(Bytes::new()),
1410 Arc::new(Extensions::new()),
1411 PathParams::new(),
1412 )
1413 }
1414
1415 proptest! {
1422 #![proptest_config(ProptestConfig::with_cases(100))]
1423
1424 #[test]
1425 fn prop_headers_extractor_completeness(
1426 headers in prop::collection::vec(
1429 (
1430 "[a-z][a-z0-9-]{0,20}", "[a-zA-Z0-9 ]{1,50}" ),
1433 0..10
1434 )
1435 ) {
1436 let result: Result<(), TestCaseError> = (|| {
1437 let header_tuples: Vec<(&str, &str)> = headers
1439 .iter()
1440 .map(|(k, v)| (k.as_str(), v.as_str()))
1441 .collect();
1442
1443 let request = create_test_request_with_headers(
1445 Method::GET,
1446 "/test",
1447 header_tuples.clone(),
1448 );
1449
1450 let extracted = Headers::from_request_parts(&request)
1452 .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
1453
1454 for (name, value) in &headers {
1457 let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
1459 prop_assert!(
1460 !all_values.is_empty(),
1461 "Header '{}' not found",
1462 name
1463 );
1464
1465 let value_found = all_values.iter().any(|v| {
1467 v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
1468 });
1469
1470 prop_assert!(
1471 value_found,
1472 "Header '{}' value '{}' not found in extracted values",
1473 name,
1474 value
1475 );
1476 }
1477
1478 Ok(())
1479 })();
1480 result?;
1481 }
1482 }
1483
1484 proptest! {
1491 #![proptest_config(ProptestConfig::with_cases(100))]
1492
1493 #[test]
1494 fn prop_header_value_extractor_correctness(
1495 header_name in "[a-z][a-z0-9-]{0,20}",
1496 header_value in "[a-zA-Z0-9 ]{1,50}",
1497 has_header in prop::bool::ANY,
1498 ) {
1499 let result: Result<(), TestCaseError> = (|| {
1500 let headers = if has_header {
1501 vec![(header_name.as_str(), header_value.as_str())]
1502 } else {
1503 vec![]
1504 };
1505
1506 let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1507
1508 let test_header = "x-test-header";
1511 let request_with_known_header = if has_header {
1512 create_test_request_with_headers(
1513 Method::GET,
1514 "/test",
1515 vec![(test_header, header_value.as_str())],
1516 )
1517 } else {
1518 create_test_request_with_headers(Method::GET, "/test", vec![])
1519 };
1520
1521 let result = HeaderValue::extract(&request_with_known_header, test_header);
1522
1523 if has_header {
1524 let extracted = result
1525 .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1526 prop_assert_eq!(
1527 extracted.value(),
1528 header_value.as_str(),
1529 "Header value mismatch"
1530 );
1531 } else {
1532 prop_assert!(
1533 result.is_err(),
1534 "Expected error when header is missing"
1535 );
1536 }
1537
1538 Ok(())
1539 })();
1540 result?;
1541 }
1542 }
1543
1544 proptest! {
1551 #![proptest_config(ProptestConfig::with_cases(100))]
1552
1553 #[test]
1554 fn prop_client_ip_extractor_with_forwarding(
1555 forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1557 .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1558 socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1559 .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1560 has_forwarded_header in prop::bool::ANY,
1561 trust_proxy in prop::bool::ANY,
1562 ) {
1563 let result: Result<(), TestCaseError> = (|| {
1564 let headers = if has_forwarded_header {
1565 vec![("x-forwarded-for", forwarded_ip.as_str())]
1566 } else {
1567 vec![]
1568 };
1569
1570 let uri: http::Uri = "/test".parse().unwrap();
1572 let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1573 for (name, value) in &headers {
1574 builder = builder.header(*name, *value);
1575 }
1576 let req = builder.body(()).unwrap();
1577 let (mut parts, _) = req.into_parts();
1578
1579 let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1581 parts.extensions.insert(socket_addr);
1582
1583 let request = Request::new(
1584 parts,
1585 crate::request::BodyVariant::Buffered(Bytes::new()),
1586 Arc::new(Extensions::new()),
1587 PathParams::new(),
1588 );
1589
1590 let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1591 .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1592
1593 if trust_proxy && has_forwarded_header {
1594 let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1596 .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1597 prop_assert_eq!(
1598 extracted.0,
1599 expected_ip,
1600 "Should use X-Forwarded-For IP when trust_proxy is enabled"
1601 );
1602 } else {
1603 prop_assert_eq!(
1605 extracted.0,
1606 socket_ip,
1607 "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1608 );
1609 }
1610
1611 Ok(())
1612 })();
1613 result?;
1614 }
1615 }
1616
1617 proptest! {
1624 #![proptest_config(ProptestConfig::with_cases(100))]
1625
1626 #[test]
1627 fn prop_extension_extractor_retrieval(
1628 value in any::<i64>(),
1629 has_extension in prop::bool::ANY,
1630 ) {
1631 let result: Result<(), TestCaseError> = (|| {
1632 #[derive(Clone, Debug, PartialEq)]
1634 struct TestExtension(i64);
1635
1636 let uri: http::Uri = "/test".parse().unwrap();
1637 let builder = http::Request::builder().method(Method::GET).uri(uri);
1638 let req = builder.body(()).unwrap();
1639 let (mut parts, _) = req.into_parts();
1640
1641 if has_extension {
1642 parts.extensions.insert(TestExtension(value));
1643 }
1644
1645 let request = Request::new(
1646 parts,
1647 crate::request::BodyVariant::Buffered(Bytes::new()),
1648 Arc::new(Extensions::new()),
1649 PathParams::new(),
1650 );
1651
1652 let result = Extension::<TestExtension>::from_request_parts(&request);
1653
1654 if has_extension {
1655 let extracted = result
1656 .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1657 prop_assert_eq!(
1658 extracted.0,
1659 TestExtension(value),
1660 "Extension value mismatch"
1661 );
1662 } else {
1663 prop_assert!(
1664 result.is_err(),
1665 "Expected error when extension is missing"
1666 );
1667 }
1668
1669 Ok(())
1670 })();
1671 result?;
1672 }
1673 }
1674
1675 #[test]
1678 fn test_headers_extractor_basic() {
1679 let request = create_test_request_with_headers(
1680 Method::GET,
1681 "/test",
1682 vec![
1683 ("content-type", "application/json"),
1684 ("accept", "text/html"),
1685 ],
1686 );
1687
1688 let headers = Headers::from_request_parts(&request).unwrap();
1689
1690 assert!(headers.contains("content-type"));
1691 assert!(headers.contains("accept"));
1692 assert!(!headers.contains("x-custom"));
1693 assert_eq!(headers.len(), 2);
1694 }
1695
1696 #[test]
1697 fn test_header_value_extractor_present() {
1698 let request = create_test_request_with_headers(
1699 Method::GET,
1700 "/test",
1701 vec![("authorization", "Bearer token123")],
1702 );
1703
1704 let result = HeaderValue::extract(&request, "authorization");
1705 assert!(result.is_ok());
1706 assert_eq!(result.unwrap().value(), "Bearer token123");
1707 }
1708
1709 #[test]
1710 fn test_header_value_extractor_missing() {
1711 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1712
1713 let result = HeaderValue::extract(&request, "authorization");
1714 assert!(result.is_err());
1715 }
1716
1717 #[test]
1718 fn test_client_ip_from_forwarded_header() {
1719 let request = create_test_request_with_headers(
1720 Method::GET,
1721 "/test",
1722 vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1723 );
1724
1725 let ip = ClientIp::extract_with_config(&request, true).unwrap();
1726 assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1727 }
1728
1729 #[test]
1730 fn test_client_ip_ignores_forwarded_when_not_trusted() {
1731 let uri: http::Uri = "/test".parse().unwrap();
1732 let builder = http::Request::builder()
1733 .method(Method::GET)
1734 .uri(uri)
1735 .header("x-forwarded-for", "192.168.1.100");
1736 let req = builder.body(()).unwrap();
1737 let (mut parts, _) = req.into_parts();
1738
1739 let socket_addr = std::net::SocketAddr::new(
1740 std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1741 8080,
1742 );
1743 parts.extensions.insert(socket_addr);
1744
1745 let request = Request::new(
1746 parts,
1747 crate::request::BodyVariant::Buffered(Bytes::new()),
1748 Arc::new(Extensions::new()),
1749 PathParams::new(),
1750 );
1751
1752 let ip = ClientIp::extract_with_config(&request, false).unwrap();
1753 assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1754 }
1755
1756 #[test]
1757 fn test_extension_extractor_present() {
1758 #[derive(Clone, Debug, PartialEq)]
1759 struct MyData(String);
1760
1761 let request =
1762 create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1763
1764 let result = Extension::<MyData>::from_request_parts(&request);
1765 assert!(result.is_ok());
1766 assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1767 }
1768
1769 #[test]
1770 fn test_extension_extractor_missing() {
1771 #[derive(Clone, Debug)]
1772 #[allow(dead_code)]
1773 struct MyData(String);
1774
1775 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1776
1777 let result = Extension::<MyData>::from_request_parts(&request);
1778 assert!(result.is_err());
1779 }
1780
1781 #[cfg(feature = "cookies")]
1783 mod cookies_tests {
1784 use super::*;
1785
1786 proptest! {
1794 #![proptest_config(ProptestConfig::with_cases(100))]
1795
1796 #[test]
1797 fn prop_cookies_extractor_parsing(
1798 cookies in prop::collection::vec(
1801 (
1802 "[a-zA-Z][a-zA-Z0-9_]{0,15}", "[a-zA-Z0-9]{1,30}" ),
1805 0..5
1806 )
1807 ) {
1808 let result: Result<(), TestCaseError> = (|| {
1809 let cookie_header = cookies
1811 .iter()
1812 .map(|(name, value)| format!("{}={}", name, value))
1813 .collect::<Vec<_>>()
1814 .join("; ");
1815
1816 let headers = if !cookies.is_empty() {
1817 vec![("cookie", cookie_header.as_str())]
1818 } else {
1819 vec![]
1820 };
1821
1822 let request = create_test_request_with_headers(Method::GET, "/test", headers);
1823
1824 let extracted = Cookies::from_request_parts(&request)
1826 .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1827
1828 let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1830 for (name, value) in &cookies {
1831 expected_cookies.insert(name.as_str(), value.as_str());
1832 }
1833
1834 for (name, expected_value) in &expected_cookies {
1836 let cookie = extracted.get(name)
1837 .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1838
1839 prop_assert_eq!(
1840 cookie.value(),
1841 *expected_value,
1842 "Cookie '{}' value mismatch",
1843 name
1844 );
1845 }
1846
1847 let extracted_count = extracted.iter().count();
1849 prop_assert_eq!(
1850 extracted_count,
1851 expected_cookies.len(),
1852 "Expected {} unique cookies, got {}",
1853 expected_cookies.len(),
1854 extracted_count
1855 );
1856
1857 Ok(())
1858 })();
1859 result?;
1860 }
1861 }
1862
1863 #[test]
1864 fn test_cookies_extractor_basic() {
1865 let request = create_test_request_with_headers(
1866 Method::GET,
1867 "/test",
1868 vec![("cookie", "session=abc123; user=john")],
1869 );
1870
1871 let cookies = Cookies::from_request_parts(&request).unwrap();
1872
1873 assert!(cookies.contains("session"));
1874 assert!(cookies.contains("user"));
1875 assert!(!cookies.contains("other"));
1876
1877 assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1878 assert_eq!(cookies.get("user").unwrap().value(), "john");
1879 }
1880
1881 #[test]
1882 fn test_cookies_extractor_empty() {
1883 let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1884
1885 let cookies = Cookies::from_request_parts(&request).unwrap();
1886 assert_eq!(cookies.iter().count(), 0);
1887 }
1888
1889 #[test]
1890 fn test_cookies_extractor_single() {
1891 let request = create_test_request_with_headers(
1892 Method::GET,
1893 "/test",
1894 vec![("cookie", "token=xyz789")],
1895 );
1896
1897 let cookies = Cookies::from_request_parts(&request).unwrap();
1898 assert_eq!(cookies.iter().count(), 1);
1899 assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1900 }
1901 }
1902
1903 #[tokio::test]
1904 async fn test_async_validated_json_with_state_context() {
1905 use async_trait::async_trait;
1906 use rustapi_validate::prelude::*;
1907 use rustapi_validate::v2::{
1908 AsyncValidationRule, DatabaseValidator, ValidationContextBuilder,
1909 };
1910 use serde::{Deserialize, Serialize};
1911
1912 struct MockDbValidator {
1913 unique_values: Vec<String>,
1914 }
1915
1916 #[async_trait]
1917 impl DatabaseValidator for MockDbValidator {
1918 async fn exists(
1919 &self,
1920 _table: &str,
1921 _column: &str,
1922 _value: &str,
1923 ) -> Result<bool, String> {
1924 Ok(true)
1925 }
1926 async fn is_unique(
1927 &self,
1928 _table: &str,
1929 _column: &str,
1930 value: &str,
1931 ) -> Result<bool, String> {
1932 Ok(!self.unique_values.contains(&value.to_string()))
1933 }
1934 async fn is_unique_except(
1935 &self,
1936 _table: &str,
1937 _column: &str,
1938 value: &str,
1939 _except_id: &str,
1940 ) -> Result<bool, String> {
1941 Ok(!self.unique_values.contains(&value.to_string()))
1942 }
1943 }
1944
1945 #[derive(Debug, Deserialize, Serialize)]
1946 struct TestUser {
1947 email: String,
1948 }
1949
1950 impl Validate for TestUser {
1951 fn validate_with_group(
1952 &self,
1953 _group: rustapi_validate::v2::ValidationGroup,
1954 ) -> Result<(), rustapi_validate::v2::ValidationErrors> {
1955 Ok(())
1956 }
1957 }
1958
1959 #[async_trait]
1960 impl AsyncValidate for TestUser {
1961 async fn validate_async_with_group(
1962 &self,
1963 ctx: &ValidationContext,
1964 _group: rustapi_validate::v2::ValidationGroup,
1965 ) -> Result<(), rustapi_validate::v2::ValidationErrors> {
1966 let mut errors = rustapi_validate::v2::ValidationErrors::new();
1967
1968 let rule = AsyncUniqueRule::new("users", "email");
1969 if let Err(e) = rule.validate_async(&self.email, ctx).await {
1970 errors.add("email", e);
1971 }
1972
1973 errors.into_result()
1974 }
1975 }
1976
1977 let uri: http::Uri = "/test".parse().unwrap();
1979 let user = TestUser {
1980 email: "new@example.com".to_string(),
1981 };
1982 let body_bytes = serde_json::to_vec(&user).unwrap();
1983
1984 let builder = http::Request::builder()
1985 .method(Method::POST)
1986 .uri(uri.clone())
1987 .header("content-type", "application/json");
1988 let req = builder.body(()).unwrap();
1989 let (parts, _) = req.into_parts();
1990
1991 let mut request = Request::new(
1993 parts,
1994 crate::request::BodyVariant::Buffered(Bytes::from(body_bytes.clone())),
1995 Arc::new(Extensions::new()),
1996 PathParams::new(),
1997 );
1998
1999 let result = AsyncValidatedJson::<TestUser>::from_request(&mut request).await;
2000
2001 assert!(result.is_err(), "Expected error when validator is missing");
2002 let err = result.unwrap_err();
2003 let err_str = format!("{:?}", err);
2004 assert!(
2005 err_str.contains("Database validator not configured")
2006 || err_str.contains("async_unique"),
2007 "Error should mention missing configuration or rule: {:?}",
2008 err_str
2009 );
2010
2011 let db_validator = MockDbValidator {
2013 unique_values: vec!["taken@example.com".to_string()],
2014 };
2015 let ctx = ValidationContextBuilder::new()
2016 .database(db_validator)
2017 .build();
2018
2019 let mut extensions = Extensions::new();
2020 extensions.insert(ctx);
2021
2022 let builder = http::Request::builder()
2023 .method(Method::POST)
2024 .uri(uri.clone())
2025 .header("content-type", "application/json");
2026 let req = builder.body(()).unwrap();
2027 let (parts, _) = req.into_parts();
2028
2029 let mut request = Request::new(
2030 parts,
2031 crate::request::BodyVariant::Buffered(Bytes::from(body_bytes.clone())),
2032 Arc::new(extensions),
2033 PathParams::new(),
2034 );
2035
2036 let result = AsyncValidatedJson::<TestUser>::from_request(&mut request).await;
2037 assert!(
2038 result.is_ok(),
2039 "Expected success when validator is present and value is unique. Error: {:?}",
2040 result.err()
2041 );
2042
2043 let user_taken = TestUser {
2045 email: "taken@example.com".to_string(),
2046 };
2047 let body_taken = serde_json::to_vec(&user_taken).unwrap();
2048
2049 let db_validator = MockDbValidator {
2050 unique_values: vec!["taken@example.com".to_string()],
2051 };
2052 let ctx = ValidationContextBuilder::new()
2053 .database(db_validator)
2054 .build();
2055
2056 let mut extensions = Extensions::new();
2057 extensions.insert(ctx);
2058
2059 let builder = http::Request::builder()
2060 .method(Method::POST)
2061 .uri("/test")
2062 .header("content-type", "application/json");
2063 let req = builder.body(()).unwrap();
2064 let (parts, _) = req.into_parts();
2065
2066 let mut request = Request::new(
2067 parts,
2068 crate::request::BodyVariant::Buffered(Bytes::from(body_taken)),
2069 Arc::new(extensions),
2070 PathParams::new(),
2071 );
2072
2073 let result = AsyncValidatedJson::<TestUser>::from_request(&mut request).await;
2074 assert!(result.is_err(), "Expected validation error for taken email");
2075 }
2076}