Skip to main content

typeway_server/
extract.rs

1//! Request extraction traits and built-in extractors.
2//!
3//! Extractors pull typed data from incoming HTTP requests. Each handler
4//! argument is an extractor that implements [`FromRequestParts`] (for
5//! metadata like path captures, headers, query strings) or [`FromRequest`]
6//! (for the request body).
7
8use bytes::Bytes;
9use http::request::Parts;
10use http::StatusCode;
11use serde::de::DeserializeOwned;
12
13use typeway_core::{ExtractPath, PathSpec};
14
15use crate::response::IntoResponse;
16
17/// Extract a value from request metadata (URI, headers, extensions).
18///
19/// Implementors can be used as handler arguments. Multiple `FromRequestParts`
20/// extractors can appear in a single handler since they don't consume the body.
21#[diagnostic::on_unimplemented(
22    message = "`{Self}` cannot be extracted from request metadata",
23    label = "does not implement `FromRequestParts`",
24    note = "valid extractors: `Path<P>`, `State<T>`, `Query<T>`, `HeaderMap`"
25)]
26pub trait FromRequestParts: Sized + Send {
27    /// The error type returned when extraction fails.
28    type Error: IntoResponse;
29
30    /// Extract this type from the request parts.
31    fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error>;
32}
33
34/// Extract a value by consuming the request body.
35///
36/// At most one `FromRequest` extractor can appear per handler (as the last
37/// argument), since it consumes the body.
38#[diagnostic::on_unimplemented(
39    message = "`{Self}` cannot be extracted from the request body",
40    label = "does not implement `FromRequest`",
41    note = "valid body extractors: `Json<T>`, `Bytes`, `String`, `()`"
42)]
43pub trait FromRequest: Sized + Send {
44    /// The error type returned when extraction fails.
45    type Error: IntoResponse;
46
47    /// Extract this type from the request parts and pre-collected body bytes.
48    ///
49    /// This is async for interface consistency, though body bytes are already
50    /// collected by the router before dispatch.
51    fn from_request(
52        parts: &Parts,
53        body: bytes::Bytes,
54    ) -> impl std::future::Future<Output = Result<Self, Self::Error>> + Send;
55}
56
57// ---------------------------------------------------------------------------
58// Path extractor
59// ---------------------------------------------------------------------------
60
61/// Extracts typed path captures from the URL.
62///
63/// The extractor reads `parts.uri.path()` directly and splits it on demand,
64/// so there's no per-request `Vec<String>` allocation. When the router has
65/// a configured prefix, the byte length of that prefix is stored in
66/// extensions as [`PathPrefixOffset`] so this extractor can skip past it.
67///
68/// # Example
69///
70/// ```ignore
71/// async fn get_user(Path((id,)): Path<path!("users" / u32)>) -> Json<User> {
72///     // id: u32, extracted from /users/42
73/// }
74/// ```
75pub struct Path<P: PathSpec>(pub P::Captures);
76
77/// Byte offset into `parts.uri.path()` where the post-prefix path begins.
78///
79/// Inserted into request extensions by the router only when a prefix is
80/// configured. Absent (treated as `0`) for the no-prefix case so we don't
81/// touch the extensions map on the hot path.
82#[derive(Copy, Clone)]
83pub struct PathPrefixOffset(pub usize);
84
85impl<P> FromRequestParts for Path<P>
86where
87    P: PathSpec + ExtractPath + Send,
88    P::Captures: Send,
89{
90    type Error = (StatusCode, String);
91
92    fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
93        let full_path = parts.uri.path();
94        let offset = parts
95            .extensions
96            .get::<PathPrefixOffset>()
97            .map_or(0, |o| o.0);
98        let path = if offset <= full_path.len() {
99            &full_path[offset..]
100        } else {
101            ""
102        };
103        let segs: smallvec::SmallVec<[&str; 8]> =
104            path.split('/').filter(|s| !s.is_empty()).collect();
105        P::extract(&segs).map(Path).ok_or_else(|| {
106            (
107                StatusCode::BAD_REQUEST,
108                format!(
109                    "failed to parse path segments for pattern: {}",
110                    P::pattern()
111                ),
112            )
113        })
114    }
115}
116
117// ---------------------------------------------------------------------------
118// State extractor
119// ---------------------------------------------------------------------------
120
121/// Extracts shared application state.
122///
123/// State must be added to the server via [`Server::with_state`](crate::server::Server::with_state)
124/// and is injected into request extensions.
125///
126/// # Example
127///
128/// ```
129/// use typeway_server::State;
130///
131/// #[derive(Clone)]
132/// struct DbPool;
133///
134/// async fn list_users(State(db): State<DbPool>) -> &'static str {
135///     let _ = db;
136///     "users"
137/// }
138/// ```
139pub struct State<T>(pub T);
140
141impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
142    type Error = (StatusCode, String);
143
144    fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
145        parts
146            .extensions
147            .get::<T>()
148            .cloned()
149            .map(State)
150            .ok_or_else(|| {
151                (
152                    StatusCode::INTERNAL_SERVER_ERROR,
153                    format!(
154                        "state of type `{}` not found — did you call .with_state()?",
155                        std::any::type_name::<T>()
156                    ),
157                )
158            })
159    }
160}
161
162// ---------------------------------------------------------------------------
163// Query extractor
164// ---------------------------------------------------------------------------
165
166/// Extracts typed query string parameters.
167///
168/// # Example
169///
170/// ```
171/// use typeway_server::Query;
172///
173/// #[derive(serde::Deserialize)]
174/// struct Pagination { page: u32, per_page: u32 }
175///
176/// async fn list_users(Query(p): Query<Pagination>) -> String {
177///     format!("page={}, per_page={}", p.page, p.per_page)
178/// }
179/// ```
180pub struct Query<T>(pub T);
181
182impl<T: DeserializeOwned + Send> FromRequestParts for Query<T> {
183    type Error = (StatusCode, String);
184
185    fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
186        let query = parts.uri.query().unwrap_or("");
187        serde_urlencoded::from_str::<T>(query)
188            .map(Query)
189            .map_err(|e| {
190                (
191                    StatusCode::BAD_REQUEST,
192                    format!("failed to parse query string: {e}"),
193                )
194            })
195    }
196}
197
198// ---------------------------------------------------------------------------
199// HeaderMap extractor
200// ---------------------------------------------------------------------------
201
202impl FromRequestParts for http::HeaderMap {
203    type Error = (StatusCode, String);
204
205    fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
206        Ok(parts.headers.clone())
207    }
208}
209
210// ---------------------------------------------------------------------------
211// Extension extractor
212// ---------------------------------------------------------------------------
213
214/// Extracts a value from request extensions.
215///
216/// Use this to access arbitrary types injected by middleware or other
217/// infrastructure. Unlike [`State`], extensions are per-request.
218///
219/// # Example
220///
221/// ```
222/// use typeway_server::Extension;
223///
224/// #[derive(Clone)]
225/// struct RequestId(String);
226///
227/// async fn handler(Extension(id): Extension<RequestId>) -> String {
228///     format!("Request: {}", id.0)
229/// }
230/// ```
231pub struct Extension<T>(pub T);
232
233impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
234    type Error = (StatusCode, String);
235
236    fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
237        parts
238            .extensions
239            .get::<T>()
240            .cloned()
241            .map(Extension)
242            .ok_or_else(|| {
243                (
244                    StatusCode::INTERNAL_SERVER_ERROR,
245                    format!(
246                        "extension of type `{}` not found in request",
247                        std::any::type_name::<T>()
248                    ),
249                )
250            })
251    }
252}
253
254// ---------------------------------------------------------------------------
255// Cookie extractor
256// ---------------------------------------------------------------------------
257
258/// Trait for types that extract a specific named cookie.
259///
260/// # Example
261///
262/// ```
263/// use typeway_server::extract::{Cookie, NamedCookie};
264///
265/// struct SessionId(String);
266///
267/// impl NamedCookie for SessionId {
268///     const COOKIE_NAME: &'static str = "session_id";
269///     fn from_value(value: &str) -> Result<Self, String> {
270///         Ok(SessionId(value.to_string()))
271///     }
272/// }
273///
274/// async fn handler(Cookie(session): Cookie<SessionId>) -> String {
275///     format!("session: {}", session.0)
276/// }
277/// ```
278pub trait NamedCookie: Sized + Send {
279    /// The cookie name to extract.
280    const COOKIE_NAME: &'static str;
281    /// Parse the cookie value string into this type.
282    fn from_value(value: &str) -> Result<Self, String>;
283}
284
285/// Extracts a single cookie by name.
286pub struct Cookie<T>(pub T);
287
288impl<T: NamedCookie + 'static> FromRequestParts for Cookie<T> {
289    type Error = (StatusCode, String);
290
291    fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
292        let cookies = parts
293            .headers
294            .get(http::header::COOKIE)
295            .and_then(|v| v.to_str().ok())
296            .unwrap_or("");
297
298        for pair in cookies.split(';') {
299            let pair = pair.trim();
300            if let Some(value) = pair
301                .strip_prefix(T::COOKIE_NAME)
302                .and_then(|s| s.strip_prefix('='))
303            {
304                return T::from_value(value)
305                    .map(Cookie)
306                    .map_err(|e| (StatusCode::BAD_REQUEST, e));
307            }
308        }
309
310        Err((
311            StatusCode::BAD_REQUEST,
312            format!("missing cookie: {}", T::COOKIE_NAME),
313        ))
314    }
315}
316
317/// Extracts all cookies as a key-value map.
318///
319/// ```
320/// use typeway_server::extract::CookieJar;
321///
322/// async fn handler(cookies: CookieJar) -> String {
323///     let session = cookies.get("session_id").unwrap_or("none");
324///     format!("session: {session}")
325/// }
326/// ```
327pub struct CookieJar(pub std::collections::HashMap<String, String>);
328
329impl CookieJar {
330    /// Get a cookie value by name.
331    pub fn get(&self, name: &str) -> Option<&str> {
332        self.0.get(name).map(|s| s.as_str())
333    }
334}
335
336impl FromRequestParts for CookieJar {
337    type Error = (StatusCode, String);
338
339    fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
340        let cookies = parts
341            .headers
342            .get(http::header::COOKIE)
343            .and_then(|v| v.to_str().ok())
344            .unwrap_or("");
345
346        let map = cookies
347            .split(';')
348            .filter_map(|pair| {
349                let pair = pair.trim();
350                let (name, value) = pair.split_once('=')?;
351                Some((name.to_string(), value.to_string()))
352            })
353            .collect();
354
355        Ok(CookieJar(map))
356    }
357}
358
359// ---------------------------------------------------------------------------
360// Method extractor
361// ---------------------------------------------------------------------------
362
363/// Extracts the HTTP method from the request.
364impl FromRequestParts for http::Method {
365    type Error = (StatusCode, String);
366
367    fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
368        Ok(parts.method.clone())
369    }
370}
371
372/// Extracts the request URI.
373impl FromRequestParts for http::Uri {
374    type Error = (StatusCode, String);
375
376    fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
377        Ok(parts.uri.clone())
378    }
379}
380
381// ---------------------------------------------------------------------------
382// Header extractor
383// ---------------------------------------------------------------------------
384
385/// Extracts a single header value by name.
386///
387/// The header name is derived from `T::HEADER_NAME`. Implement [`NamedHeader`]
388/// on your type to use this extractor.
389///
390/// # Example
391///
392/// ```ignore
393/// struct ContentType(String);
394///
395/// impl NamedHeader for ContentType {
396///     const HEADER_NAME: &'static str = "content-type";
397///     fn from_value(value: &str) -> Result<Self, String> {
398///         Ok(ContentType(value.to_string()))
399///     }
400/// }
401///
402/// async fn handler(Header(ct): Header<ContentType>) -> String {
403///     format!("Content-Type: {}", ct.0)
404/// }
405/// ```
406pub struct Header<T>(pub T);
407
408/// Trait for types that can be extracted from a named HTTP header.
409pub trait NamedHeader: Sized + Send {
410    /// The header name (lowercase), e.g. `"content-type"`.
411    const HEADER_NAME: &'static str;
412
413    /// Parse the header value string into this type.
414    fn from_value(value: &str) -> Result<Self, String>;
415}
416
417impl<T: NamedHeader + 'static> FromRequestParts for Header<T> {
418    type Error = (StatusCode, String);
419
420    fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
421        let value = parts
422            .headers
423            .get(T::HEADER_NAME)
424            .ok_or_else(|| {
425                (
426                    StatusCode::BAD_REQUEST,
427                    format!("missing required header: {}", T::HEADER_NAME),
428                )
429            })?
430            .to_str()
431            .map_err(|_| {
432                (
433                    StatusCode::BAD_REQUEST,
434                    format!("invalid header value for: {}", T::HEADER_NAME),
435                )
436            })?;
437
438        T::from_value(value)
439            .map(Header)
440            .map_err(|e| (StatusCode::BAD_REQUEST, e))
441    }
442}
443
444// ---------------------------------------------------------------------------
445// Body extractors (FromRequest)
446// ---------------------------------------------------------------------------
447
448/// JSON request body extractor.
449///
450/// Parses the request body as JSON. Requires `Content-Type: application/json`.
451///
452/// # Example
453///
454/// ```ignore
455/// async fn create_user(Json(body): Json<CreateUser>) -> Json<User> {
456///     // body: CreateUser, deserialized from JSON
457/// }
458/// ```
459impl<T: DeserializeOwned + Send> FromRequest for crate::response::Json<T> {
460    type Error = (StatusCode, String);
461
462    async fn from_request(_parts: &Parts, body: bytes::Bytes) -> Result<Self, Self::Error> {
463        serde_json::from_slice(&body)
464            .map(crate::response::Json)
465            .map_err(|e| (StatusCode::BAD_REQUEST, format!("invalid JSON: {e}")))
466    }
467}
468
469impl FromRequest for Bytes {
470    type Error = (StatusCode, String);
471
472    async fn from_request(_parts: &Parts, body: bytes::Bytes) -> Result<Self, Self::Error> {
473        Ok(body)
474    }
475}
476
477impl FromRequest for String {
478    type Error = (StatusCode, String);
479
480    async fn from_request(_parts: &Parts, body: bytes::Bytes) -> Result<Self, Self::Error> {
481        String::from_utf8(body.to_vec()).map_err(|e| {
482            (
483                StatusCode::BAD_REQUEST,
484                format!("request body is not valid UTF-8: {e}"),
485            )
486        })
487    }
488}
489
490/// Unit extractor — always succeeds, ignoring the body.
491impl FromRequest for () {
492    type Error = (StatusCode, String);
493
494    async fn from_request(_parts: &Parts, _body: bytes::Bytes) -> Result<Self, Self::Error> {
495        Ok(())
496    }
497}