small_router/
lib.rs

1/*
2 * Copyright (c) 2024-2025 Bastiaan van der Plaat
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7#![doc = include_str!("../README.md")]
8#![forbid(unsafe_code)]
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use small_http::{Method, Request, Response, Status};
14
15// MARK: Handler
16
17/// Parsed path parameters
18type HandlerFn<T> = fn(&Request, &T) -> Response;
19type PreLayerFn<T> = fn(&Request, &mut T) -> Option<Response>;
20type PostLayerFn<T> = fn(&Request, &mut T, Response) -> Response;
21
22struct Handler<T> {
23    handler: HandlerFn<T>,
24    pre_layers: Vec<PreLayerFn<T>>,
25    post_layers: Vec<PostLayerFn<T>>,
26}
27
28impl<T> Handler<T> {
29    fn new(
30        handler: HandlerFn<T>,
31        pre_layers: Vec<PreLayerFn<T>>,
32        post_layers: Vec<PostLayerFn<T>>,
33    ) -> Self {
34        Self {
35            handler,
36            pre_layers,
37            post_layers,
38        }
39    }
40
41    fn call(&self, req: &Request, ctx: &mut T) -> Response {
42        for pre_layer in &self.pre_layers {
43            if let Some(mut res) = pre_layer(req, ctx) {
44                for post_layer in &self.post_layers {
45                    res = post_layer(req, ctx, res);
46                }
47                return res;
48            }
49        }
50        let mut res = (self.handler)(req, ctx);
51        for post_layer in &self.post_layers {
52            res = post_layer(req, ctx, res);
53        }
54        res
55    }
56}
57
58// MARK: Route
59enum RoutePart {
60    Static(String),
61    Param(String),
62}
63
64struct Route<T> {
65    methods: Vec<Method>,
66    route: String,
67    parts: Vec<RoutePart>,
68    handler: Handler<T>,
69}
70
71impl<T> Route<T> {
72    fn new(methods: Vec<Method>, route: String, handler: Handler<T>) -> Self {
73        let parts = Self::route_parse_parts(&route);
74        Self {
75            methods,
76            route,
77            parts,
78            handler,
79        }
80    }
81
82    fn route_parse_parts(route: &str) -> Vec<RoutePart> {
83        route
84            .split('/')
85            .filter(|part| !part.is_empty())
86            .map(|part| {
87                if let Some(stripped) = part.strip_prefix(':') {
88                    RoutePart::Param(stripped.to_string())
89                } else {
90                    RoutePart::Static(part.to_string())
91                }
92            })
93            .collect()
94    }
95
96    fn is_match(&self, path: &str) -> bool {
97        let mut path_parts = path.split('/').filter(|part| !part.is_empty());
98        for part in &self.parts {
99            match part {
100                RoutePart::Static(expected) => {
101                    if let Some(actual) = path_parts.next() {
102                        if actual != *expected {
103                            return false;
104                        }
105                    } else {
106                        return false;
107                    }
108                }
109                RoutePart::Param(_) => {
110                    if path_parts.next().is_none() {
111                        return false;
112                    }
113                }
114            }
115        }
116        path_parts.next().is_none()
117    }
118
119    fn match_path(&self, path: &str) -> HashMap<String, String> {
120        let mut path_parts = path.split('/').filter(|part| !part.is_empty());
121        let mut params = HashMap::new();
122        for part in &self.parts {
123            match part {
124                RoutePart::Static(_) => {
125                    path_parts.next();
126                }
127                RoutePart::Param(name) => {
128                    if let Some(value) = path_parts.next() {
129                        params.insert(name.clone(), value.to_string());
130                    }
131                }
132            }
133        }
134        params
135    }
136}
137
138// MARK: RouterBuilder
139/// Router builder
140pub struct RouterBuilder<T: Clone> {
141    ctx: T,
142    pre_layers: Vec<PreLayerFn<T>>,
143    post_layers: Vec<PostLayerFn<T>>,
144    routes: Vec<Route<T>>,
145    not_allowed_method_handler: Option<Handler<T>>,
146    fallback_handler: Option<Handler<T>>,
147}
148
149impl<T: Clone> RouterBuilder<T> {
150    /// Create new router with context
151    pub fn with(ctx: T) -> Self {
152        Self {
153            ctx,
154            pre_layers: Vec::new(),
155            post_layers: Vec::new(),
156            routes: Vec::new(),
157            not_allowed_method_handler: None,
158            fallback_handler: None,
159        }
160    }
161
162    /// Add pre layer
163    pub fn pre_layer(mut self, layer: PreLayerFn<T>) -> Self {
164        self.pre_layers.push(layer);
165        self
166    }
167
168    /// Add post layer
169    pub fn post_layer(mut self, layer: PostLayerFn<T>) -> Self {
170        self.post_layers.push(layer);
171        self
172    }
173
174    /// Add route
175    pub fn route(mut self, methods: &[Method], route: String, handler: HandlerFn<T>) -> Self {
176        self.routes.push(Route::new(
177            methods.to_vec(),
178            route,
179            Handler::new(handler, self.pre_layers.clone(), self.post_layers.clone()),
180        ));
181        self
182    }
183
184    /// Add route for any method
185    pub fn any(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
186        self.route(
187            &[
188                Method::Get,
189                Method::Head,
190                Method::Post,
191                Method::Put,
192                Method::Delete,
193                Method::Connect,
194                Method::Options,
195                Method::Trace,
196                Method::Patch,
197            ],
198            route.into(),
199            handler,
200        )
201    }
202    /// Add route for GET method
203    pub fn get(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
204        self.route(&[Method::Get], route.into(), handler)
205    }
206
207    /// Add route for HEAD method
208    pub fn head(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
209        self.route(&[Method::Head], route.into(), handler)
210    }
211
212    /// Add route for POST method
213    pub fn post(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
214        self.route(&[Method::Post], route.into(), handler)
215    }
216
217    /// Add route for PUT method
218    pub fn put(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
219        self.route(&[Method::Put], route.into(), handler)
220    }
221
222    /// Add route for DELETE method
223    pub fn delete(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
224        self.route(&[Method::Delete], route.into(), handler)
225    }
226
227    /// Add route for CONNECT method
228    pub fn connect(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
229        self.route(&[Method::Connect], route.into(), handler)
230    }
231
232    /// Add route for OPTIONS method
233    pub fn options(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
234        self.route(&[Method::Options], route.into(), handler)
235    }
236
237    /// Add route for TRACE method
238    pub fn trace(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
239        self.route(&[Method::Trace], route.into(), handler)
240    }
241
242    /// Add route for PATCH method
243    pub fn patch(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
244        self.route(&[Method::Patch], route.into(), handler)
245    }
246
247    /// Set fallback handler
248    pub fn fallback(mut self, handler: HandlerFn<T>) -> Self {
249        self.fallback_handler = Some(Handler::new(
250            handler,
251            self.pre_layers.clone(),
252            self.post_layers.clone(),
253        ));
254        self
255    }
256
257    /// Build router
258    pub fn build(self) -> Router<T> {
259        Router(Arc::new(InnerRouter {
260            ctx: self.ctx,
261            routes: self.routes,
262            not_allowed_method_handler: self.not_allowed_method_handler.unwrap_or_else(|| {
263                Handler::new(
264                    |_, _| {
265                        Response::with_status(Status::MethodNotAllowed)
266                            .body("405 Method Not Allowed")
267                    },
268                    self.pre_layers.clone(),
269                    self.post_layers.clone(),
270                )
271            }),
272            fallback_handler: self.fallback_handler.unwrap_or_else(|| {
273                Handler::new(
274                    |_, _| Response::with_status(Status::NotFound).body("404 Not Found"),
275                    self.pre_layers.clone(),
276                    self.post_layers.clone(),
277                )
278            }),
279        }))
280    }
281}
282
283// MARK: InnerRouter
284struct InnerRouter<T: Clone> {
285    ctx: T,
286    routes: Vec<Route<T>>,
287    not_allowed_method_handler: Handler<T>,
288    fallback_handler: Handler<T>,
289}
290
291impl<T: Clone> InnerRouter<T> {
292    fn handle(&self, req: &Request) -> Response {
293        let mut ctx = self.ctx.clone();
294
295        // Match routes
296        let path = req.url.path();
297        for route in self.routes.iter().rev() {
298            if route.is_match(path) {
299                let mut req = req.clone();
300                req.params = route.match_path(path);
301
302                // Find matching route by method
303                for route in self.routes.iter().filter(|r| r.route == route.route) {
304                    if route.methods.contains(&req.method) {
305                        return route.handler.call(&req, &mut ctx);
306                    }
307                }
308
309                // Or run not allowed method handler
310                return self.not_allowed_method_handler.call(&req, &mut ctx);
311            }
312        }
313
314        // Or run fallback handler
315        self.fallback_handler.call(req, &mut ctx)
316    }
317}
318
319// MARK: Router
320/// Router
321#[derive(Clone)]
322pub struct Router<T: Clone>(Arc<InnerRouter<T>>);
323
324impl<T: Clone> Router<T> {
325    /// Handle request
326    pub fn handle(&self, req: &Request) -> Response {
327        self.0.handle(req)
328    }
329}
330
331// MARK: Tests
332#[cfg(test)]
333mod test {
334    use small_http::Status;
335
336    use super::*;
337
338    fn home(_req: &Request, _ctx: &()) -> Response {
339        Response::with_status(Status::Ok).body("Hello, World!")
340    }
341
342    fn hello(req: &Request, _ctx: &()) -> Response {
343        let name = req.params.get("name").unwrap();
344        Response::with_status(Status::Ok).body(format!("Hello, {}!", name))
345    }
346
347    #[test]
348    fn test_routing() {
349        let router = RouterBuilder::with(())
350            .get("/", home)
351            .get("/hello/:name", hello)
352            .get("/hello/:name/i/:am/so/:deep", hello)
353            .build();
354
355        // Test home route
356        let res = router.handle(&Request::with_url("http://localhost/"));
357        assert_eq!(res.status, Status::Ok);
358        assert_eq!(res.body, b"Hello, World!");
359
360        // Test fallback route
361        let res = router.handle(&Request::with_url("http://localhost/unknown"));
362        assert_eq!(res.status, Status::NotFound);
363        assert_eq!(res.body, b"404 Not Found");
364
365        // Test route with params
366        let res = router.handle(&Request::with_url("http://localhost/hello/Bassie"));
367        assert_eq!(res.status, Status::Ok);
368        assert_eq!(res.body, b"Hello, Bassie!");
369
370        // Test route with multiple params
371        let res = router.handle(&Request::with_url(
372            "http://localhost/hello/Bassie/i/handle/so/much",
373        ));
374        assert_eq!(res.status, Status::Ok);
375
376        // Test wrong method
377        let res = router.handle(&Request::with_url("http://localhost/").method(Method::Options));
378        assert_eq!(res.status, Status::MethodNotAllowed);
379        assert_eq!(res.body, b"405 Method Not Allowed");
380    }
381}