typeway_server/extract.rs
1//! Request extraction traits and built-in extractors.
2//!
3//! Extractors pull typed data from incoming HTTP requests. Each handler
4//! argument is an extractor that implements [`FromRequestParts`] (for
5//! metadata like path captures, headers, query strings) or [`FromRequest`]
6//! (for the request body).
7
8use bytes::Bytes;
9use http::request::Parts;
10use http::StatusCode;
11use serde::de::DeserializeOwned;
12
13use typeway_core::{ExtractPath, PathSpec};
14
15use crate::response::IntoResponse;
16
17/// Extract a value from request metadata (URI, headers, extensions).
18///
19/// Implementors can be used as handler arguments. Multiple `FromRequestParts`
20/// extractors can appear in a single handler since they don't consume the body.
21#[diagnostic::on_unimplemented(
22 message = "`{Self}` cannot be extracted from request metadata",
23 label = "does not implement `FromRequestParts`",
24 note = "valid extractors: `Path<P>`, `State<T>`, `Query<T>`, `HeaderMap`"
25)]
26pub trait FromRequestParts: Sized + Send {
27 /// The error type returned when extraction fails.
28 type Error: IntoResponse;
29
30 /// Extract this type from the request parts.
31 fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error>;
32}
33
34/// Extract a value by consuming the request body.
35///
36/// At most one `FromRequest` extractor can appear per handler (as the last
37/// argument), since it consumes the body.
38#[diagnostic::on_unimplemented(
39 message = "`{Self}` cannot be extracted from the request body",
40 label = "does not implement `FromRequest`",
41 note = "valid body extractors: `Json<T>`, `Bytes`, `String`, `()`"
42)]
43pub trait FromRequest: Sized + Send {
44 /// The error type returned when extraction fails.
45 type Error: IntoResponse;
46
47 /// Extract this type from the request parts and pre-collected body bytes.
48 ///
49 /// This is async for interface consistency, though body bytes are already
50 /// collected by the router before dispatch.
51 fn from_request(
52 parts: &Parts,
53 body: bytes::Bytes,
54 ) -> impl std::future::Future<Output = Result<Self, Self::Error>> + Send;
55}
56
57// ---------------------------------------------------------------------------
58// Path extractor
59// ---------------------------------------------------------------------------
60
61/// Extracts typed path captures from the URL.
62///
63/// The extractor reads `parts.uri.path()` directly and splits it on demand,
64/// so there's no per-request `Vec<String>` allocation. When the router has
65/// a configured prefix, the byte length of that prefix is stored in
66/// extensions as [`PathPrefixOffset`] so this extractor can skip past it.
67///
68/// # Example
69///
70/// ```ignore
71/// async fn get_user(Path((id,)): Path<path!("users" / u32)>) -> Json<User> {
72/// // id: u32, extracted from /users/42
73/// }
74/// ```
75pub struct Path<P: PathSpec>(pub P::Captures);
76
77/// Byte offset into `parts.uri.path()` where the post-prefix path begins.
78///
79/// Inserted into request extensions by the router only when a prefix is
80/// configured. Absent (treated as `0`) for the no-prefix case so we don't
81/// touch the extensions map on the hot path.
82#[derive(Copy, Clone)]
83pub struct PathPrefixOffset(pub usize);
84
85impl<P> FromRequestParts for Path<P>
86where
87 P: PathSpec + ExtractPath + Send,
88 P::Captures: Send,
89{
90 type Error = (StatusCode, String);
91
92 fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
93 let full_path = parts.uri.path();
94 let offset = parts
95 .extensions
96 .get::<PathPrefixOffset>()
97 .map_or(0, |o| o.0);
98 let path = if offset <= full_path.len() {
99 &full_path[offset..]
100 } else {
101 ""
102 };
103 let segs: smallvec::SmallVec<[&str; 8]> =
104 path.split('/').filter(|s| !s.is_empty()).collect();
105 P::extract(&segs).map(Path).ok_or_else(|| {
106 (
107 StatusCode::BAD_REQUEST,
108 format!(
109 "failed to parse path segments for pattern: {}",
110 P::pattern()
111 ),
112 )
113 })
114 }
115}
116
117// ---------------------------------------------------------------------------
118// State extractor
119// ---------------------------------------------------------------------------
120
121/// Extracts shared application state.
122///
123/// State must be added to the server via [`Server::with_state`](crate::server::Server::with_state)
124/// and is injected into request extensions.
125///
126/// # Example
127///
128/// ```
129/// use typeway_server::State;
130///
131/// #[derive(Clone)]
132/// struct DbPool;
133///
134/// async fn list_users(State(db): State<DbPool>) -> &'static str {
135/// let _ = db;
136/// "users"
137/// }
138/// ```
139pub struct State<T>(pub T);
140
141impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
142 type Error = (StatusCode, String);
143
144 fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
145 parts
146 .extensions
147 .get::<T>()
148 .cloned()
149 .map(State)
150 .ok_or_else(|| {
151 (
152 StatusCode::INTERNAL_SERVER_ERROR,
153 format!(
154 "state of type `{}` not found — did you call .with_state()?",
155 std::any::type_name::<T>()
156 ),
157 )
158 })
159 }
160}
161
162// ---------------------------------------------------------------------------
163// Query extractor
164// ---------------------------------------------------------------------------
165
166/// Extracts typed query string parameters.
167///
168/// # Example
169///
170/// ```
171/// use typeway_server::Query;
172///
173/// #[derive(serde::Deserialize)]
174/// struct Pagination { page: u32, per_page: u32 }
175///
176/// async fn list_users(Query(p): Query<Pagination>) -> String {
177/// format!("page={}, per_page={}", p.page, p.per_page)
178/// }
179/// ```
180pub struct Query<T>(pub T);
181
182impl<T: DeserializeOwned + Send> FromRequestParts for Query<T> {
183 type Error = (StatusCode, String);
184
185 fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
186 let query = parts.uri.query().unwrap_or("");
187 serde_urlencoded::from_str::<T>(query)
188 .map(Query)
189 .map_err(|e| {
190 (
191 StatusCode::BAD_REQUEST,
192 format!("failed to parse query string: {e}"),
193 )
194 })
195 }
196}
197
198// ---------------------------------------------------------------------------
199// HeaderMap extractor
200// ---------------------------------------------------------------------------
201
202impl FromRequestParts for http::HeaderMap {
203 type Error = (StatusCode, String);
204
205 fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
206 Ok(parts.headers.clone())
207 }
208}
209
210// ---------------------------------------------------------------------------
211// Extension extractor
212// ---------------------------------------------------------------------------
213
214/// Extracts a value from request extensions.
215///
216/// Use this to access arbitrary types injected by middleware or other
217/// infrastructure. Unlike [`State`], extensions are per-request.
218///
219/// # Example
220///
221/// ```
222/// use typeway_server::Extension;
223///
224/// #[derive(Clone)]
225/// struct RequestId(String);
226///
227/// async fn handler(Extension(id): Extension<RequestId>) -> String {
228/// format!("Request: {}", id.0)
229/// }
230/// ```
231pub struct Extension<T>(pub T);
232
233impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
234 type Error = (StatusCode, String);
235
236 fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
237 parts
238 .extensions
239 .get::<T>()
240 .cloned()
241 .map(Extension)
242 .ok_or_else(|| {
243 (
244 StatusCode::INTERNAL_SERVER_ERROR,
245 format!(
246 "extension of type `{}` not found in request",
247 std::any::type_name::<T>()
248 ),
249 )
250 })
251 }
252}
253
254// ---------------------------------------------------------------------------
255// Cookie extractor
256// ---------------------------------------------------------------------------
257
258/// Trait for types that extract a specific named cookie.
259///
260/// # Example
261///
262/// ```
263/// use typeway_server::extract::{Cookie, NamedCookie};
264///
265/// struct SessionId(String);
266///
267/// impl NamedCookie for SessionId {
268/// const COOKIE_NAME: &'static str = "session_id";
269/// fn from_value(value: &str) -> Result<Self, String> {
270/// Ok(SessionId(value.to_string()))
271/// }
272/// }
273///
274/// async fn handler(Cookie(session): Cookie<SessionId>) -> String {
275/// format!("session: {}", session.0)
276/// }
277/// ```
278pub trait NamedCookie: Sized + Send {
279 /// The cookie name to extract.
280 const COOKIE_NAME: &'static str;
281 /// Parse the cookie value string into this type.
282 fn from_value(value: &str) -> Result<Self, String>;
283}
284
285/// Extracts a single cookie by name.
286pub struct Cookie<T>(pub T);
287
288impl<T: NamedCookie + 'static> FromRequestParts for Cookie<T> {
289 type Error = (StatusCode, String);
290
291 fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
292 let cookies = parts
293 .headers
294 .get(http::header::COOKIE)
295 .and_then(|v| v.to_str().ok())
296 .unwrap_or("");
297
298 for pair in cookies.split(';') {
299 let pair = pair.trim();
300 if let Some(value) = pair
301 .strip_prefix(T::COOKIE_NAME)
302 .and_then(|s| s.strip_prefix('='))
303 {
304 return T::from_value(value)
305 .map(Cookie)
306 .map_err(|e| (StatusCode::BAD_REQUEST, e));
307 }
308 }
309
310 Err((
311 StatusCode::BAD_REQUEST,
312 format!("missing cookie: {}", T::COOKIE_NAME),
313 ))
314 }
315}
316
317/// Extracts all cookies as a key-value map.
318///
319/// ```
320/// use typeway_server::extract::CookieJar;
321///
322/// async fn handler(cookies: CookieJar) -> String {
323/// let session = cookies.get("session_id").unwrap_or("none");
324/// format!("session: {session}")
325/// }
326/// ```
327pub struct CookieJar(pub std::collections::HashMap<String, String>);
328
329impl CookieJar {
330 /// Get a cookie value by name.
331 pub fn get(&self, name: &str) -> Option<&str> {
332 self.0.get(name).map(|s| s.as_str())
333 }
334}
335
336impl FromRequestParts for CookieJar {
337 type Error = (StatusCode, String);
338
339 fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
340 let cookies = parts
341 .headers
342 .get(http::header::COOKIE)
343 .and_then(|v| v.to_str().ok())
344 .unwrap_or("");
345
346 let map = cookies
347 .split(';')
348 .filter_map(|pair| {
349 let pair = pair.trim();
350 let (name, value) = pair.split_once('=')?;
351 Some((name.to_string(), value.to_string()))
352 })
353 .collect();
354
355 Ok(CookieJar(map))
356 }
357}
358
359// ---------------------------------------------------------------------------
360// Method extractor
361// ---------------------------------------------------------------------------
362
363/// Extracts the HTTP method from the request.
364impl FromRequestParts for http::Method {
365 type Error = (StatusCode, String);
366
367 fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
368 Ok(parts.method.clone())
369 }
370}
371
372/// Extracts the request URI.
373impl FromRequestParts for http::Uri {
374 type Error = (StatusCode, String);
375
376 fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
377 Ok(parts.uri.clone())
378 }
379}
380
381// ---------------------------------------------------------------------------
382// Header extractor
383// ---------------------------------------------------------------------------
384
385/// Extracts a single header value by name.
386///
387/// The header name is derived from `T::HEADER_NAME`. Implement [`NamedHeader`]
388/// on your type to use this extractor.
389///
390/// # Example
391///
392/// ```ignore
393/// struct ContentType(String);
394///
395/// impl NamedHeader for ContentType {
396/// const HEADER_NAME: &'static str = "content-type";
397/// fn from_value(value: &str) -> Result<Self, String> {
398/// Ok(ContentType(value.to_string()))
399/// }
400/// }
401///
402/// async fn handler(Header(ct): Header<ContentType>) -> String {
403/// format!("Content-Type: {}", ct.0)
404/// }
405/// ```
406pub struct Header<T>(pub T);
407
408/// Trait for types that can be extracted from a named HTTP header.
409pub trait NamedHeader: Sized + Send {
410 /// The header name (lowercase), e.g. `"content-type"`.
411 const HEADER_NAME: &'static str;
412
413 /// Parse the header value string into this type.
414 fn from_value(value: &str) -> Result<Self, String>;
415}
416
417impl<T: NamedHeader + 'static> FromRequestParts for Header<T> {
418 type Error = (StatusCode, String);
419
420 fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
421 let value = parts
422 .headers
423 .get(T::HEADER_NAME)
424 .ok_or_else(|| {
425 (
426 StatusCode::BAD_REQUEST,
427 format!("missing required header: {}", T::HEADER_NAME),
428 )
429 })?
430 .to_str()
431 .map_err(|_| {
432 (
433 StatusCode::BAD_REQUEST,
434 format!("invalid header value for: {}", T::HEADER_NAME),
435 )
436 })?;
437
438 T::from_value(value)
439 .map(Header)
440 .map_err(|e| (StatusCode::BAD_REQUEST, e))
441 }
442}
443
444// ---------------------------------------------------------------------------
445// Body extractors (FromRequest)
446// ---------------------------------------------------------------------------
447
448/// JSON request body extractor.
449///
450/// Parses the request body as JSON. Requires `Content-Type: application/json`.
451///
452/// # Example
453///
454/// ```ignore
455/// async fn create_user(Json(body): Json<CreateUser>) -> Json<User> {
456/// // body: CreateUser, deserialized from JSON
457/// }
458/// ```
459impl<T: DeserializeOwned + Send> FromRequest for crate::response::Json<T> {
460 type Error = (StatusCode, String);
461
462 async fn from_request(_parts: &Parts, body: bytes::Bytes) -> Result<Self, Self::Error> {
463 serde_json::from_slice(&body)
464 .map(crate::response::Json)
465 .map_err(|e| (StatusCode::BAD_REQUEST, format!("invalid JSON: {e}")))
466 }
467}
468
469impl FromRequest for Bytes {
470 type Error = (StatusCode, String);
471
472 async fn from_request(_parts: &Parts, body: bytes::Bytes) -> Result<Self, Self::Error> {
473 Ok(body)
474 }
475}
476
477impl FromRequest for String {
478 type Error = (StatusCode, String);
479
480 async fn from_request(_parts: &Parts, body: bytes::Bytes) -> Result<Self, Self::Error> {
481 String::from_utf8(body.to_vec()).map_err(|e| {
482 (
483 StatusCode::BAD_REQUEST,
484 format!("request body is not valid UTF-8: {e}"),
485 )
486 })
487 }
488}
489
490/// Unit extractor — always succeeds, ignoring the body.
491impl FromRequest for () {
492 type Error = (StatusCode, String);
493
494 async fn from_request(_parts: &Parts, _body: bytes::Bytes) -> Result<Self, Self::Error> {
495 Ok(())
496 }
497}