1#![doc = include_str!("../README.md")]
8#![forbid(unsafe_code)]
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use small_http::{Method, Request, Response, Status};
14
15type HandlerFn<T> = fn(&Request, &T) -> Response;
19type PreLayerFn<T> = fn(&Request, &mut T) -> Option<Response>;
20type PostLayerFn<T> = fn(&Request, &mut T, Response) -> Response;
21
22struct Handler<T> {
23 handler: HandlerFn<T>,
24 pre_layers: Vec<PreLayerFn<T>>,
25 post_layers: Vec<PostLayerFn<T>>,
26}
27
28impl<T> Handler<T> {
29 fn new(
30 handler: HandlerFn<T>,
31 pre_layers: Vec<PreLayerFn<T>>,
32 post_layers: Vec<PostLayerFn<T>>,
33 ) -> Self {
34 Self {
35 handler,
36 pre_layers,
37 post_layers,
38 }
39 }
40
41 fn call(&self, req: &Request, ctx: &mut T) -> Response {
42 for pre_layer in &self.pre_layers {
43 if let Some(mut res) = pre_layer(req, ctx) {
44 for post_layer in &self.post_layers {
45 res = post_layer(req, ctx, res);
46 }
47 return res;
48 }
49 }
50 let mut res = (self.handler)(req, ctx);
51 for post_layer in &self.post_layers {
52 res = post_layer(req, ctx, res);
53 }
54 res
55 }
56}
57
58enum RoutePart {
60 Static(String),
61 Param(String),
62}
63
64struct Route<T> {
65 methods: Vec<Method>,
66 route: String,
67 parts: Vec<RoutePart>,
68 handler: Handler<T>,
69}
70
71impl<T> Route<T> {
72 fn new(methods: Vec<Method>, route: String, handler: Handler<T>) -> Self {
73 let parts = Self::route_parse_parts(&route);
74 Self {
75 methods,
76 route,
77 parts,
78 handler,
79 }
80 }
81
82 fn route_parse_parts(route: &str) -> Vec<RoutePart> {
83 route
84 .split('/')
85 .filter(|part| !part.is_empty())
86 .map(|part| {
87 if let Some(stripped) = part.strip_prefix(':') {
88 RoutePart::Param(stripped.to_string())
89 } else {
90 RoutePart::Static(part.to_string())
91 }
92 })
93 .collect()
94 }
95
96 fn is_match(&self, path: &str) -> bool {
97 let mut path_parts = path.split('/').filter(|part| !part.is_empty());
98 for part in &self.parts {
99 match part {
100 RoutePart::Static(expected) => {
101 if let Some(actual) = path_parts.next() {
102 if actual != *expected {
103 return false;
104 }
105 } else {
106 return false;
107 }
108 }
109 RoutePart::Param(_) => {
110 if path_parts.next().is_none() {
111 return false;
112 }
113 }
114 }
115 }
116 path_parts.next().is_none()
117 }
118
119 fn match_path(&self, path: &str) -> HashMap<String, String> {
120 let mut path_parts = path.split('/').filter(|part| !part.is_empty());
121 let mut params = HashMap::new();
122 for part in &self.parts {
123 match part {
124 RoutePart::Static(_) => {
125 path_parts.next();
126 }
127 RoutePart::Param(name) => {
128 if let Some(value) = path_parts.next() {
129 params.insert(name.clone(), value.to_string());
130 }
131 }
132 }
133 }
134 params
135 }
136}
137
138pub struct RouterBuilder<T: Clone> {
141 ctx: T,
142 pre_layers: Vec<PreLayerFn<T>>,
143 post_layers: Vec<PostLayerFn<T>>,
144 routes: Vec<Route<T>>,
145 not_allowed_method_handler: Option<Handler<T>>,
146 fallback_handler: Option<Handler<T>>,
147}
148
149impl<T: Clone> RouterBuilder<T> {
150 pub fn with(ctx: T) -> Self {
152 Self {
153 ctx,
154 pre_layers: Vec::new(),
155 post_layers: Vec::new(),
156 routes: Vec::new(),
157 not_allowed_method_handler: None,
158 fallback_handler: None,
159 }
160 }
161
162 pub fn pre_layer(mut self, layer: PreLayerFn<T>) -> Self {
164 self.pre_layers.push(layer);
165 self
166 }
167
168 pub fn post_layer(mut self, layer: PostLayerFn<T>) -> Self {
170 self.post_layers.push(layer);
171 self
172 }
173
174 pub fn route(mut self, methods: &[Method], route: String, handler: HandlerFn<T>) -> Self {
176 self.routes.push(Route::new(
177 methods.to_vec(),
178 route,
179 Handler::new(handler, self.pre_layers.clone(), self.post_layers.clone()),
180 ));
181 self
182 }
183
184 pub fn any(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
186 self.route(
187 &[
188 Method::Get,
189 Method::Head,
190 Method::Post,
191 Method::Put,
192 Method::Delete,
193 Method::Connect,
194 Method::Options,
195 Method::Trace,
196 Method::Patch,
197 ],
198 route.into(),
199 handler,
200 )
201 }
202 pub fn get(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
204 self.route(&[Method::Get], route.into(), handler)
205 }
206
207 pub fn head(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
209 self.route(&[Method::Head], route.into(), handler)
210 }
211
212 pub fn post(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
214 self.route(&[Method::Post], route.into(), handler)
215 }
216
217 pub fn put(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
219 self.route(&[Method::Put], route.into(), handler)
220 }
221
222 pub fn delete(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
224 self.route(&[Method::Delete], route.into(), handler)
225 }
226
227 pub fn connect(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
229 self.route(&[Method::Connect], route.into(), handler)
230 }
231
232 pub fn options(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
234 self.route(&[Method::Options], route.into(), handler)
235 }
236
237 pub fn trace(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
239 self.route(&[Method::Trace], route.into(), handler)
240 }
241
242 pub fn patch(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
244 self.route(&[Method::Patch], route.into(), handler)
245 }
246
247 pub fn fallback(mut self, handler: HandlerFn<T>) -> Self {
249 self.fallback_handler = Some(Handler::new(
250 handler,
251 self.pre_layers.clone(),
252 self.post_layers.clone(),
253 ));
254 self
255 }
256
257 pub fn build(self) -> Router<T> {
259 Router(Arc::new(InnerRouter {
260 ctx: self.ctx,
261 routes: self.routes,
262 not_allowed_method_handler: self.not_allowed_method_handler.unwrap_or_else(|| {
263 Handler::new(
264 |_, _| {
265 Response::with_status(Status::MethodNotAllowed)
266 .body("405 Method Not Allowed")
267 },
268 self.pre_layers.clone(),
269 self.post_layers.clone(),
270 )
271 }),
272 fallback_handler: self.fallback_handler.unwrap_or_else(|| {
273 Handler::new(
274 |_, _| Response::with_status(Status::NotFound).body("404 Not Found"),
275 self.pre_layers.clone(),
276 self.post_layers.clone(),
277 )
278 }),
279 }))
280 }
281}
282
283struct InnerRouter<T: Clone> {
285 ctx: T,
286 routes: Vec<Route<T>>,
287 not_allowed_method_handler: Handler<T>,
288 fallback_handler: Handler<T>,
289}
290
291impl<T: Clone> InnerRouter<T> {
292 fn handle(&self, req: &Request) -> Response {
293 let mut ctx = self.ctx.clone();
294
295 let path = req.url.path();
297 for route in self.routes.iter().rev() {
298 if route.is_match(path) {
299 let mut req = req.clone();
300 req.params = route.match_path(path);
301
302 for route in self.routes.iter().filter(|r| r.route == route.route) {
304 if route.methods.contains(&req.method) {
305 return route.handler.call(&req, &mut ctx);
306 }
307 }
308
309 return self.not_allowed_method_handler.call(&req, &mut ctx);
311 }
312 }
313
314 self.fallback_handler.call(req, &mut ctx)
316 }
317}
318
319#[derive(Clone)]
322pub struct Router<T: Clone>(Arc<InnerRouter<T>>);
323
324impl<T: Clone> Router<T> {
325 pub fn handle(&self, req: &Request) -> Response {
327 self.0.handle(req)
328 }
329}
330
331#[cfg(test)]
333mod test {
334 use small_http::Status;
335
336 use super::*;
337
338 fn home(_req: &Request, _ctx: &()) -> Response {
339 Response::with_status(Status::Ok).body("Hello, World!")
340 }
341
342 fn hello(req: &Request, _ctx: &()) -> Response {
343 let name = req.params.get("name").unwrap();
344 Response::with_status(Status::Ok).body(format!("Hello, {}!", name))
345 }
346
347 #[test]
348 fn test_routing() {
349 let router = RouterBuilder::with(())
350 .get("/", home)
351 .get("/hello/:name", hello)
352 .get("/hello/:name/i/:am/so/:deep", hello)
353 .build();
354
355 let res = router.handle(&Request::with_url("http://localhost/"));
357 assert_eq!(res.status, Status::Ok);
358 assert_eq!(res.body, b"Hello, World!");
359
360 let res = router.handle(&Request::with_url("http://localhost/unknown"));
362 assert_eq!(res.status, Status::NotFound);
363 assert_eq!(res.body, b"404 Not Found");
364
365 let res = router.handle(&Request::with_url("http://localhost/hello/Bassie"));
367 assert_eq!(res.status, Status::Ok);
368 assert_eq!(res.body, b"Hello, Bassie!");
369
370 let res = router.handle(&Request::with_url(
372 "http://localhost/hello/Bassie/i/handle/so/much",
373 ));
374 assert_eq!(res.status, Status::Ok);
375
376 let res = router.handle(&Request::with_url("http://localhost/").method(Method::Options));
378 assert_eq!(res.status, Status::MethodNotAllowed);
379 assert_eq!(res.body, b"405 Method Not Allowed");
380 }
381}