1use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use http::{Method, StatusCode};
12
13use crate::error::Result;
14use crate::extract::RequestContext;
15use crate::hooks::{ErrorEvent, RequestEvent, ResponseEvent, ValidationErrorEvent};
16use crate::response::Response;
17
18pub mod matcher;
19
20pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
22
23pub(crate) type SharedRequestHook =
28 Arc<dyn Fn(RequestEvent) -> BoxFuture<'static, ()> + Send + Sync>;
29
30pub(crate) type SharedResponseHook =
32 Arc<dyn Fn(ResponseEvent) -> BoxFuture<'static, ()> + Send + Sync>;
33
34pub(crate) type SharedErrorHook = Arc<dyn Fn(ErrorEvent) -> BoxFuture<'static, ()> + Send + Sync>;
36
37pub(crate) type SharedValidationErrorHook =
39 Arc<dyn Fn(ValidationErrorEvent) -> BoxFuture<'static, ()> + Send + Sync>;
40
41macro_rules! scoped_hook_builders {
44 () => {
45 pub fn on_request<F, Fut>(mut self, hook: F) -> Self
49 where
50 F: Fn(RequestEvent) -> Fut + Send + Sync + 'static,
51 Fut: Future<Output = ()> + Send + 'static,
52 {
53 self.request_hooks
54 .push(Arc::new(move |event| Box::pin(hook(event))));
55 self
56 }
57
58 pub fn on_response<F, Fut>(mut self, hook: F) -> Self
60 where
61 F: Fn(ResponseEvent) -> Fut + Send + Sync + 'static,
62 Fut: Future<Output = ()> + Send + 'static,
63 {
64 self.response_hooks
65 .push(Arc::new(move |event| Box::pin(hook(event))));
66 self
67 }
68
69 pub fn on_error<F, Fut>(mut self, hook: F) -> Self
71 where
72 F: Fn(ErrorEvent) -> Fut + Send + Sync + 'static,
73 Fut: Future<Output = ()> + Send + 'static,
74 {
75 self.error_hooks
76 .push(Arc::new(move |event| Box::pin(hook(event))));
77 self
78 }
79
80 pub fn on_validation_error<F, Fut>(mut self, hook: F) -> Self
82 where
83 F: Fn(ValidationErrorEvent) -> Fut + Send + Sync + 'static,
84 Fut: Future<Output = ()> + Send + 'static,
85 {
86 self.validation_hooks
87 .push(Arc::new(move |event| Box::pin(hook(event))));
88 self
89 }
90 };
91}
92
93pub type HandlerFn =
105 Arc<dyn Fn(RequestContext) -> BoxFuture<'static, Result<Response>> + Send + Sync>;
106
107pub type SchemaThunk = fn(&mut schemars::SchemaGenerator) -> schemars::Schema;
111
112#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
114pub enum RequestBodyKind {
115 #[default]
117 Json,
118 Form,
120 Multipart,
122}
123
124#[derive(Clone, Debug)]
126pub struct RouteMeta {
127 pub summary: Option<String>,
129 pub description: Option<String>,
131 pub tags: Vec<String>,
133 pub status_code: StatusCode,
135 pub response_model: Option<&'static str>,
137 pub request_schema: Option<SchemaThunk>,
139 pub request_kind: RequestBodyKind,
141 pub response_schema: Option<SchemaThunk>,
143 pub streaming: bool,
146 pub websocket: bool,
149 pub ws_incoming: Option<SchemaThunk>,
151 pub ws_outgoing: Option<SchemaThunk>,
153}
154
155impl Default for RouteMeta {
156 fn default() -> Self {
157 Self {
158 summary: None,
159 description: None,
160 tags: Vec::new(),
161 status_code: StatusCode::OK,
162 response_model: None,
163 request_schema: None,
164 request_kind: RequestBodyKind::Json,
165 response_schema: None,
166 streaming: false,
167 websocket: false,
168 ws_incoming: None,
169 ws_outgoing: None,
170 }
171 }
172}
173
174#[derive(Clone)]
179pub struct Route {
180 method: Method,
181 path: String,
182 handler: HandlerFn,
183 meta: RouteMeta,
184 request_hooks: Vec<SharedRequestHook>,
185 response_hooks: Vec<SharedResponseHook>,
186 error_hooks: Vec<SharedErrorHook>,
187 validation_hooks: Vec<SharedValidationErrorHook>,
188}
189
190impl Route {
191 pub fn new(method: Method, path: impl Into<String>, handler: HandlerFn) -> Self {
193 Self {
194 method,
195 path: path.into(),
196 handler,
197 meta: RouteMeta::default(),
198 request_hooks: Vec::new(),
199 response_hooks: Vec::new(),
200 error_hooks: Vec::new(),
201 validation_hooks: Vec::new(),
202 }
203 }
204
205 scoped_hook_builders!();
206
207 pub fn summary(mut self, summary: impl Into<String>) -> Self {
209 self.meta.summary = Some(summary.into());
210 self
211 }
212
213 pub fn description(mut self, description: impl Into<String>) -> Self {
215 self.meta.description = Some(description.into());
216 self
217 }
218
219 pub fn tag(mut self, tag: impl Into<String>) -> Self {
221 let tag = tag.into();
222 if !self.meta.tags.contains(&tag) {
223 self.meta.tags.push(tag);
224 }
225 self
226 }
227
228 pub fn status_code(mut self, status_code: StatusCode) -> Self {
230 self.meta.status_code = status_code;
231 self
232 }
233
234 pub fn response_model<T: ?Sized>(mut self) -> Self {
236 self.meta.response_model = Some(std::any::type_name::<T>());
237 self
238 }
239
240 pub fn request_schema<T: schemars::JsonSchema>(mut self) -> Self {
242 self.meta.request_schema = Some(|generator| generator.subschema_for::<T>());
243 self
244 }
245
246 pub fn request_schema_fn(mut self, thunk: SchemaThunk) -> Self {
248 self.meta.request_schema = Some(thunk);
249 self
250 }
251
252 pub fn request_kind(mut self, kind: RequestBodyKind) -> Self {
254 self.meta.request_kind = kind;
255 self
256 }
257
258 pub fn response_schema<T: schemars::JsonSchema>(mut self) -> Self {
260 self.meta.response_schema = Some(|generator| generator.subschema_for::<T>());
261 self
262 }
263
264 pub fn streaming(mut self) -> Self {
266 self.meta.streaming = true;
267 self
268 }
269
270 pub fn websocket(mut self) -> Self {
272 self.meta.websocket = true;
273 self
274 }
275
276 pub fn ws_incoming<T: schemars::JsonSchema>(mut self) -> Self {
278 self.meta.ws_incoming = Some(|generator| generator.subschema_for::<T>());
279 self
280 }
281
282 pub fn ws_outgoing<T: schemars::JsonSchema>(mut self) -> Self {
284 self.meta.ws_outgoing = Some(|generator| generator.subschema_for::<T>());
285 self
286 }
287
288 pub fn method(&self) -> &Method {
290 &self.method
291 }
292
293 pub fn path(&self) -> &str {
295 &self.path
296 }
297
298 pub fn meta(&self) -> &RouteMeta {
300 &self.meta
301 }
302
303 pub fn handler(&self) -> &HandlerFn {
305 &self.handler
306 }
307
308 pub(crate) fn request_hooks(&self) -> &[SharedRequestHook] {
310 &self.request_hooks
311 }
312
313 pub(crate) fn response_hooks(&self) -> &[SharedResponseHook] {
315 &self.response_hooks
316 }
317
318 pub(crate) fn error_hooks(&self) -> &[SharedErrorHook] {
320 &self.error_hooks
321 }
322
323 pub(crate) fn validation_hooks(&self) -> &[SharedValidationErrorHook] {
325 &self.validation_hooks
326 }
327
328 pub(crate) fn has_hooks(&self) -> bool {
330 !self.request_hooks.is_empty()
331 || !self.response_hooks.is_empty()
332 || !self.error_hooks.is_empty()
333 || !self.validation_hooks.is_empty()
334 }
335
336 fn prepend_prefix(mut self, prefix: &str) -> Self {
338 self.path = join_paths(prefix, &self.path);
339 self
340 }
341
342 fn inherit_tags(mut self, tags: &[String]) -> Self {
344 for tag in tags {
345 if !self.meta.tags.contains(tag) {
346 self.meta.tags.push(tag.clone());
347 }
348 }
349 self
350 }
351
352 fn prepend_hooks(
358 mut self,
359 request: &[SharedRequestHook],
360 response: &[SharedResponseHook],
361 error: &[SharedErrorHook],
362 validation: &[SharedValidationErrorHook],
363 ) -> Self {
364 self.request_hooks.splice(0..0, request.iter().cloned());
365 self.response_hooks.splice(0..0, response.iter().cloned());
366 self.error_hooks.splice(0..0, error.iter().cloned());
367 self.validation_hooks
368 .splice(0..0, validation.iter().cloned());
369 self
370 }
371}
372
373#[derive(Default)]
375pub struct Router {
376 prefix: String,
377 tags: Vec<String>,
378 routes: Vec<Route>,
379 request_hooks: Vec<SharedRequestHook>,
380 response_hooks: Vec<SharedResponseHook>,
381 error_hooks: Vec<SharedErrorHook>,
382 validation_hooks: Vec<SharedValidationErrorHook>,
383}
384
385impl Router {
386 pub fn new() -> Self {
388 Self::default()
389 }
390
391 scoped_hook_builders!();
392
393 pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
395 self.prefix = prefix.into();
396 self
397 }
398
399 pub fn tags(mut self, tags: &[&str]) -> Self {
401 self.tags = tags.iter().map(|tag| (*tag).to_owned()).collect();
402 self
403 }
404
405 pub fn route(mut self, route: Route) -> Self {
407 self.routes.push(route);
408 self
409 }
410
411 pub fn include(mut self, child: Router) -> Self {
413 self.routes.extend(child.into_routes());
417 self
418 }
419
420 pub fn into_routes(self) -> Vec<Route> {
425 let Router {
426 prefix,
427 tags,
428 routes,
429 request_hooks,
430 response_hooks,
431 error_hooks,
432 validation_hooks,
433 } = self;
434 routes
435 .into_iter()
436 .map(|route| {
437 route
438 .prepend_prefix(&prefix)
439 .inherit_tags(&tags)
440 .prepend_hooks(
441 &request_hooks,
442 &response_hooks,
443 &error_hooks,
444 &validation_hooks,
445 )
446 })
447 .collect()
448 }
449
450 pub fn routes(&self) -> &[Route] {
452 &self.routes
453 }
454}
455
456fn join_paths(prefix: &str, path: &str) -> String {
461 let head = prefix.trim_end_matches('/');
462 let tail = path.trim_start_matches('/');
463
464 let mut combined = String::with_capacity(head.len() + tail.len() + 1);
465 combined.push_str(head);
466 if !tail.is_empty() {
467 combined.push('/');
468 combined.push_str(tail);
469 }
470 if !combined.starts_with('/') {
471 combined.insert(0, '/');
472 }
473
474 let normalized = combined.trim_end_matches('/');
475 if normalized.is_empty() {
476 "/".to_owned()
477 } else {
478 normalized.to_owned()
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use crate::response::empty;
486
487 fn dummy_handler() -> HandlerFn {
488 Arc::new(
489 |_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
490 Box::pin(async { Ok(empty(StatusCode::OK)) })
491 },
492 )
493 }
494
495 fn get(path: &str) -> Route {
496 Route::new(Method::GET, path, dummy_handler())
497 }
498
499 #[test]
500 fn prefix_is_prepended_to_routes() {
501 let routes = Router::new()
502 .prefix("/users")
503 .tags(&["users"])
504 .route(get("/{user_id}"))
505 .into_routes();
506
507 assert_eq!(routes.len(), 1);
508 assert_eq!(routes[0].path(), "/users/{user_id}");
509 assert_eq!(routes[0].meta().tags, vec!["users".to_owned()]);
510 }
511
512 #[test]
513 fn root_route_drops_trailing_slash() {
514 let routes = Router::new().prefix("/users").route(get("/")).into_routes();
515
516 assert_eq!(routes[0].path(), "/users");
517 }
518
519 #[test]
520 fn nested_include_composes_prefixes_and_tags() {
521 let orders = Router::new()
522 .prefix("/{user_id}/orders")
523 .tags(&["orders"])
524 .route(get("/"));
525
526 let routes = Router::new()
527 .prefix("/users")
528 .tags(&["users"])
529 .include(orders)
530 .into_routes();
531
532 assert_eq!(routes[0].path(), "/users/{user_id}/orders");
533 assert_eq!(
534 routes[0].meta().tags,
535 vec!["orders".to_owned(), "users".to_owned()]
536 );
537 }
538
539 #[test]
540 fn route_tag_deduplicates_repeated_tags() {
541 let route = get("/x").tag("a").tag("a").tag("b");
542 assert_eq!(route.meta().tags, vec!["a".to_owned(), "b".to_owned()]);
543 }
544
545 #[test]
546 fn route_meta_default_has_empty_collections() {
547 let meta = RouteMeta::default();
548 assert!(meta.summary.is_none());
549 assert!(meta.description.is_none());
550 assert!(meta.tags.is_empty());
551 assert!(meta.request_schema.is_none());
552 assert!(meta.response_schema.is_none());
553 }
554
555 #[tokio::test]
556 async fn router_hooks_propagate_to_routes_outer_to_inner() {
557 use crate::hooks::{RequestEvent, RequestInfo};
558 use std::sync::Mutex;
559
560 let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(Vec::new()));
561 let outer_log = log.clone();
562 let inner_log = log.clone();
563
564 let inner = Router::new().route(get("/x")).on_request(move |_event| {
565 let log = inner_log.clone();
566 async move { log.lock().unwrap().push("inner") }
567 });
568 let outer = Router::new()
569 .on_request(move |_event| {
570 let log = outer_log.clone();
571 async move { log.lock().unwrap().push("outer") }
572 })
573 .include(inner);
574
575 let routes = outer.into_routes();
576 assert_eq!(routes.len(), 1);
577 let hooks = routes[0].request_hooks();
578 assert_eq!(hooks.len(), 2, "both router hooks attach to the route");
579
580 let info = RequestInfo::new(Method::GET, "/x".into(), Some("/x".into()), None);
581 for hook in hooks {
582 hook(RequestEvent::new(info.clone())).await;
583 }
584 assert_eq!(*log.lock().unwrap(), ["outer", "inner"]);
585 }
586}