1use core::fmt;
4
5use viz_core::{
6 BoxHandler, Handler, HandlerExt, IntoResponse, Method, Next, Request, Response, Result,
7 Transform,
8};
9
10macro_rules! export_internal_verb {
11 ($name:ident $verb:tt) => {
12 #[doc = concat!(" Appends a handler buy the HTTP `", stringify!($verb), "` verb into the route.")]
13 #[must_use]
14 pub fn $name<H, O>(self, handler: H) -> Self
15 where
16 H: Handler<Request, Output = Result<O>> + Clone,
17 O: IntoResponse,
18 {
19 self.on(Method::$verb, handler)
20 }
21 };
22}
23
24macro_rules! export_verb {
25 ($name:ident $verb:ty) => {
26 #[doc = concat!(" Creates a route with a handler and HTTP `", stringify!($verb), "` verb pair.")]
27 #[must_use]
28 pub fn $name<H, O>(handler: H) -> Route
29 where
30 H: Handler<Request, Output = Result<O>> + Clone,
31 O: IntoResponse,
32 {
33 Route::new().$name(handler)
34 }
35 };
36}
37
38#[derive(Clone, Default)]
40pub struct Route {
41 pub(crate) methods: Vec<(Method, BoxHandler)>,
42}
43
44impl Route {
45 #[must_use]
47 pub const fn new() -> Self {
48 Self {
49 methods: Vec::new(),
50 }
51 }
52
53 #[must_use]
55 pub fn push(mut self, method: Method, handler: BoxHandler) -> Self {
56 match self
57 .methods
58 .iter_mut()
59 .find(|(m, _)| m == method)
60 .map(|(_, e)| e)
61 {
62 Some(h) => *h = handler,
63 None => self.methods.push((method, handler)),
64 }
65
66 self
67 }
68
69 #[must_use]
71 pub fn on<H, O>(self, method: Method, handler: H) -> Self
72 where
73 H: Handler<Request, Output = Result<O>> + Clone,
74 O: IntoResponse,
75 {
76 self.push(method, handler.map_into_response().boxed())
77 }
78
79 #[must_use]
81 pub fn any<H, O>(self, handler: H) -> Self
82 where
83 H: Handler<Request, Output = Result<O>> + Clone,
84 O: IntoResponse,
85 {
86 [
87 Method::GET,
88 Method::POST,
89 Method::PUT,
90 Method::DELETE,
91 Method::HEAD,
92 Method::OPTIONS,
93 Method::CONNECT,
94 Method::PATCH,
95 Method::TRACE,
96 ]
97 .into_iter()
98 .fold(self, |route, method| route.on(method, handler.clone()))
99 }
100
101 repeat!(
102 export_internal_verb
103 get GET
104 post POST
105 put PUT
106 delete DELETE
107 head HEAD
108 options OPTIONS
109 connect CONNECT
110 patch PATCH
111 trace TRACE
112 );
113
114 #[must_use]
116 pub fn map_handler<F>(self, f: F) -> Self
117 where
118 F: Fn(BoxHandler) -> BoxHandler,
119 {
120 self.into_iter()
121 .map(|(method, handler)| (method, f(handler)))
122 .collect()
123 }
124
125 #[must_use]
127 pub fn with<T>(self, t: T) -> Self
128 where
129 T: Transform<BoxHandler>,
130 T::Output: Handler<Request, Output = Result<Response>> + Clone,
131 {
132 self.map_handler(|handler| t.transform(handler).boxed())
133 }
134
135 #[must_use]
137 pub fn with_handler<H>(self, f: H) -> Self
138 where
139 H: Handler<Next<Request, BoxHandler>, Output = Result<Response>> + Clone,
140 {
141 self.map_handler(|handler| handler.around(f.clone()).boxed())
142 }
143}
144
145impl IntoIterator for Route {
146 type Item = (Method, BoxHandler);
147
148 type IntoIter = std::vec::IntoIter<(Method, BoxHandler)>;
149
150 fn into_iter(self) -> Self::IntoIter {
151 self.methods.into_iter()
152 }
153}
154
155impl FromIterator<(Method, BoxHandler)> for Route {
156 fn from_iter<T>(iter: T) -> Self
157 where
158 T: IntoIterator<Item = (Method, BoxHandler)>,
159 {
160 Self {
161 methods: iter.into_iter().collect(),
162 }
163 }
164}
165
166pub fn on<H, O>(method: Method, handler: H) -> Route
168where
169 H: Handler<Request, Output = Result<O>> + Clone,
170 O: IntoResponse,
171{
172 Route::new().on(method, handler)
173}
174
175repeat!(
176 export_verb
177 get GET
178 post POST
179 put PUT
180 delete DELETE
181 head HEAD
182 options OPTIONS
183 connect CONNECT
184 patch PATCH
185 trace TRACE
186);
187
188pub fn any<H, O>(handler: H) -> Route
190where
191 H: Handler<Request, Output = Result<O>> + Clone,
192 O: IntoResponse,
193{
194 Route::new().any(handler)
195}
196
197impl fmt::Debug for Route {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 f.debug_struct("Route")
200 .field(
201 "methods",
202 &self
203 .methods
204 .iter()
205 .map(|(m, _)| m)
206 .collect::<Vec<&Method>>(),
207 )
208 .finish()
209 }
210}
211
212#[cfg(test)]
213#[allow(dead_code)]
214#[allow(clippy::unused_async)]
215mod tests {
216 use super::Route;
217 use http_body_util::BodyExt;
218 use serde::Deserialize;
219 use std::sync::Arc;
220 use viz_core::{
221 Handler, HandlerExt, IntoHandler, IntoResponse, Method, Next, Request, RequestExt,
222 Response, Result, async_trait,
223 handler::Transform,
224 types::{Query, State},
225 };
226
227 #[tokio::test]
228 async fn route() -> anyhow::Result<()> {
229 async fn handler(_: Request) -> Result<impl IntoResponse> {
230 Ok(())
231 }
232
233 struct Logger;
234
235 impl Logger {
236 const fn new() -> Self {
237 Self
238 }
239 }
240
241 impl<H: Clone> Transform<H> for Logger {
242 type Output = LoggerHandler<H>;
243
244 fn transform(&self, h: H) -> Self::Output {
245 LoggerHandler(h)
246 }
247 }
248
249 #[derive(Clone)]
250 struct LoggerHandler<H>(H);
251
252 #[async_trait]
253 impl<H> Handler<Request> for LoggerHandler<H>
254 where
255 H: Handler<Request>,
256 {
257 type Output = H::Output;
258
259 async fn call(&self, req: Request) -> Self::Output {
260 self.0.call(req).await
261 }
262 }
263
264 async fn before(req: Request) -> Result<Request> {
265 Ok(req)
266 }
267
268 async fn after(res: Result<Response>) -> Result<Response> {
269 res
270 }
271
272 async fn around<H, O>((req, handler): Next<Request, H>) -> Result<Response>
273 where
274 H: Handler<Request, Output = Result<O>>,
275 O: IntoResponse,
276 {
277 handler.call(req).await.map(IntoResponse::into_response)
278 }
279
280 async fn around_1<H, O>((req, handler): Next<Request, H>) -> Result<Response>
281 where
282 H: Handler<Request, Output = Result<O>>,
283 O: IntoResponse,
284 {
285 handler.call(req).await.map(IntoResponse::into_response)
286 }
287
288 async fn around_2<H>((req, handler): Next<Request, H>) -> Result<Response>
289 where
290 H: Handler<Request, Output = Result<Response>>,
291 {
292 handler.call(req).await
293 }
294
295 #[derive(Clone)]
296 struct Around2 {
297 name: String,
298 }
299
300 #[async_trait]
301 impl<H, I, O> Handler<Next<I, H>> for Around2
302 where
303 I: Send + 'static,
304 H: Handler<I, Output = Result<O>>,
305 {
306 type Output = H::Output;
307
308 async fn call(&self, (i, h): Next<I, H>) -> Self::Output {
309 h.call(i).await
310 }
311 }
312
313 #[derive(Clone)]
314 struct Around3 {
315 name: String,
316 }
317
318 #[async_trait]
319 impl<H, O> Handler<Next<Request, H>> for Around3
320 where
321 H: Handler<Request, Output = Result<O>> + Clone,
322 O: IntoResponse,
323 {
324 type Output = Result<Response>;
325
326 async fn call(&self, (i, h): Next<Request, H>) -> Self::Output {
327 h.call(i).await.map(IntoResponse::into_response)
328 }
329 }
330
331 #[derive(Clone)]
332 struct Around4 {
333 name: String,
334 }
335
336 #[async_trait]
337 impl<H> Handler<Next<Request, H>> for Around4
338 where
339 H: Handler<Request, Output = Result<Response>> + Clone,
340 {
341 type Output = Result<Response>;
342
343 async fn call(&self, (i, h): Next<Request, H>) -> Self::Output {
344 h.call(i).await
345 }
346 }
347
348 #[derive(Deserialize)]
349 struct Counter {
350 c: u8,
351 }
352
353 async fn ext(q: Query<Counter>, s: State<Arc<String>>) -> Result<impl IntoResponse> {
354 let mut a = s.to_string().as_bytes().to_vec();
355 a.push(q.c);
356 Ok(a)
357 }
358
359 let route = Route::new()
360 .any(ext.into_handler())
361 .on(Method::GET, handler.map_into_response().before(before))
362 .on(Method::POST, handler.map_into_response().after(after))
363 .put(handler.around(Around2 {
364 name: "handler around".to_string(),
365 }))
366 .with(Logger::new())
367 .map_handler(|handler| {
368 handler
369 .before(|mut req: Request| async {
370 req.set_state(Arc::new("before".to_string()));
371 Ok(req)
372 })
373 .before(before)
374 .around(around_2)
375 .after(after)
376 .around(Around4 {
377 name: "4".to_string(),
378 })
379 .around(Around2 {
380 name: "2".to_string(),
381 })
382 .around(around)
383 .around(around_1)
384 .around(Around3 {
385 name: "3".to_string(),
386 })
387 .with(Logger::new())
388 .boxed()
391 })
392 .with_handler(around)
393 .with_handler(around_1)
394 .with_handler(around_2)
395 .with_handler(Around2 {
396 name: "2 with handler".to_string(),
397 })
398 .with_handler(Around3 {
399 name: "3 with handler".to_string(),
400 })
401 .with_handler(Around4 {
402 name: "4 with handler".to_string(),
403 })
404 .into_iter()
406 .collect::<Route>();
408
409 let (_, h) = route
410 .methods
411 .iter()
412 .find(|(m, _)| m == Method::GET)
413 .unwrap();
414
415 let resp = match h.call(Request::default()).await {
416 Ok(r) => r,
417 Err(e) => e.into_response(),
418 };
419 assert_eq!(resp.into_body().collect().await?.to_bytes(), "");
420
421 let (_, h) = route
422 .methods
423 .iter()
424 .find(|(m, _)| m == Method::DELETE)
425 .unwrap();
426
427 let mut req = Request::default();
428 *req.uri_mut() = "/?c=1".parse().unwrap();
429
430 let resp = match h.call(req).await {
431 Ok(r) => r,
432 Err(e) => e.into_response(),
433 };
434 assert_eq!(
435 resp.into_body().collect().await?.to_bytes().to_vec(),
436 vec![98, 101, 102, 111, 114, 101, 1]
437 );
438
439 Ok(())
440 }
441}