silent/core/
request.rs

1#[cfg(feature = "multipart")]
2use crate::core::form::{FilePart, FormData};
3use crate::core::path_param::PathParam;
4use crate::core::req_body::ReqBody;
5#[cfg(feature = "multipart")]
6use crate::core::serde::from_str_multi_val;
7use crate::core::socket_addr::SocketAddr;
8use crate::header::CONTENT_TYPE;
9use crate::{Configs, Result, SilentError};
10use bytes::Bytes;
11use http::request::Parts;
12use http::{Extensions, HeaderMap, HeaderValue, Method, Uri, Version};
13use http::{Request as BaseRequest, StatusCode};
14use http_body_util::BodyExt;
15use mime::Mime;
16use serde::de::StdError;
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use std::collections::HashMap;
20use tokio::sync::OnceCell;
21use url::form_urlencoded;
22
23/// 请求体
24/// ```
25/// use silent::Request;
26/// let req = Request::empty();
27/// ```
28#[derive(Debug)]
29pub struct Request {
30    // req: BaseRequest<ReqBody>,
31    parts: Parts,
32    path_params: HashMap<String, PathParam>,
33    params: HashMap<String, String>,
34    body: ReqBody,
35    #[cfg(feature = "multipart")]
36    form_data: OnceCell<FormData>,
37    json_data: OnceCell<Value>,
38    pub(crate) configs: Configs,
39}
40
41impl Request {
42    /// 从http请求体创建请求
43    pub fn into_http(self) -> http::Request<ReqBody> {
44        http::Request::from_parts(self.parts, self.body)
45    }
46    /// Strip the request to [`hyper::Request`].
47    #[doc(hidden)]
48    pub fn strip_to_hyper<QB>(&mut self) -> Result<hyper::Request<QB>>
49    where
50        QB: TryFrom<ReqBody>,
51        <QB as TryFrom<ReqBody>>::Error: StdError + Send + Sync + 'static,
52    {
53        let mut builder = http::request::Builder::new()
54            .method(self.method().clone())
55            .uri(self.uri().clone())
56            .version(self.version());
57        if let Some(headers) = builder.headers_mut() {
58            *headers = std::mem::take(self.headers_mut());
59        }
60        if let Some(extensions) = builder.extensions_mut() {
61            *extensions = std::mem::take(self.extensions_mut());
62        }
63
64        let body = self.take_body();
65        builder
66            .body(body.try_into().map_err(|e| {
67                SilentError::business_error(
68                    StatusCode::INTERNAL_SERVER_ERROR,
69                    format!("request strip to hyper failed: {e}"),
70                )
71            })?)
72            .map_err(|e| SilentError::business_error(StatusCode::BAD_REQUEST, e.to_string()))
73    }
74    /// Strip the request to [`hyper::Request`].
75    #[doc(hidden)]
76    pub async fn strip_to_bytes_hyper(&mut self) -> Result<hyper::Request<Bytes>> {
77        let mut builder = http::request::Builder::new()
78            .method(self.method().clone())
79            .uri(self.uri().clone())
80            .version(self.version());
81        if let Some(headers) = builder.headers_mut() {
82            *headers = std::mem::take(self.headers_mut());
83        }
84        if let Some(extensions) = builder.extensions_mut() {
85            *extensions = std::mem::take(self.extensions_mut());
86        }
87
88        let mut body = self.take_body();
89        builder
90            .body(body.frame().await.unwrap()?.into_data().unwrap())
91            .map_err(|e| SilentError::business_error(StatusCode::BAD_REQUEST, e.to_string()))
92    }
93}
94
95impl Default for Request {
96    fn default() -> Self {
97        Self::empty()
98    }
99}
100
101impl Request {
102    /// 创建空请求体
103    pub fn empty() -> Self {
104        let (parts, _) = BaseRequest::builder()
105            .method("GET")
106            .body(())
107            .unwrap()
108            .into_parts();
109        Self {
110            // req: BaseRequest::builder()
111            //     .method("GET")
112            //     .body(().into())
113            //     .unwrap(),
114            parts,
115            path_params: HashMap::new(),
116            params: HashMap::new(),
117            body: ReqBody::Empty,
118            #[cfg(feature = "multipart")]
119            form_data: OnceCell::new(),
120            json_data: OnceCell::new(),
121            configs: Configs::default(),
122        }
123    }
124
125    /// 从请求体创建请求
126    #[inline]
127    pub fn from_parts(parts: Parts, body: ReqBody) -> Self {
128        Self {
129            parts,
130            body,
131            ..Self::default()
132        }
133    }
134
135    /// 获取访问真实地址
136    #[inline]
137    pub fn remote(&self) -> SocketAddr {
138        self.headers()
139            .get("x-real-ip")
140            .and_then(|h| h.to_str().ok())
141            .unwrap()
142            .parse()
143            .unwrap()
144    }
145
146    /// 设置访问真实地址
147    #[inline]
148    pub fn set_remote(&mut self, remote_addr: SocketAddr) {
149        if self.headers().get("x-real-ip").is_none() {
150            self.headers_mut()
151                .insert("x-real-ip", remote_addr.to_string().parse().unwrap());
152        }
153    }
154
155    /// 获取请求方法
156    #[inline]
157    pub fn method(&self) -> &Method {
158        &self.parts.method
159    }
160
161    /// 获取请求方法
162    #[inline]
163    pub fn method_mut(&mut self) -> &mut Method {
164        &mut self.parts.method
165    }
166    /// 获取请求uri
167    #[inline]
168    pub fn uri(&self) -> &Uri {
169        &self.parts.uri
170    }
171    /// 获取请求uri
172    #[inline]
173    pub fn uri_mut(&mut self) -> &mut Uri {
174        &mut self.parts.uri
175    }
176    /// 获取请求版本
177    #[inline]
178    pub fn version(&self) -> Version {
179        self.parts.version
180    }
181    /// 获取请求版本
182    #[inline]
183    pub fn version_mut(&mut self) -> &mut Version {
184        &mut self.parts.version
185    }
186    /// 获取请求头
187    #[inline]
188    pub fn headers(&self) -> &HeaderMap<HeaderValue> {
189        &self.parts.headers
190    }
191    /// 获取请求头
192    #[inline]
193    pub fn headers_mut(&mut self) -> &mut HeaderMap<HeaderValue> {
194        &mut self.parts.headers
195    }
196    /// 获取请求拓展
197    #[inline]
198    pub fn extensions(&self) -> &Extensions {
199        &self.parts.extensions
200    }
201    /// 获取请求拓展
202    #[inline]
203    pub fn extensions_mut(&mut self) -> &mut Extensions {
204        &mut self.parts.extensions
205    }
206    pub(crate) fn set_path_params(&mut self, key: String, value: PathParam) {
207        self.path_params.insert(key, value);
208    }
209
210    /// 获取配置
211    #[inline]
212    pub fn get_config<T: Send + Sync + 'static>(&self) -> Result<&T> {
213        self.configs.get::<T>().ok_or(SilentError::ConfigNotFound)
214    }
215
216    /// 获取配置(Uncheck)
217    #[inline]
218    pub fn get_config_uncheck<T: Send + Sync + 'static>(&self) -> &T {
219        self.configs.get::<T>().unwrap()
220    }
221
222    /// 获取全局配置
223    #[inline]
224    pub fn configs(&self) -> Configs {
225        self.configs.clone()
226    }
227
228    /// 获取可变全局配置
229    #[inline]
230    pub fn configs_mut(&mut self) -> &mut Configs {
231        &mut self.configs
232    }
233
234    /// 获取路径参数集合
235    pub fn path_params(&self) -> &HashMap<String, PathParam> {
236        &self.path_params
237    }
238
239    /// 获取路径参数
240    pub fn get_path_params<'a, T>(&'a self, key: &'a str) -> Result<T>
241    where
242        T: TryFrom<&'a PathParam, Error = SilentError>,
243    {
244        match self.path_params.get(key) {
245            Some(value) => value.try_into(),
246            None => Err(SilentError::ParamsNotFound),
247        }
248    }
249
250    /// 获取query参数
251    pub fn params(&mut self) -> &HashMap<String, String> {
252        if let Some(query) = self.uri().query() {
253            let params = form_urlencoded::parse(query.as_bytes())
254                .into_owned()
255                .collect::<HashMap<String, String>>();
256            self.params = params;
257        };
258        &self.params
259    }
260
261    /// 转换query参数
262    pub fn params_parse<T>(&mut self) -> Result<T>
263    where
264        for<'de> T: Deserialize<'de>,
265    {
266        let query = self.uri().query().unwrap_or("");
267        let params = serde_html_form::from_str(query)?;
268        Ok(params)
269    }
270
271    /// 获取请求body
272    #[inline]
273    pub fn replace_body(&mut self, body: ReqBody) -> ReqBody {
274        std::mem::replace(&mut self.body, body)
275    }
276
277    /// 获取请求body
278    #[inline]
279    pub fn take_body(&mut self) -> ReqBody {
280        self.replace_body(ReqBody::Empty)
281    }
282
283    /// 获取请求content_type
284    #[inline]
285    pub fn content_type(&self) -> Option<Mime> {
286        self.headers()
287            .get(CONTENT_TYPE)
288            .and_then(|h| h.to_str().ok())
289            .and_then(|v| v.parse().ok())
290    }
291
292    /// 获取请求form_data
293    #[cfg(feature = "multipart")]
294    #[inline]
295    pub async fn form_data(&mut self) -> Result<&FormData> {
296        let content_type = self
297            .content_type()
298            .ok_or(SilentError::ContentTypeMissingError)?;
299        if content_type.subtype() != mime::FORM_DATA {
300            return Err(SilentError::ContentTypeError);
301        }
302        let body = self.take_body();
303        let headers = self.headers();
304        self.form_data
305            .get_or_try_init(|| async { FormData::read(headers, body).await })
306            .await
307    }
308
309    /// 解析表单数据(支持 multipart/form-data 和 application/x-www-form-urlencoded)
310    pub async fn form_parse<T>(&mut self) -> Result<T>
311    where
312        for<'de> T: Deserialize<'de> + Serialize,
313    {
314        let content_type = self
315            .content_type()
316            .ok_or(SilentError::ContentTypeMissingError)?;
317
318        match content_type.subtype() {
319            #[cfg(feature = "multipart")]
320            mime::FORM_DATA => {
321                // 复用 form_data 的缓存机制
322                let form_data = self.form_data().await?;
323                let value =
324                    serde_json::to_value(form_data.fields.clone()).map_err(SilentError::from)?;
325                serde_json::from_value(value).map_err(Into::into)
326            }
327            mime::WWW_FORM_URLENCODED => {
328                // 检查是否已缓存到 json_data
329                if let Some(cached_value) = self.json_data.get() {
330                    return serde_json::from_value(cached_value.clone()).map_err(Into::into);
331                }
332
333                // 解析 form-urlencoded 数据并缓存到 json_data
334                let body = self.take_body();
335                let bytes = match body {
336                    ReqBody::Incoming(body) => body
337                        .collect()
338                        .await
339                        .or(Err(SilentError::BodyEmpty))?
340                        .to_bytes(),
341                    ReqBody::Once(bytes) => bytes,
342                    ReqBody::Empty => return Err(SilentError::BodyEmpty),
343                };
344
345                if bytes.is_empty() {
346                    return Err(SilentError::BodyEmpty);
347                }
348
349                // 先解析为目标类型
350                let parsed_data: T =
351                    serde_html_form::from_bytes(&bytes).map_err(SilentError::from)?;
352
353                // 转换为 Value 并缓存(需要重新导入 Serialize)
354                let value = serde_json::to_value(&parsed_data).map_err(SilentError::from)?;
355                let _ = self.json_data.set(value.clone());
356
357                // 直接返回已解析的数据
358                Ok(parsed_data)
359            }
360            _ => Err(SilentError::ContentTypeError),
361        }
362    }
363
364    /// 转换body参数
365    #[cfg(feature = "multipart")]
366    pub async fn form_field<T>(&mut self, key: &str) -> Option<T>
367    where
368        for<'de> T: Deserialize<'de>,
369    {
370        self.form_data()
371            .await
372            .ok()
373            .and_then(|ps| ps.fields.get_vec(key))
374            .and_then(|vs| from_str_multi_val(vs).ok())
375    }
376
377    /// 获取上传的文件
378    #[cfg(feature = "multipart")]
379    #[inline]
380    pub async fn files<'a>(&'a mut self, key: &'a str) -> Option<&'a Vec<FilePart>> {
381        self.form_data()
382            .await
383            .ok()
384            .and_then(|ps| ps.files.get_vec(key))
385    }
386
387    /// 解析 JSON 数据(仅支持 application/json)
388    pub async fn json_parse<T>(&mut self) -> Result<T>
389    where
390        for<'de> T: Deserialize<'de>,
391    {
392        // 检查是否已缓存
393        if let Some(cached_value) = self.json_data.get() {
394            return serde_json::from_value(cached_value.clone()).map_err(Into::into);
395        }
396
397        let content_type = self
398            .content_type()
399            .ok_or(SilentError::ContentTypeMissingError)?;
400
401        if content_type.subtype() != mime::JSON {
402            return Err(SilentError::ContentTypeError);
403        }
404
405        let body = self.take_body();
406        let bytes = match body {
407            ReqBody::Incoming(body) => body
408                .collect()
409                .await
410                .or(Err(SilentError::JsonEmpty))?
411                .to_bytes(),
412            ReqBody::Once(bytes) => bytes,
413            ReqBody::Empty => return Err(SilentError::JsonEmpty),
414        };
415
416        if bytes.is_empty() {
417            return Err(SilentError::JsonEmpty);
418        }
419
420        let value: Value = serde_json::from_slice(&bytes).map_err(SilentError::from)?;
421
422        // 缓存结果
423        let _ = self.json_data.set(value.clone());
424
425        serde_json::from_value(value).map_err(Into::into)
426    }
427
428    /// 转换body参数按Json匹配
429    pub async fn json_field<T>(&mut self, key: &str) -> Result<T>
430    where
431        for<'de> T: Deserialize<'de>,
432    {
433        let value: Value = self.json_parse().await?;
434        serde_json::from_value(
435            value
436                .get(key)
437                .ok_or(SilentError::ParamsNotFound)?
438                .to_owned(),
439        )
440        .map_err(Into::into)
441    }
442
443    /// 获取请求body
444    #[inline]
445    pub fn replace_extensions(&mut self, extensions: Extensions) -> Extensions {
446        std::mem::replace(self.extensions_mut(), extensions)
447    }
448
449    /// 获取请求body
450    #[inline]
451    pub fn take_extensions(&mut self) -> Extensions {
452        self.replace_extensions(Extensions::default())
453    }
454
455    /// 分割请求体与url
456    pub(crate) fn split_url(self) -> (Self, String) {
457        let url = self.uri().path().to_string();
458        (self, url)
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[derive(Deserialize, Debug, PartialEq)]
467    struct TestStruct {
468        a: i32,
469        b: String,
470        #[serde(default, alias = "c[]")]
471        c: Vec<String>,
472    }
473
474    #[test]
475    fn test_query_parse_alias() {
476        let mut req = Request::empty();
477        *req.uri_mut() = Uri::from_static("http://localhost:8080/test?a=1&b=2&c[]=3&c[]=4");
478        let _ = req.params_parse::<TestStruct>().unwrap();
479    }
480
481    #[test]
482    fn test_query_parse() {
483        let mut req = Request::empty();
484        *req.uri_mut() = Uri::from_static("http://localhost:8080/test?a=1&b=2&c=3&c=4");
485        let _ = req.params_parse::<TestStruct>().unwrap();
486    }
487
488    /// 测试 json_parse 和 form_parse 的语义分离
489    #[tokio::test]
490    async fn test_methods_semantic_separation() {
491        // 测试数据结构,现在需要 Serialize 和 Deserialize
492        #[derive(Deserialize, Serialize, Debug, PartialEq)]
493        struct TestData {
494            name: String,
495            age: u32,
496        }
497
498        let test_data = TestData {
499            name: "Alice".to_string(),
500            age: 25,
501        };
502
503        // 1. json_parse 正确处理 JSON 数据
504        let json_body = r#"{"name":"Alice","age":25}"#.as_bytes().to_vec();
505        let mut req = create_request_with_body("application/json", json_body);
506
507        let parsed_data = req
508            .json_parse::<TestData>()
509            .await
510            .expect("json_parse should successfully parse JSON data");
511        assert_eq!(parsed_data.name, test_data.name);
512        assert_eq!(parsed_data.age, test_data.age);
513
514        // 2. form_parse 正确处理 form-urlencoded 数据
515        let form_body = "name=Alice&age=25".as_bytes().to_vec();
516        let mut req = create_request_with_body("application/x-www-form-urlencoded", form_body);
517
518        let parsed_data = req
519            .form_parse::<TestData>()
520            .await
521            .expect("form_parse should successfully parse form-urlencoded data");
522        assert_eq!(parsed_data.name, test_data.name);
523        assert_eq!(parsed_data.age, test_data.age);
524
525        // 3. json_parse 拒绝 form-urlencoded 数据
526        let form_body = "name=Alice&age=25".as_bytes().to_vec();
527        let mut req = create_request_with_body("application/x-www-form-urlencoded", form_body);
528
529        let result = req.json_parse::<TestData>().await;
530        assert!(
531            result.is_err(),
532            "json_parse should reject form-urlencoded data"
533        );
534
535        // 4. form_parse 拒绝 JSON 数据
536        let json_body = r#"{"name":"Alice","age":25}"#.as_bytes().to_vec();
537        let mut req = create_request_with_body("application/json", json_body);
538
539        let result = req.form_parse::<TestData>().await;
540        assert!(result.is_err(), "form_parse should reject JSON data");
541    }
542
543    /// 测试 WWW_FORM_URLENCODED 数据缓存到 json_data 字段
544    #[tokio::test]
545    async fn test_form_urlencoded_caches_to_json_data() {
546        #[derive(Deserialize, Serialize, Debug, PartialEq)]
547        struct TestData {
548            name: String,
549            age: u32,
550        }
551
552        // 创建一个 form-urlencoded 请求
553        let form_body = "name=Alice&age=25".as_bytes().to_vec();
554        let mut req = create_request_with_body("application/x-www-form-urlencoded", form_body);
555
556        // 第一次调用 form_parse,应该解析数据并缓存到 json_data
557        let first_result = req
558            .form_parse::<TestData>()
559            .await
560            .expect("First form_parse call should succeed");
561
562        // 验证 json_data 字段已被缓存
563        assert!(
564            req.json_data.get().is_some(),
565            "json_data should be cached after form_parse"
566        );
567
568        // 第二次调用应该从缓存中获取(不会再次解析 body)
569        let second_result = req
570            .form_parse::<TestData>()
571            .await
572            .expect("Second form_parse call should use cached data");
573
574        // 两次结果应该相同
575        assert_eq!(first_result.name, second_result.name);
576        assert_eq!(first_result.age, second_result.age);
577        assert_eq!(first_result.name, "Alice");
578        assert_eq!(first_result.age, 25);
579    }
580
581    /// 测试共享缓存机制(验证 form_parse 复用 form_data 缓存)
582    #[cfg(feature = "multipart")]
583    #[tokio::test]
584    async fn test_shared_cache_mechanism() {
585        // 简单验证:当 Content-Type 是 multipart/form-data 时,
586        // form_parse 会调用 form_data() 方法,从而复用其缓存
587        let mut req = Request::empty();
588        req.headers_mut().insert(
589            "content-type",
590            HeaderValue::from_str("multipart/form-data; boundary=----formdata").unwrap(),
591        );
592
593        // 设置一个空的 body 来避免实际的 multipart 解析
594        req.body = ReqBody::Empty;
595
596        // 尝试调用 form_parse,它应该尝试使用 form_data() 方法
597        // 这个测试主要验证代码路径,而不是具体的数据解析
598        #[derive(Deserialize, Serialize, Debug)]
599        struct TestData {
600            name: String,
601        }
602
603        let result = req.form_parse::<TestData>().await;
604        // 预期会失败,因为我们没有提供真实的 multipart 数据
605        // 但重要的是代码走了正确的路径(调用 form_data())
606        assert!(
607            result.is_err(),
608            "Should fail due to empty body, but went through correct code path"
609        );
610    }
611
612    /// 辅助函数:创建带有指定内容类型和内容的请求
613    fn create_request_with_body(content_type: &str, body: Vec<u8>) -> Request {
614        let mut req = Request::empty();
615        req.headers_mut()
616            .insert("content-type", HeaderValue::from_str(content_type).unwrap());
617        req.body = ReqBody::Once(body.into());
618        req
619    }
620}