1use std::convert::Infallible;
17
18use http::{method::Method, status::StatusCode};
19use motore::{ServiceExt, layer::Layer, service::Service};
20use paste::paste;
21
22use super::{Fallback, Route};
23use crate::{
24 body::Body,
25 context::ServerContext,
26 request::Request,
27 response::Response,
28 server::{IntoResponse, handler::Handler},
29};
30
31pub struct MethodRouter<B = Body, E = Infallible> {
66 options: MethodEndpoint<B, E>,
67 get: MethodEndpoint<B, E>,
68 post: MethodEndpoint<B, E>,
69 put: MethodEndpoint<B, E>,
70 delete: MethodEndpoint<B, E>,
71 head: MethodEndpoint<B, E>,
72 trace: MethodEndpoint<B, E>,
73 connect: MethodEndpoint<B, E>,
74 patch: MethodEndpoint<B, E>,
75 fallback: Fallback<B, E>,
76}
77
78impl<B, E> Service<ServerContext, Request<B>> for MethodRouter<B, E>
79where
80 B: Send,
81{
82 type Response = Response;
83 type Error = E;
84
85 async fn call(
86 &self,
87 cx: &mut ServerContext,
88 req: Request<B>,
89 ) -> Result<Self::Response, Self::Error> {
90 let handler = match *req.method() {
91 Method::OPTIONS => Some(&self.options),
92 Method::GET => Some(&self.get),
93 Method::POST => Some(&self.post),
94 Method::PUT => Some(&self.put),
95 Method::DELETE => Some(&self.delete),
96 Method::HEAD => Some(&self.head),
97 Method::TRACE => Some(&self.trace),
98 Method::CONNECT => Some(&self.connect),
99 Method::PATCH => Some(&self.patch),
100 _ => None,
101 };
102
103 match handler {
104 Some(MethodEndpoint::Route(route)) => route.call(cx, req).await,
105 _ => self.fallback.call(cx, req).await,
106 }
107 }
108}
109
110impl<B, E> Default for MethodRouter<B, E>
111where
112 B: Send + 'static,
113 E: 'static,
114{
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120impl<B, E> MethodRouter<B, E>
121where
122 B: Send + 'static,
123 E: 'static,
124{
125 fn new() -> Self {
126 Self {
127 options: MethodEndpoint::None,
128 get: MethodEndpoint::None,
129 post: MethodEndpoint::None,
130 put: MethodEndpoint::None,
131 delete: MethodEndpoint::None,
132 head: MethodEndpoint::None,
133 trace: MethodEndpoint::None,
134 connect: MethodEndpoint::None,
135 patch: MethodEndpoint::None,
136 fallback: Fallback::from_status_code(StatusCode::METHOD_NOT_ALLOWED),
137 }
138 }
139
140 pub fn layer<L, B2, E2>(self, l: L) -> MethodRouter<B2, E2>
144 where
145 L: Layer<Route<B, E>> + Clone + Send + Sync + 'static,
146 L::Service: Service<ServerContext, Request<B2>, Error = E2> + Send + Sync + 'static,
147 <L::Service as Service<ServerContext, Request<B2>>>::Response: IntoResponse,
148 B2: 'static,
149 {
150 let Self {
151 options,
152 get,
153 post,
154 put,
155 delete,
156 head,
157 trace,
158 connect,
159 patch,
160 fallback,
161 } = self;
162
163 let layer_fn = move |route: Route<B, E>| {
164 Route::new(
165 l.clone()
166 .layer(route)
167 .map_response(IntoResponse::into_response),
168 )
169 };
170
171 let options = options.map(layer_fn.clone());
172 let get = get.map(layer_fn.clone());
173 let post = post.map(layer_fn.clone());
174 let put = put.map(layer_fn.clone());
175 let delete = delete.map(layer_fn.clone());
176 let head = head.map(layer_fn.clone());
177 let trace = trace.map(layer_fn.clone());
178 let connect = connect.map(layer_fn.clone());
179 let patch = patch.map(layer_fn.clone());
180
181 let fallback = fallback.map(layer_fn);
182
183 MethodRouter {
184 options,
185 get,
186 post,
187 put,
188 delete,
189 head,
190 trace,
191 connect,
192 patch,
193 fallback,
194 }
195 }
196}
197
198macro_rules! for_all_methods {
199 ($name:ident) => {
200 $name!(options, get, post, put, delete, head, trace, connect, patch);
201 };
202}
203
204macro_rules! impl_method_register_for_builder {
205 ($( $method:ident ),*) => {
206 $(
207 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
208 pub fn $method<H, T>(mut self, handler: H) -> Self
209 where
210 for<'a> H: Handler<T, B, E> + Clone + Send + Sync + 'a,
211 B: Send,
212 T: 'static,
213 {
214 self.$method = MethodEndpoint::from_handler(handler);
215 self
216 }
217
218 paste! {
219 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
220 pub fn [<$method _service>]<S>(mut self, service: S) -> MethodRouter<B, E>
221 where
222 for<'a> S: Service<ServerContext, Request<B>, Error = E>
223 + Send
224 + Sync
225 + 'a,
226 S::Response: IntoResponse,
227 {
228 self.$method = MethodEndpoint::from_service(service);
229 self
230 }
231 }
232 )+
233 };
234}
235
236impl<B, E> MethodRouter<B, E>
237where
238 B: Send + 'static,
239 E: IntoResponse + 'static,
240{
241 for_all_methods!(impl_method_register_for_builder);
242
243 pub fn fallback<H, T>(mut self, handler: H) -> Self
250 where
251 for<'a> H: Handler<T, B, E> + Clone + Send + Sync + 'a,
252 T: 'static,
253 {
254 self.fallback = Fallback::from_handler(handler);
255 self
256 }
257
258 pub fn fallback_service<S>(mut self, service: S) -> Self
265 where
266 for<'a> S: Service<ServerContext, Request<B>, Error = E> + Send + Sync + 'a,
267 S::Response: IntoResponse,
268 {
269 self.fallback = Fallback::from_service(service);
270 self
271 }
272}
273
274macro_rules! impl_method_register {
275 ($( $method:ident ),*) => {
276 $(
277 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
278 pub fn $method<H, T, B, E>(handler: H) -> MethodRouter<B, E>
279 where
280 for<'a> H: Handler<T, B, E> + Clone + Send + Sync + 'a,
281 T: 'static,
282 B: Send + 'static,
283 E: IntoResponse + 'static,
284 {
285 MethodRouter {
286 $method: MethodEndpoint::from_handler(handler),
287 ..Default::default()
288 }
289 }
290
291 paste! {
292 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
293 pub fn [<$method _service>]<S, B, E>(service: S) -> MethodRouter<B, E>
294 where
295 for<'a> S: Service<ServerContext, Request<B>, Error = E>
296 + Send
297 + Sync
298 + 'a,
299 S::Response: IntoResponse,
300 B: Send + 'static,
301 E: IntoResponse + 'static,
302 {
303 MethodRouter {
304 $method: MethodEndpoint::from_service(service),
305 ..Default::default()
306 }
307 }
308 }
309 )+
310 };
311}
312
313for_all_methods!(impl_method_register);
314
315pub fn any<H, T, B, E>(handler: H) -> MethodRouter<B, E>
317where
318 for<'a> H: Handler<T, B, E> + Clone + Send + Sync + 'a,
319 T: 'static,
320 B: Send + 'static,
321 E: IntoResponse + 'static,
322{
323 MethodRouter {
324 fallback: Fallback::from_handler(handler),
325 ..Default::default()
326 }
327}
328
329pub fn any_service<S, B, E>(service: S) -> MethodRouter<B, E>
331where
332 for<'a> S: Service<ServerContext, Request<B>, Error = E> + Send + Sync + 'a,
333 S::Response: IntoResponse,
334 B: Send + 'static,
335 E: IntoResponse + 'static,
336{
337 MethodRouter {
338 fallback: Fallback::from_service(service),
339 ..Default::default()
340 }
341}
342
343#[derive(Default)]
344enum MethodEndpoint<B = Body, E = Infallible> {
345 #[default]
346 None,
347 Route(Route<B, E>),
348}
349
350impl<B, E> MethodEndpoint<B, E>
351where
352 B: Send + 'static,
353{
354 fn from_handler<H, T>(handler: H) -> Self
355 where
356 for<'a> H: Handler<T, B, E> + Clone + Send + Sync + 'a,
357 T: 'static,
358 E: 'static,
359 {
360 Self::from_service(handler.into_service())
361 }
362
363 fn from_service<S>(service: S) -> Self
364 where
365 for<'a> S: Service<ServerContext, Request<B>, Error = E> + Send + Sync + 'a,
366 S::Response: IntoResponse,
367 {
368 Self::Route(Route::new(
369 service.map_response(IntoResponse::into_response),
370 ))
371 }
372
373 fn map<F, B2, E2>(self, f: F) -> MethodEndpoint<B2, E2>
374 where
375 F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + 'static,
376 {
377 match self {
378 Self::None => MethodEndpoint::None,
379 Self::Route(route) => MethodEndpoint::Route(f(route)),
380 }
381 }
382}
383
384#[cfg(test)]
385mod method_router_tests {
386 use http::{method::Method, status::StatusCode};
387
388 use super::{MethodRouter, any, get, head, options};
389 use crate::body::Body;
390
391 async fn always_ok() {}
392 async fn teapot() -> StatusCode {
393 StatusCode::IM_A_TEAPOT
394 }
395
396 #[tokio::test]
397 async fn method_router() {
398 async fn test_all_method<F>(router: MethodRouter<Option<Body>>, filter: F)
399 where
400 F: Fn(Method) -> bool,
401 {
402 let methods = [
403 Method::GET,
404 Method::POST,
405 Method::PUT,
406 Method::DELETE,
407 Method::HEAD,
408 Method::OPTIONS,
409 Method::CONNECT,
410 Method::PATCH,
411 Method::TRACE,
412 ];
413 for m in methods {
414 assert_eq!(
415 router
416 .call_route(m.clone(), None)
417 .await
418 .status()
419 .is_success(),
420 filter(m)
421 );
422 }
423 }
424
425 test_all_method(get(always_ok), |m| m == Method::GET).await;
426 test_all_method(head(always_ok), |m| m == Method::HEAD).await;
427 test_all_method(any(always_ok), |_| true).await;
428 }
429
430 #[tokio::test]
431 async fn method_fallback() {
432 async fn test_all_method<F>(router: MethodRouter<Option<Body>>, filter: F)
433 where
434 F: Fn(Method) -> bool,
435 {
436 let methods = [
437 Method::GET,
438 Method::POST,
439 Method::PUT,
440 Method::DELETE,
441 Method::HEAD,
442 Method::OPTIONS,
443 Method::CONNECT,
444 Method::PATCH,
445 Method::TRACE,
446 ];
447 for m in methods {
448 assert_eq!(
449 router.call_route(m.clone(), None).await.status() == StatusCode::IM_A_TEAPOT,
450 filter(m)
451 );
452 }
453 }
454
455 test_all_method(get(always_ok).fallback(teapot), |m| m != Method::GET).await;
456 test_all_method(options(always_ok).fallback(teapot), |m| {
457 m != Method::OPTIONS
458 })
459 .await;
460 test_all_method(any(teapot), |_| true).await;
461 }
462}