1use std::convert::Infallible;
5use std::sync::Arc;
6use tokio::sync::Mutex;
7use hyper::{Body, Request, Response, Method, StatusCode};
8use futures::future::BoxFuture;
9use mime_guess::from_path;
10use tokio::fs::File;
11use tokio::io::AsyncReadExt;
12use serde::Deserialize;
13
14pub type Middleware = Arc<dyn Fn(Request<Body>, Arc<dyn Fn(Request<Body>) -> BoxFuture<'static, Response<Body>> + Send + Sync>) -> BoxFuture<'static, Response<Body>> + Send + Sync>;
15
16pub struct Router {
17 pub routes: Arc<Mutex<Vec<(Method, String, Arc<dyn Fn(Request<Body>) -> BoxFuture<'static, Response<Body>> + Send + Sync>)>>>,
18 pub middlewares: Arc<Mutex<Vec<Middleware>>>,
19}
20
21impl Router {
22 pub async fn new() -> Self {
23 Self {
24 routes: Arc::new(Mutex::new(Vec::new())),
25 middlewares: Arc::new(Mutex::new(Vec::new())),
26 }
27 }
28 pub async fn add_route<F>(&self, method: Method, path: &str, handler: F)
29 where
30 F: Fn(Request<Body>) -> BoxFuture<'static, Response<Body>> + Send + Sync + 'static,
31 {
32 let mut routes = self.routes.lock().await;
33 routes.push((method, path.to_string(), Arc::new(handler)));
34 }
35 pub async fn add_middleware(&self, mw: Middleware) {
36 let mut middlewares = self.middlewares.lock().await;
37 middlewares.push(mw);
38 }
39 pub async fn handle(&self, req: Request<Body>) -> Result<Response<Body>, Infallible> {
40 let method = req.method().clone();
41 let path = req.uri().path().to_string();
42 let routes = self.routes.lock().await;
43 for (m, p, handler) in routes.iter() {
44 if *m == method && *p == path {
45 let mut next = handler.clone();
46 let middlewares = self.middlewares.lock().await;
47 for mw in middlewares.iter().rev() {
48 let prev = next.clone();
49 let mw = mw.clone();
50 next = Arc::new(move |req| mw(req, prev.clone()));
51 }
52 return Ok(next(req).await);
53 }
54 }
55 Ok(Response::builder().status(StatusCode::NOT_FOUND).body(Body::from("Not Found")).unwrap())
56 }
57}
58
59pub async fn serve_static_file(path: &str) -> Response<Body> {
60 match File::open(path).await {
61 Ok(mut file) => {
62 let mut buf = Vec::new();
63 if let Err(_) = file.read_to_end(&mut buf).await {
64 return Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Body::from("Error reading file")).unwrap();
65 }
66 let mime = from_path(path).first_or_octet_stream();
67 Response::builder().header("Content-Type", mime.as_ref()).body(Body::from(buf)).unwrap()
68 }
69 Err(_) => Response::builder().status(StatusCode::NOT_FOUND).body(Body::from("File not found")).unwrap(),
70 }
71}
72
73pub fn parse_query(req: &Request<Body>) -> std::collections::HashMap<String, String> {
74 req.uri().query().map(|q| {
75 url::form_urlencoded::parse(q.as_bytes()).into_owned().collect()
76 }).unwrap_or_default()
77}
78
79pub async fn parse_json<T: for<'de> Deserialize<'de>>(req: Request<Body>) -> Result<T, Response<Body>> {
80 let whole_body = hyper::body::to_bytes(req.into_body()).await.map_err(|_| Response::builder().status(StatusCode::BAD_REQUEST).body(Body::from("Bad Request")).unwrap())?;
81 serde_json::from_slice(&whole_body).map_err(|_| Response::builder().status(StatusCode::BAD_REQUEST).body(Body::from("Invalid JSON")).unwrap())
82}