Skip to main content

wae_https/extract/
mod.rs

1#![doc = include_str!("readme.md")]
2
3use http::{Method, Uri, Version, header::HeaderMap};
4use serde::de::DeserializeOwned;
5use std::fmt;
6
7/// 请求上下文
8///
9/// 封装 HTTP 请求的原始状态,用于提取器访问请求数据。
10#[derive(Debug)]
11pub struct RequestParts {
12    /// HTTP 方法
13    pub method: Method,
14    /// 请求 URI
15    pub uri: Uri,
16    /// HTTP 版本
17    pub version: Version,
18    /// 请求头
19    pub headers: HeaderMap,
20    /// 路径参数
21    pub path_params: Vec<(String, String)>,
22}
23
24impl RequestParts {
25    /// 创建新的请求上下文
26    pub fn new(method: Method, uri: Uri, version: Version, headers: HeaderMap) -> Self {
27        Self { method, uri, version, headers, path_params: Vec::new() }
28    }
29}
30
31/// 从请求中提取数据的 trait
32///
33/// 类似于 Axum 的 FromRequestParts trait,用于从 HTTP 请求中提取数据。
34pub trait FromRequestParts<S>: Sized {
35    /// 提取过程中可能发生的错误
36    type Error;
37
38    /// 从请求中提取数据
39    ///
40    /// # 参数
41    ///
42    /// * `parts` - 请求上下文
43    /// * `state` - 应用状态
44    fn from_request_parts(parts: &RequestParts, state: &S) -> Result<Self, Self::Error>;
45}
46
47/// Extractor 错误类型
48///
49/// 统一封装所有提取器可能产生的错误,便于错误处理和响应。
50#[derive(Debug)]
51pub enum ExtractorError {
52    /// 路径参数提取错误
53    PathRejection(String),
54
55    /// 查询参数提取错误
56    QueryRejection(String),
57
58    /// JSON 提取错误
59    JsonRejection(String),
60
61    /// 表单数据提取错误
62    FormRejection(String),
63
64    /// 扩展数据提取错误
65    ExtensionRejection(String),
66
67    /// Host 提取错误
68    HostRejection(String),
69
70    /// TypedHeader 提取错误
71    TypedHeaderRejection(String),
72
73    /// Cookie 提取错误
74    CookieRejection(String),
75
76    /// Multipart 提取错误
77    MultipartRejection(String),
78
79    /// WebSocketUpgrade 提取错误
80    WebSocketUpgradeRejection(String),
81
82    /// Bytes 提取错误
83    BytesRejection(String),
84
85    /// Stream 提取错误
86    StreamRejection(String),
87
88    /// 自定义错误消息
89    Custom(String),
90}
91
92impl fmt::Display for ExtractorError {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        match self {
95            ExtractorError::PathRejection(e) => write!(f, "Path extraction failed: {}", e),
96            ExtractorError::QueryRejection(e) => write!(f, "Query extraction failed: {}", e),
97            ExtractorError::JsonRejection(e) => write!(f, "Json extraction failed: {}", e),
98            ExtractorError::FormRejection(e) => write!(f, "Form extraction failed: {}", e),
99            ExtractorError::ExtensionRejection(e) => write!(f, "Extension extraction failed: {}", e),
100            ExtractorError::HostRejection(e) => write!(f, "Host extraction failed: {}", e),
101            ExtractorError::TypedHeaderRejection(e) => write!(f, "TypedHeader extraction failed: {}", e),
102            ExtractorError::CookieRejection(e) => write!(f, "Cookie extraction failed: {}", e),
103            ExtractorError::MultipartRejection(e) => write!(f, "Multipart extraction failed: {}", e),
104            ExtractorError::WebSocketUpgradeRejection(e) => write!(f, "WebSocketUpgrade extraction failed: {}", e),
105            ExtractorError::BytesRejection(e) => write!(f, "Bytes extraction failed: {}", e),
106            ExtractorError::StreamRejection(e) => write!(f, "Stream extraction failed: {}", e),
107            ExtractorError::Custom(msg) => write!(f, "Extractor error: {}", msg),
108        }
109    }
110}
111
112impl std::error::Error for ExtractorError {}
113
114/// 路径参数提取器
115///
116/// 用于从 URL 路径中提取命名参数。
117///
118/// # 示例
119///
120/// ```ignore
121/// use wae_https::extract::Path;
122///
123/// async fn handler(Path(user_id): Path<u64>) -> String {
124///     format!("User ID: {}", user_id)
125/// }
126/// ```
127#[derive(Debug, Clone)]
128pub struct Path<T>(pub T);
129
130impl<T, S> FromRequestParts<S> for Path<T>
131where
132    T: std::str::FromStr,
133    T::Err: std::fmt::Display,
134{
135    type Error = ExtractorError;
136
137    fn from_request_parts(parts: &RequestParts, _state: &S) -> Result<Self, Self::Error> {
138        if let Some((_, value)) = parts.path_params.first() {
139            T::from_str(value)
140                .map(Path)
141                .map_err(|e| ExtractorError::PathRejection(format!("Failed to parse path parameter: {}", e)))
142        }
143        else {
144            Err(ExtractorError::PathRejection("No path parameters found".to_string()))
145        }
146    }
147}
148
149/// 查询参数提取器
150///
151/// 用于从 URL 查询字符串中提取参数。
152///
153/// # 示例
154///
155/// ```ignore
156/// use wae_https::extract::Query;
157/// use serde::Deserialize;
158///
159/// #[derive(Deserialize)]
160/// struct Pagination {
161///     page: u32,
162///     limit: u32,
163/// }
164///
165/// async fn handler(Query(pagination): Query<Pagination>) -> String {
166///     format!("Page: {}, Limit: {}", pagination.page, pagination.limit)
167/// }
168/// ```
169#[derive(Debug, Clone, serde::Deserialize)]
170pub struct Query<T>(pub T);
171
172impl<T, S> FromRequestParts<S> for Query<T>
173where
174    T: DeserializeOwned,
175{
176    type Error = ExtractorError;
177
178    fn from_request_parts(parts: &RequestParts, _state: &S) -> Result<Self, Self::Error> {
179        let query = parts.uri.query().unwrap_or_default();
180        serde_urlencoded::from_str(query)
181            .map(Query)
182            .map_err(|e| ExtractorError::QueryRejection(format!("Failed to parse query parameters: {}", e)))
183    }
184}
185
186/// JSON 请求体提取器
187///
188/// 用于从 JSON 格式的请求体中提取数据。
189///
190/// # 示例
191///
192/// ```ignore
193/// use wae_https::extract::Json;
194/// use serde::Deserialize;
195///
196/// #[derive(Deserialize)]
197/// struct User {
198///     name: String,
199///     age: u32,
200/// }
201///
202/// async fn handler(Json(user): Json<User>) -> String {
203///     format!("Name: {}, Age: {}", user.name, user.age)
204/// }
205/// ```
206#[derive(Debug, Clone, serde::Deserialize)]
207pub struct Json<T>(pub T);
208
209/// 表单数据提取器
210///
211/// 用于从表单格式的请求体中提取数据。
212///
213/// # 示例
214///
215/// ```ignore
216/// use wae_https::extract::Form;
217/// use serde::Deserialize;
218///
219/// #[derive(Deserialize)]
220/// struct LoginForm {
221///     username: String,
222///     password: String,
223/// }
224///
225/// async fn handler(Form(form): Form<LoginForm>) -> String {
226///     format!("Username: {}", form.username)
227/// }
228/// ```
229#[derive(Debug, Clone, serde::Deserialize)]
230pub struct Form<T>(pub T);
231
232/// 请求头提取器
233///
234/// 用于从请求中提取指定名称的请求头值。
235///
236/// # 示例
237///
238/// ```ignore
239/// use wae_https::extract::Header;
240///
241/// async fn handler(Header(content_type): Header<String>) -> String {
242///     content_type
243/// }
244/// ```
245#[derive(Debug, Clone)]
246pub struct Header<T>(pub T);
247
248/// HTTP 方法提取器
249///
250/// 获取当前请求的 HTTP 方法(GET、POST、PUT、DELETE 等)。
251pub type HttpMethod = Method;
252
253impl<S> FromRequestParts<S> for HttpMethod {
254    type Error = ExtractorError;
255
256    fn from_request_parts(parts: &RequestParts, _state: &S) -> Result<Self, Self::Error> {
257        Ok(parts.method.clone())
258    }
259}
260
261/// 请求 URI 提取器
262///
263/// 获取当前请求的完整 URI。
264pub type RequestUri = Uri;
265
266impl<S> FromRequestParts<S> for RequestUri {
267    type Error = ExtractorError;
268
269    fn from_request_parts(parts: &RequestParts, _state: &S) -> Result<Self, Self::Error> {
270        Ok(parts.uri.clone())
271    }
272}
273
274impl<S> FromRequestParts<S> for RequestParts {
275    type Error = ExtractorError;
276
277    fn from_request_parts(parts: &RequestParts, _state: &S) -> Result<Self, Self::Error> {
278        Ok(parts.clone())
279    }
280}
281
282impl Clone for RequestParts {
283    fn clone(&self) -> Self {
284        Self {
285            method: self.method.clone(),
286            uri: self.uri.clone(),
287            version: self.version,
288            headers: self.headers.clone(),
289            path_params: self.path_params.clone(),
290        }
291    }
292}
293
294/// HTTP 版本提取器
295///
296/// 获取当前请求的 HTTP 版本(HTTP/1.0、HTTP/1.1、HTTP/2.0 等)。
297pub type HttpVersion = Version;
298
299impl<S> FromRequestParts<S> for HttpVersion {
300    type Error = ExtractorError;
301
302    fn from_request_parts(parts: &RequestParts, _state: &S) -> Result<Self, Self::Error> {
303        Ok(parts.version)
304    }
305}
306
307/// 请求头映射提取器
308///
309/// 获取所有请求头的键值对映射。
310pub type Headers = HeaderMap;
311
312impl<S> FromRequestParts<S> for Headers {
313    type Error = ExtractorError;
314
315    fn from_request_parts(parts: &RequestParts, _state: &S) -> Result<Self, Self::Error> {
316        Ok(parts.headers.clone())
317    }
318}
319
320/// 扩展数据提取器
321#[derive(Debug, Clone)]
322pub struct Extension<T>(pub T);
323
324/// 多部分表单数据提取器
325#[derive(Debug, Clone)]
326pub struct Multipart;
327
328/// 原始 URI 提取器
329pub type OriginalUri = http::Uri;
330
331/// 状态提取器
332///
333/// 用于从应用状态中提取数据。
334///
335/// # 示例
336///
337/// ```ignore
338/// use wae_https::extract::State;
339///
340/// async fn handler(State(db_pool): State<DatabasePool>) -> String {
341///     // 使用 db_pool
342/// }
343/// ```
344#[derive(Debug, Clone)]
345pub struct State<T>(pub T);
346
347impl<T, S> FromRequestParts<S> for State<T>
348where
349    T: Clone + Send + Sync + 'static,
350    S: std::ops::Deref<Target = T>,
351{
352    type Error = ExtractorError;
353
354    fn from_request_parts(_parts: &RequestParts, state: &S) -> Result<Self, Self::Error> {
355        Ok(State(state.deref().clone()))
356    }
357}
358
359/// WebSocket 升级提取器
360#[derive(Debug, Clone)]
361pub struct WebSocketUpgrade;
362
363/// 流式请求体提取器
364///
365/// 用于从请求中提取流式数据。
366pub type Stream = crate::Body;
367
368/// 元组提取器支持(最多支持 16 个参数)
369macro_rules! impl_from_request_parts_tuple {
370    ($($ty:ident),*) => {
371        #[allow(non_snake_case, unused_variables)]
372        impl<S, $($ty,)*> FromRequestParts<S> for ($($ty,)*)
373        where
374            $($ty: FromRequestParts<S, Error = ExtractorError>,)*
375        {
376            type Error = ExtractorError;
377
378            fn from_request_parts(parts: &RequestParts, state: &S) -> Result<Self, Self::Error> {
379                Ok(($(
380                    $ty::from_request_parts(parts, state)?,
381                )*))
382            }
383        }
384    };
385}
386
387impl_from_request_parts_tuple!();
388impl_from_request_parts_tuple!(T1);
389impl_from_request_parts_tuple!(T1, T2);
390impl_from_request_parts_tuple!(T1, T2, T3);
391impl_from_request_parts_tuple!(T1, T2, T3, T4);
392impl_from_request_parts_tuple!(T1, T2, T3, T4, T5);
393impl_from_request_parts_tuple!(T1, T2, T3, T4, T5, T6);
394impl_from_request_parts_tuple!(T1, T2, T3, T4, T5, T6, T7);
395impl_from_request_parts_tuple!(T1, T2, T3, T4, T5, T6, T7, T8);
396impl_from_request_parts_tuple!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
397impl_from_request_parts_tuple!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
398impl_from_request_parts_tuple!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
399impl_from_request_parts_tuple!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
400impl_from_request_parts_tuple!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
401impl_from_request_parts_tuple!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
402impl_from_request_parts_tuple!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
403impl_from_request_parts_tuple!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);