1use std::net::SocketAddr;
11use std::sync::Mutex;
12
13use http::{Extensions, HeaderMap, Method, Uri};
14use hyper::upgrade::OnUpgrade;
15
16use crate::body::ReqBody;
17use crate::error::{Error, Result};
18use crate::state::{AppStateRef, StateMap};
19use crate::ws::Upgrade;
20
21pub mod body;
22pub mod header;
23pub mod path;
24pub mod valid;
25
26pub use header::{BearerToken, LastEventId, SseResume};
27pub use path::{__extract_path_param, FromPathParam};
28pub use valid::Valid;
29
30#[derive(Debug, Default, Clone)]
32pub struct PathParams {
33 entries: Vec<(String, String)>,
34}
35
36impl PathParams {
37 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn push(&mut self, name: String, value: String) {
44 self.entries.push((name, value));
45 }
46
47 pub fn get(&self, name: &str) -> Option<&str> {
49 self.entries
50 .iter()
51 .find(|(key, _)| key == name)
52 .map(|(_, value)| value.as_str())
53 }
54
55 pub fn is_empty(&self) -> bool {
57 self.entries.is_empty()
58 }
59
60 pub fn len(&self) -> usize {
62 self.entries.len()
63 }
64}
65
66pub struct RequestContext {
72 head: http::request::Parts,
73 path_params: PathParams,
74 state: AppStateRef,
75 body: Mutex<Option<ReqBody>>,
76 upgrade: Mutex<Option<Upgrade>>,
77}
78
79#[derive(Clone, Copy)]
81pub(crate) struct RequestPeerAddr(pub(crate) SocketAddr);
82
83#[derive(Clone, Copy, PartialEq, Eq)]
85pub(crate) enum RequestScheme {
86 Http,
87 Https,
88}
89
90impl RequestScheme {
91 pub(crate) fn as_str(self) -> &'static str {
92 match self {
93 RequestScheme::Http => "http",
94 RequestScheme::Https => "https",
95 }
96 }
97}
98
99pub(crate) fn peer_addr_from_extensions(extensions: &Extensions) -> Option<SocketAddr> {
100 extensions.get::<RequestPeerAddr>().map(|peer| peer.0)
101}
102
103pub(crate) fn scheme_from_extensions(extensions: &Extensions) -> Option<RequestScheme> {
104 extensions.get::<RequestScheme>().copied()
105}
106
107impl RequestContext {
108 pub fn new(
114 mut head: http::request::Parts,
115 path_params: PathParams,
116 state: AppStateRef,
117 body: ReqBody,
118 ) -> Self {
119 let upgrade = head.extensions.remove::<OnUpgrade>().map(Upgrade::Hyper);
120 Self {
121 head,
122 path_params,
123 state,
124 body: Mutex::new(Some(body)),
125 upgrade: Mutex::new(upgrade),
126 }
127 }
128
129 #[allow(dead_code)]
135 pub(crate) fn with_duplex_upgrade(
136 head: http::request::Parts,
137 path_params: PathParams,
138 state: AppStateRef,
139 body: ReqBody,
140 duplex: tokio::io::DuplexStream,
141 ) -> Self {
142 Self {
143 head,
144 path_params,
145 state,
146 body: Mutex::new(Some(body)),
147 upgrade: Mutex::new(Some(Upgrade::Duplex(duplex))),
148 }
149 }
150
151 pub fn method(&self) -> &Method {
153 &self.head.method
154 }
155
156 pub fn uri(&self) -> &Uri {
158 &self.head.uri
159 }
160
161 pub fn headers(&self) -> &HeaderMap {
163 &self.head.headers
164 }
165
166 pub fn peer_addr(&self) -> Option<SocketAddr> {
169 peer_addr_from_extensions(&self.head.extensions)
170 }
171
172 pub fn scheme(&self) -> Option<&'static str> {
174 scheme_from_extensions(&self.head.extensions).map(RequestScheme::as_str)
175 }
176
177 pub fn head(&self) -> &http::request::Parts {
179 &self.head
180 }
181
182 pub fn state(&self) -> &StateMap {
184 self.state.as_ref()
185 }
186
187 pub fn resource<T: Clone + Send + Sync + 'static>(&self) -> Result<T> {
194 self.state()
195 .get::<T>()
196 .map(|value| (*value).clone())
197 .ok_or_else(|| {
198 Error::internal(format!(
199 "resource `{}` was not registered",
200 std::any::type_name::<T>()
201 ))
202 .with_code("MISSING_RESOURCE")
203 })
204 }
205
206 pub fn path_params(&self) -> &PathParams {
208 &self.path_params
209 }
210
211 pub fn path_param(&self, name: &str) -> Option<&str> {
213 self.path_params.get(name)
214 }
215
216 pub fn take_body(&self) -> Result<ReqBody> {
223 self.body
224 .lock()
225 .expect("request body mutex poisoned")
226 .take()
227 .ok_or_else(|| Error::bad_request("request body has already been consumed"))
228 }
229
230 pub(crate) fn take_upgrade(&self) -> Result<Upgrade> {
237 self.upgrade
238 .lock()
239 .expect("request upgrade mutex poisoned")
240 .take()
241 .ok_or_else(|| {
242 Error::bad_request("request is not a WebSocket upgrade").with_code("NOT_AN_UPGRADE")
243 })
244 }
245}
246
247pub trait FromRequest: Sized + Send {
254 fn from_request(ctx: &RequestContext)
259 -> impl std::future::Future<Output = Result<Self>> + Send;
260}
261
262impl<T: Send + Sync + 'static> FromRequest for std::sync::Arc<T> {
270 fn from_request(
271 ctx: &RequestContext,
272 ) -> impl std::future::Future<Output = Result<Self>> + Send {
273 let resolved = ctx.resource::<std::sync::Arc<T>>();
274 async move { resolved }
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use crate::body::box_body;
282 use crate::error::ErrorKind;
283 use bytes::Bytes;
284 use http_body_util::Full;
285 use std::sync::Arc;
286
287 fn test_context(path_params: PathParams, body: &'static str) -> RequestContext {
288 let head = http::Request::new(()).into_parts().0;
289 let body = box_body(Full::new(Bytes::from_static(body.as_bytes())));
290 RequestContext::new(head, path_params, Arc::new(StateMap::new()), body)
291 }
292
293 #[test]
294 fn path_param_lookup_and_parse() {
295 let mut params = PathParams::new();
296 params.push("user_id".to_owned(), "42".to_owned());
297 let ctx = test_context(params, "");
298
299 let parsed: i64 = __extract_path_param(&ctx, "user_id").unwrap();
300 assert_eq!(parsed, 42);
301 }
302
303 #[test]
304 fn invalid_path_param_is_unprocessable() {
305 let mut params = PathParams::new();
306 params.push("user_id".to_owned(), "not-a-number".to_owned());
307 let ctx = test_context(params, "");
308
309 let error = __extract_path_param::<i64>(&ctx, "user_id").unwrap_err();
310 assert_eq!(error.kind(), ErrorKind::Unprocessable);
311 }
312
313 #[test]
314 fn take_upgrade_errors_without_an_upgrade() {
315 let ctx = test_context(PathParams::new(), "");
316 let error = ctx
317 .take_upgrade()
318 .err()
319 .expect("should error without an upgrade");
320 assert_eq!(error.code(), "NOT_AN_UPGRADE");
321 }
322
323 #[test]
324 fn body_can_only_be_taken_once() {
325 let ctx = test_context(PathParams::new(), "hello");
326
327 assert!(ctx.take_body().is_ok());
328 let error = ctx.take_body().unwrap_err();
329 assert_eq!(error.kind(), ErrorKind::BadRequest);
330 }
331
332 #[test]
333 fn resource_is_cloned_from_registry() {
334 let mut map = StateMap::new();
335 map.insert(42_i64);
336 let head = http::Request::new(()).into_parts().0;
337 let body = box_body(Full::new(Bytes::from_static(b"")));
338 let ctx = RequestContext::new(head, PathParams::new(), Arc::new(map), body);
339
340 assert_eq!(ctx.resource::<i64>().unwrap(), 42);
341 assert!(ctx.resource::<String>().is_err());
342 }
343}