1#[cfg(feature = "multipart")]
2use crate::core::form::{FilePart, FormData};
3use crate::core::path_param::PathParam;
4use crate::core::req_body::ReqBody;
5#[cfg(feature = "multipart")]
6use crate::core::serde::from_str_multi_val;
7use crate::core::socket_addr::SocketAddr;
8use crate::header::CONTENT_TYPE;
9use crate::{Configs, Result, SilentError};
10use bytes::Bytes;
11use http::request::Parts;
12use http::{Extensions, HeaderMap, HeaderValue, Method, Uri, Version};
13use http::{Request as BaseRequest, StatusCode};
14use http_body_util::BodyExt;
15use mime::Mime;
16use serde::de::StdError;
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use std::collections::HashMap;
20use tokio::sync::OnceCell;
21use url::form_urlencoded;
22
23#[derive(Debug)]
29pub struct Request {
30 parts: Parts,
32 path_params: HashMap<String, PathParam>,
33 params: HashMap<String, String>,
34 body: ReqBody,
35 #[cfg(feature = "multipart")]
36 form_data: OnceCell<FormData>,
37 json_data: OnceCell<Value>,
38 pub(crate) configs: Configs,
39}
40
41impl Request {
42 pub fn into_http(self) -> http::Request<ReqBody> {
44 http::Request::from_parts(self.parts, self.body)
45 }
46 #[doc(hidden)]
48 pub fn strip_to_hyper<QB>(&mut self) -> Result<hyper::Request<QB>>
49 where
50 QB: TryFrom<ReqBody>,
51 <QB as TryFrom<ReqBody>>::Error: StdError + Send + Sync + 'static,
52 {
53 let mut builder = http::request::Builder::new()
54 .method(self.method().clone())
55 .uri(self.uri().clone())
56 .version(self.version());
57 if let Some(headers) = builder.headers_mut() {
58 *headers = std::mem::take(self.headers_mut());
59 }
60 if let Some(extensions) = builder.extensions_mut() {
61 *extensions = std::mem::take(self.extensions_mut());
62 }
63
64 let body = self.take_body();
65 builder
66 .body(body.try_into().map_err(|e| {
67 SilentError::business_error(
68 StatusCode::INTERNAL_SERVER_ERROR,
69 format!("request strip to hyper failed: {e}"),
70 )
71 })?)
72 .map_err(|e| SilentError::business_error(StatusCode::BAD_REQUEST, e.to_string()))
73 }
74 #[doc(hidden)]
76 pub async fn strip_to_bytes_hyper(&mut self) -> Result<hyper::Request<Bytes>> {
77 let mut builder = http::request::Builder::new()
78 .method(self.method().clone())
79 .uri(self.uri().clone())
80 .version(self.version());
81 if let Some(headers) = builder.headers_mut() {
82 *headers = std::mem::take(self.headers_mut());
83 }
84 if let Some(extensions) = builder.extensions_mut() {
85 *extensions = std::mem::take(self.extensions_mut());
86 }
87
88 let mut body = self.take_body();
89 builder
90 .body(body.frame().await.unwrap()?.into_data().unwrap())
91 .map_err(|e| SilentError::business_error(StatusCode::BAD_REQUEST, e.to_string()))
92 }
93}
94
95impl Default for Request {
96 fn default() -> Self {
97 Self::empty()
98 }
99}
100
101impl Request {
102 pub fn empty() -> Self {
104 let (parts, _) = BaseRequest::builder()
105 .method("GET")
106 .body(())
107 .unwrap()
108 .into_parts();
109 Self {
110 parts,
115 path_params: HashMap::new(),
116 params: HashMap::new(),
117 body: ReqBody::Empty,
118 #[cfg(feature = "multipart")]
119 form_data: OnceCell::new(),
120 json_data: OnceCell::new(),
121 configs: Configs::default(),
122 }
123 }
124
125 #[inline]
127 pub fn from_parts(parts: Parts, body: ReqBody) -> Self {
128 Self {
129 parts,
130 body,
131 ..Self::default()
132 }
133 }
134
135 #[inline]
137 pub fn remote(&self) -> SocketAddr {
138 self.headers()
139 .get("x-real-ip")
140 .and_then(|h| h.to_str().ok())
141 .unwrap()
142 .parse()
143 .unwrap()
144 }
145
146 #[inline]
148 pub fn set_remote(&mut self, remote_addr: SocketAddr) {
149 if self.headers().get("x-real-ip").is_none() {
150 self.headers_mut()
151 .insert("x-real-ip", remote_addr.to_string().parse().unwrap());
152 }
153 }
154
155 #[inline]
157 pub fn method(&self) -> &Method {
158 &self.parts.method
159 }
160
161 #[inline]
163 pub fn method_mut(&mut self) -> &mut Method {
164 &mut self.parts.method
165 }
166 #[inline]
168 pub fn uri(&self) -> &Uri {
169 &self.parts.uri
170 }
171 #[inline]
173 pub fn uri_mut(&mut self) -> &mut Uri {
174 &mut self.parts.uri
175 }
176 #[inline]
178 pub fn version(&self) -> Version {
179 self.parts.version
180 }
181 #[inline]
183 pub fn version_mut(&mut self) -> &mut Version {
184 &mut self.parts.version
185 }
186 #[inline]
188 pub fn headers(&self) -> &HeaderMap<HeaderValue> {
189 &self.parts.headers
190 }
191 #[inline]
193 pub fn headers_mut(&mut self) -> &mut HeaderMap<HeaderValue> {
194 &mut self.parts.headers
195 }
196 #[inline]
198 pub fn extensions(&self) -> &Extensions {
199 &self.parts.extensions
200 }
201 #[inline]
203 pub fn extensions_mut(&mut self) -> &mut Extensions {
204 &mut self.parts.extensions
205 }
206 pub(crate) fn set_path_params(&mut self, key: String, value: PathParam) {
207 self.path_params.insert(key, value);
208 }
209
210 #[inline]
212 pub fn get_config<T: Send + Sync + 'static>(&self) -> Result<&T> {
213 self.configs.get::<T>().ok_or(SilentError::ConfigNotFound)
214 }
215
216 #[inline]
218 pub fn get_config_uncheck<T: Send + Sync + 'static>(&self) -> &T {
219 self.configs.get::<T>().unwrap()
220 }
221
222 #[inline]
224 pub fn configs(&self) -> Configs {
225 self.configs.clone()
226 }
227
228 #[inline]
230 pub fn configs_mut(&mut self) -> &mut Configs {
231 &mut self.configs
232 }
233
234 pub fn path_params(&self) -> &HashMap<String, PathParam> {
236 &self.path_params
237 }
238
239 pub fn get_path_params<'a, T>(&'a self, key: &'a str) -> Result<T>
241 where
242 T: TryFrom<&'a PathParam, Error = SilentError>,
243 {
244 match self.path_params.get(key) {
245 Some(value) => value.try_into(),
246 None => Err(SilentError::ParamsNotFound),
247 }
248 }
249
250 pub fn params(&mut self) -> &HashMap<String, String> {
252 if let Some(query) = self.uri().query() {
253 let params = form_urlencoded::parse(query.as_bytes())
254 .into_owned()
255 .collect::<HashMap<String, String>>();
256 self.params = params;
257 };
258 &self.params
259 }
260
261 pub fn params_parse<T>(&mut self) -> Result<T>
263 where
264 for<'de> T: Deserialize<'de>,
265 {
266 let query = self.uri().query().unwrap_or("");
267 let params = serde_html_form::from_str(query)?;
268 Ok(params)
269 }
270
271 #[inline]
273 pub fn replace_body(&mut self, body: ReqBody) -> ReqBody {
274 std::mem::replace(&mut self.body, body)
275 }
276
277 #[inline]
279 pub fn take_body(&mut self) -> ReqBody {
280 self.replace_body(ReqBody::Empty)
281 }
282
283 #[inline]
285 pub fn content_type(&self) -> Option<Mime> {
286 self.headers()
287 .get(CONTENT_TYPE)
288 .and_then(|h| h.to_str().ok())
289 .and_then(|v| v.parse().ok())
290 }
291
292 #[cfg(feature = "multipart")]
294 #[inline]
295 pub async fn form_data(&mut self) -> Result<&FormData> {
296 let content_type = self
297 .content_type()
298 .ok_or(SilentError::ContentTypeMissingError)?;
299 if content_type.subtype() != mime::FORM_DATA {
300 return Err(SilentError::ContentTypeError);
301 }
302 let body = self.take_body();
303 let headers = self.headers();
304 self.form_data
305 .get_or_try_init(|| async { FormData::read(headers, body).await })
306 .await
307 }
308
309 pub async fn form_parse<T>(&mut self) -> Result<T>
311 where
312 for<'de> T: Deserialize<'de> + Serialize,
313 {
314 let content_type = self
315 .content_type()
316 .ok_or(SilentError::ContentTypeMissingError)?;
317
318 match content_type.subtype() {
319 #[cfg(feature = "multipart")]
320 mime::FORM_DATA => {
321 let form_data = self.form_data().await?;
323 let value =
324 serde_json::to_value(form_data.fields.clone()).map_err(SilentError::from)?;
325 serde_json::from_value(value).map_err(Into::into)
326 }
327 mime::WWW_FORM_URLENCODED => {
328 if let Some(cached_value) = self.json_data.get() {
330 return serde_json::from_value(cached_value.clone()).map_err(Into::into);
331 }
332
333 let body = self.take_body();
335 let bytes = match body {
336 ReqBody::Incoming(body) => body
337 .collect()
338 .await
339 .or(Err(SilentError::BodyEmpty))?
340 .to_bytes(),
341 ReqBody::Once(bytes) => bytes,
342 ReqBody::Empty => return Err(SilentError::BodyEmpty),
343 };
344
345 if bytes.is_empty() {
346 return Err(SilentError::BodyEmpty);
347 }
348
349 let parsed_data: T =
351 serde_html_form::from_bytes(&bytes).map_err(SilentError::from)?;
352
353 let value = serde_json::to_value(&parsed_data).map_err(SilentError::from)?;
355 let _ = self.json_data.set(value.clone());
356
357 Ok(parsed_data)
359 }
360 _ => Err(SilentError::ContentTypeError),
361 }
362 }
363
364 #[cfg(feature = "multipart")]
366 pub async fn form_field<T>(&mut self, key: &str) -> Option<T>
367 where
368 for<'de> T: Deserialize<'de>,
369 {
370 self.form_data()
371 .await
372 .ok()
373 .and_then(|ps| ps.fields.get_vec(key))
374 .and_then(|vs| from_str_multi_val(vs).ok())
375 }
376
377 #[cfg(feature = "multipart")]
379 #[inline]
380 pub async fn files<'a>(&'a mut self, key: &'a str) -> Option<&'a Vec<FilePart>> {
381 self.form_data()
382 .await
383 .ok()
384 .and_then(|ps| ps.files.get_vec(key))
385 }
386
387 pub async fn json_parse<T>(&mut self) -> Result<T>
389 where
390 for<'de> T: Deserialize<'de>,
391 {
392 if let Some(cached_value) = self.json_data.get() {
394 return serde_json::from_value(cached_value.clone()).map_err(Into::into);
395 }
396
397 let content_type = self
398 .content_type()
399 .ok_or(SilentError::ContentTypeMissingError)?;
400
401 if content_type.subtype() != mime::JSON {
402 return Err(SilentError::ContentTypeError);
403 }
404
405 let body = self.take_body();
406 let bytes = match body {
407 ReqBody::Incoming(body) => body
408 .collect()
409 .await
410 .or(Err(SilentError::JsonEmpty))?
411 .to_bytes(),
412 ReqBody::Once(bytes) => bytes,
413 ReqBody::Empty => return Err(SilentError::JsonEmpty),
414 };
415
416 if bytes.is_empty() {
417 return Err(SilentError::JsonEmpty);
418 }
419
420 let value: Value = serde_json::from_slice(&bytes).map_err(SilentError::from)?;
421
422 let _ = self.json_data.set(value.clone());
424
425 serde_json::from_value(value).map_err(Into::into)
426 }
427
428 pub async fn json_field<T>(&mut self, key: &str) -> Result<T>
430 where
431 for<'de> T: Deserialize<'de>,
432 {
433 let value: Value = self.json_parse().await?;
434 serde_json::from_value(
435 value
436 .get(key)
437 .ok_or(SilentError::ParamsNotFound)?
438 .to_owned(),
439 )
440 .map_err(Into::into)
441 }
442
443 #[inline]
445 pub fn replace_extensions(&mut self, extensions: Extensions) -> Extensions {
446 std::mem::replace(self.extensions_mut(), extensions)
447 }
448
449 #[inline]
451 pub fn take_extensions(&mut self) -> Extensions {
452 self.replace_extensions(Extensions::default())
453 }
454
455 pub(crate) fn split_url(self) -> (Self, String) {
457 let url = self.uri().path().to_string();
458 (self, url)
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[derive(Deserialize, Debug, PartialEq)]
467 struct TestStruct {
468 a: i32,
469 b: String,
470 #[serde(default, alias = "c[]")]
471 c: Vec<String>,
472 }
473
474 #[test]
475 fn test_query_parse_alias() {
476 let mut req = Request::empty();
477 *req.uri_mut() = Uri::from_static("http://localhost:8080/test?a=1&b=2&c[]=3&c[]=4");
478 let _ = req.params_parse::<TestStruct>().unwrap();
479 }
480
481 #[test]
482 fn test_query_parse() {
483 let mut req = Request::empty();
484 *req.uri_mut() = Uri::from_static("http://localhost:8080/test?a=1&b=2&c=3&c=4");
485 let _ = req.params_parse::<TestStruct>().unwrap();
486 }
487
488 #[tokio::test]
490 async fn test_methods_semantic_separation() {
491 #[derive(Deserialize, Serialize, Debug, PartialEq)]
493 struct TestData {
494 name: String,
495 age: u32,
496 }
497
498 let test_data = TestData {
499 name: "Alice".to_string(),
500 age: 25,
501 };
502
503 let json_body = r#"{"name":"Alice","age":25}"#.as_bytes().to_vec();
505 let mut req = create_request_with_body("application/json", json_body);
506
507 let parsed_data = req
508 .json_parse::<TestData>()
509 .await
510 .expect("json_parse should successfully parse JSON data");
511 assert_eq!(parsed_data.name, test_data.name);
512 assert_eq!(parsed_data.age, test_data.age);
513
514 let form_body = "name=Alice&age=25".as_bytes().to_vec();
516 let mut req = create_request_with_body("application/x-www-form-urlencoded", form_body);
517
518 let parsed_data = req
519 .form_parse::<TestData>()
520 .await
521 .expect("form_parse should successfully parse form-urlencoded data");
522 assert_eq!(parsed_data.name, test_data.name);
523 assert_eq!(parsed_data.age, test_data.age);
524
525 let form_body = "name=Alice&age=25".as_bytes().to_vec();
527 let mut req = create_request_with_body("application/x-www-form-urlencoded", form_body);
528
529 let result = req.json_parse::<TestData>().await;
530 assert!(
531 result.is_err(),
532 "json_parse should reject form-urlencoded data"
533 );
534
535 let json_body = r#"{"name":"Alice","age":25}"#.as_bytes().to_vec();
537 let mut req = create_request_with_body("application/json", json_body);
538
539 let result = req.form_parse::<TestData>().await;
540 assert!(result.is_err(), "form_parse should reject JSON data");
541 }
542
543 #[tokio::test]
545 async fn test_form_urlencoded_caches_to_json_data() {
546 #[derive(Deserialize, Serialize, Debug, PartialEq)]
547 struct TestData {
548 name: String,
549 age: u32,
550 }
551
552 let form_body = "name=Alice&age=25".as_bytes().to_vec();
554 let mut req = create_request_with_body("application/x-www-form-urlencoded", form_body);
555
556 let first_result = req
558 .form_parse::<TestData>()
559 .await
560 .expect("First form_parse call should succeed");
561
562 assert!(
564 req.json_data.get().is_some(),
565 "json_data should be cached after form_parse"
566 );
567
568 let second_result = req
570 .form_parse::<TestData>()
571 .await
572 .expect("Second form_parse call should use cached data");
573
574 assert_eq!(first_result.name, second_result.name);
576 assert_eq!(first_result.age, second_result.age);
577 assert_eq!(first_result.name, "Alice");
578 assert_eq!(first_result.age, 25);
579 }
580
581 #[cfg(feature = "multipart")]
583 #[tokio::test]
584 async fn test_shared_cache_mechanism() {
585 let mut req = Request::empty();
588 req.headers_mut().insert(
589 "content-type",
590 HeaderValue::from_str("multipart/form-data; boundary=----formdata").unwrap(),
591 );
592
593 req.body = ReqBody::Empty;
595
596 #[derive(Deserialize, Serialize, Debug)]
599 struct TestData {
600 name: String,
601 }
602
603 let result = req.form_parse::<TestData>().await;
604 assert!(
607 result.is_err(),
608 "Should fail due to empty body, but went through correct code path"
609 );
610 }
611
612 fn create_request_with_body(content_type: &str, body: Vec<u8>) -> Request {
614 let mut req = Request::empty();
615 req.headers_mut()
616 .insert("content-type", HeaderValue::from_str(content_type).unwrap());
617 req.body = ReqBody::Once(body.into());
618 req
619 }
620}