Skip to main content

tork_core/extract/
mod.rs

1//! The request context and the [`FromRequest`] dependency-injection trait.
2//!
3//! Every handler parameter that is not a path parameter is resolved through
4//! [`FromRequest`]. Built-in extractors (such as [`State`](crate::State),
5//! [`BearerToken`], and [`Json`](crate::Json)) implement it directly, and the
6//! `#[tork::dependency]` macro generates an implementation for user-defined
7//! dependencies. There is no blanket implementation, which keeps the trait free
8//! of coherence conflicts.
9
10use std::net::SocketAddr;
11use std::sync::Mutex;
12
13use http::{Extensions, HeaderMap, Method, Uri};
14use hyper::upgrade::OnUpgrade;
15
16use crate::body::ReqBody;
17use crate::error::{Error, Result};
18use crate::state::{AppStateRef, StateMap};
19use crate::ws::Upgrade;
20
21pub mod body;
22pub mod header;
23pub mod path;
24pub mod valid;
25
26pub use header::{BearerToken, LastEventId, SseResume};
27pub use path::{__extract_path_param, FromPathParam};
28pub use valid::Valid;
29
30/// Raw path parameters captured by the router, in match order.
31#[derive(Debug, Default, Clone)]
32pub struct PathParams {
33    entries: Vec<(String, String)>,
34}
35
36impl PathParams {
37    /// Creates an empty set of path parameters.
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Appends a captured parameter and its raw value.
43    pub fn push(&mut self, name: String, value: String) {
44        self.entries.push((name, value));
45    }
46
47    /// Returns the raw value captured for `name`, if any.
48    pub fn get(&self, name: &str) -> Option<&str> {
49        self.entries
50            .iter()
51            .find(|(key, _)| key == name)
52            .map(|(_, value)| value.as_str())
53    }
54
55    /// Returns `true` if no path parameters were captured.
56    pub fn is_empty(&self) -> bool {
57        self.entries.is_empty()
58    }
59
60    /// Returns the number of captured path parameters.
61    pub fn len(&self) -> usize {
62        self.entries.len()
63    }
64}
65
66/// Everything an extractor needs to resolve a value from the current request.
67///
68/// Holds the request head (method, URI, headers, extensions), the captured path
69/// parameters, a handle to the application state, and the request body. The body
70/// can be taken at most once; see [`RequestContext::take_body`].
71pub struct RequestContext {
72    head: http::request::Parts,
73    path_params: PathParams,
74    state: AppStateRef,
75    body: Mutex<Option<ReqBody>>,
76    upgrade: Mutex<Option<Upgrade>>,
77}
78
79/// The remote TCP peer address, propagated from the accept loop when present.
80#[derive(Clone, Copy)]
81pub(crate) struct RequestPeerAddr(pub(crate) SocketAddr);
82
83/// The effective request scheme after trusted proxy normalization.
84#[derive(Clone, Copy, PartialEq, Eq)]
85pub(crate) enum RequestScheme {
86    Http,
87    Https,
88}
89
90impl RequestScheme {
91    pub(crate) fn as_str(self) -> &'static str {
92        match self {
93            RequestScheme::Http => "http",
94            RequestScheme::Https => "https",
95        }
96    }
97}
98
99pub(crate) fn peer_addr_from_extensions(extensions: &Extensions) -> Option<SocketAddr> {
100    extensions.get::<RequestPeerAddr>().map(|peer| peer.0)
101}
102
103pub(crate) fn scheme_from_extensions(extensions: &Extensions) -> Option<RequestScheme> {
104    extensions.get::<RequestScheme>().copied()
105}
106
107impl RequestContext {
108    /// Builds a new request context.
109    ///
110    /// A pending WebSocket upgrade (hyper's `OnUpgrade`, present on an upgrade
111    /// request) is taken out of the head's extensions so it can be claimed once
112    /// by [`take_upgrade`](RequestContext::take_upgrade).
113    pub fn new(
114        mut head: http::request::Parts,
115        path_params: PathParams,
116        state: AppStateRef,
117        body: ReqBody,
118    ) -> Self {
119        let upgrade = head.extensions.remove::<OnUpgrade>().map(Upgrade::Hyper);
120        Self {
121            head,
122            path_params,
123            state,
124            body: Mutex::new(Some(body)),
125            upgrade: Mutex::new(upgrade),
126        }
127    }
128
129    /// Builds a request context with an in-process WebSocket upgrade.
130    ///
131    /// Used by the test client to drive a WebSocket handler over an in-memory
132    /// duplex stream instead of a real upgraded connection. (The caller lands in a
133    /// later commit of this phase.)
134    #[allow(dead_code)]
135    pub(crate) fn with_duplex_upgrade(
136        head: http::request::Parts,
137        path_params: PathParams,
138        state: AppStateRef,
139        body: ReqBody,
140        duplex: tokio::io::DuplexStream,
141    ) -> Self {
142        Self {
143            head,
144            path_params,
145            state,
146            body: Mutex::new(Some(body)),
147            upgrade: Mutex::new(Some(Upgrade::Duplex(duplex))),
148        }
149    }
150
151    /// Returns the request method.
152    pub fn method(&self) -> &Method {
153        &self.head.method
154    }
155
156    /// Returns the request URI.
157    pub fn uri(&self) -> &Uri {
158        &self.head.uri
159    }
160
161    /// Returns the request headers.
162    pub fn headers(&self) -> &HeaderMap {
163        &self.head.headers
164    }
165
166    /// Returns the remote TCP peer address, when the request came through the
167    /// real server rather than an in-process test transport.
168    pub fn peer_addr(&self) -> Option<SocketAddr> {
169        peer_addr_from_extensions(&self.head.extensions)
170    }
171
172    /// Returns the effective request scheme after trusted proxy normalization.
173    pub fn scheme(&self) -> Option<&'static str> {
174        scheme_from_extensions(&self.head.extensions).map(RequestScheme::as_str)
175    }
176
177    /// Returns the full request head.
178    pub fn head(&self) -> &http::request::Parts {
179        &self.head
180    }
181
182    /// Returns the application state map.
183    pub fn state(&self) -> &StateMap {
184        self.state.as_ref()
185    }
186
187    /// Clones a registered resource of type `T` out of the registry.
188    ///
189    /// # Errors
190    ///
191    /// Returns an error (code `MISSING_RESOURCE`) if no resource of type `T` was
192    /// registered (for example, by a lifespan).
193    pub fn resource<T: Clone + Send + Sync + 'static>(&self) -> Result<T> {
194        self.state()
195            .get::<T>()
196            .map(|value| (*value).clone())
197            .ok_or_else(|| {
198                Error::internal(format!(
199                    "resource `{}` was not registered",
200                    std::any::type_name::<T>()
201                ))
202                .with_code("MISSING_RESOURCE")
203            })
204    }
205
206    /// Returns the captured path parameters.
207    pub fn path_params(&self) -> &PathParams {
208        &self.path_params
209    }
210
211    /// Returns the raw value of the path parameter named `name`, if captured.
212    pub fn path_param(&self, name: &str) -> Option<&str> {
213        self.path_params.get(name)
214    }
215
216    /// Takes ownership of the request body.
217    ///
218    /// # Errors
219    ///
220    /// Returns a `400 Bad Request` error if the body was already taken, since a
221    /// body can only be consumed by a single extractor.
222    pub fn take_body(&self) -> Result<ReqBody> {
223        self.body
224            .lock()
225            .expect("request body mutex poisoned")
226            .take()
227            .ok_or_else(|| Error::bad_request("request body has already been consumed"))
228    }
229
230    /// Takes the pending WebSocket upgrade.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error (code `NOT_AN_UPGRADE`) if the request is not a WebSocket
235    /// upgrade, or if the upgrade was already taken.
236    pub(crate) fn take_upgrade(&self) -> Result<Upgrade> {
237        self.upgrade
238            .lock()
239            .expect("request upgrade mutex poisoned")
240            .take()
241            .ok_or_else(|| {
242                Error::bad_request("request is not a WebSocket upgrade").with_code("NOT_AN_UPGRADE")
243            })
244    }
245}
246
247/// Produces a value from the current request to satisfy a handler parameter.
248///
249/// Implemented directly by built-in extractors and generated by
250/// `#[tork::dependency]` for user dependencies. Resolution is always statically
251/// dispatched. The returned future is `Send` so the enclosing handler future is
252/// `Send`, as required by the server.
253pub trait FromRequest: Sized + Send {
254    /// Resolves `Self` from the request context.
255    ///
256    /// An `Err` short-circuits request handling and is rendered as an HTTP error
257    /// response.
258    fn from_request(ctx: &RequestContext)
259        -> impl std::future::Future<Output = Result<Self>> + Send;
260}
261
262/// Injects any resource registered as `Arc<T>`.
263///
264/// Registering a shared value as `Arc<T>` (for example a loaded configuration)
265/// lets a handler or service take it by `Arc<T>`, cloning only the pointer per
266/// request. This is the idiomatic way to share immutable state cheaply, since the
267/// orphan rules prevent a downstream crate from implementing `FromRequest` for
268/// `Arc<T>` itself.
269impl<T: Send + Sync + 'static> FromRequest for std::sync::Arc<T> {
270    fn from_request(
271        ctx: &RequestContext,
272    ) -> impl std::future::Future<Output = Result<Self>> + Send {
273        let resolved = ctx.resource::<std::sync::Arc<T>>();
274        async move { resolved }
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use crate::body::box_body;
282    use crate::error::ErrorKind;
283    use bytes::Bytes;
284    use http_body_util::Full;
285    use std::sync::Arc;
286
287    fn test_context(path_params: PathParams, body: &'static str) -> RequestContext {
288        let head = http::Request::new(()).into_parts().0;
289        let body = box_body(Full::new(Bytes::from_static(body.as_bytes())));
290        RequestContext::new(head, path_params, Arc::new(StateMap::new()), body)
291    }
292
293    #[test]
294    fn path_param_lookup_and_parse() {
295        let mut params = PathParams::new();
296        params.push("user_id".to_owned(), "42".to_owned());
297        let ctx = test_context(params, "");
298
299        let parsed: i64 = __extract_path_param(&ctx, "user_id").unwrap();
300        assert_eq!(parsed, 42);
301    }
302
303    #[test]
304    fn invalid_path_param_is_unprocessable() {
305        let mut params = PathParams::new();
306        params.push("user_id".to_owned(), "not-a-number".to_owned());
307        let ctx = test_context(params, "");
308
309        let error = __extract_path_param::<i64>(&ctx, "user_id").unwrap_err();
310        assert_eq!(error.kind(), ErrorKind::Unprocessable);
311    }
312
313    #[test]
314    fn take_upgrade_errors_without_an_upgrade() {
315        let ctx = test_context(PathParams::new(), "");
316        let error = ctx
317            .take_upgrade()
318            .err()
319            .expect("should error without an upgrade");
320        assert_eq!(error.code(), "NOT_AN_UPGRADE");
321    }
322
323    #[test]
324    fn body_can_only_be_taken_once() {
325        let ctx = test_context(PathParams::new(), "hello");
326
327        assert!(ctx.take_body().is_ok());
328        let error = ctx.take_body().unwrap_err();
329        assert_eq!(error.kind(), ErrorKind::BadRequest);
330    }
331
332    #[test]
333    fn resource_is_cloned_from_registry() {
334        let mut map = StateMap::new();
335        map.insert(42_i64);
336        let head = http::Request::new(()).into_parts().0;
337        let body = box_body(Full::new(Bytes::from_static(b"")));
338        let ctx = RequestContext::new(head, PathParams::new(), Arc::new(map), body);
339
340        assert_eq!(ctx.resource::<i64>().unwrap(), 42);
341        assert!(ctx.resource::<String>().is_err());
342    }
343}