tela/
router.rs

1use std::{collections::HashMap, convert::Infallible, ffi::OsStr, fs, path::Path, sync::Arc};
2
3use bytes::Bytes;
4use http_body_util::{BodyExt, Full};
5use hyper::{Method, Uri};
6use tokio::sync::{
7    mpsc::{self, Sender},
8    oneshot,
9};
10
11use crate::{
12    errors::{default_error_page, StatusCode},
13    request::{Catch, Endpoint},
14    uri::index,
15};
16
17/// Commands sent through channel to router
18#[derive(Debug)]
19pub enum Command {
20    Get {
21        method: Method,
22        path: String,
23        response: oneshot::Sender<Option<Route>>,
24    },
25    Error {
26        code: u16,
27        response: oneshot::Sender<Option<ErrorHandler>>,
28    },
29}
30
31#[derive(Debug, Clone)]
32pub struct Route(pub Arc<dyn Endpoint>);
33
34#[derive(Debug, Clone)]
35pub struct ErrorHandler(pub Arc<dyn Catch>);
36
37#[derive(Clone)]
38pub struct Router {
39    channel: Option<Sender<Command>>,
40    router: HashMap<Method, Vec<Route>>,
41    catch: HashMap<u16, ErrorHandler>,
42    assets: String,
43}
44impl Router {
45    pub fn new() -> Self {
46        Router {
47            channel: None,
48            router: HashMap::new(),
49            catch: HashMap::new(),
50            assets: "assets/".to_string(),
51        }
52    }
53
54    pub fn assets(&mut self, path: String) {
55        self.assets = path;
56    }
57
58    pub fn catch(&mut self, catch: Arc<dyn Catch>) {
59        if !self.catch.contains_key(&catch.code()) {
60            self.catch.insert(catch.code(), ErrorHandler(catch));
61        }
62    }
63
64    pub fn route(&mut self, route: Arc<dyn Endpoint>) {
65        for method in route.methods() {
66            if !self.router.contains_key(&method) {
67                self.router.insert(method.clone(), Vec::new());
68            }
69            self.router
70                .get_mut(&method)
71                .unwrap()
72                .push(Route(route.clone()));
73        }
74    }
75
76    /// Start listener thread for handling access to router
77    ///
78    /// Creates mpsc channel and returns Sender handle. The thread that this method
79    /// creates is the only instance of the router that should exists.
80    pub fn serve_routes(&mut self) {
81        let (tx, mut rx) = mpsc::channel::<Command>(32);
82        let router = self.router.clone();
83        let catch = self.catch.clone();
84
85        tokio::spawn(async move {
86            'watcher: while let Some(cmd) = rx.recv().await {
87                use Command::*;
88
89                match cmd {
90                    Get {
91                        method,
92                        path,
93                        response,
94                    } => {
95                        match router.get(&method) {
96                            Some(data) => {
97                                match index(
98                                    &path,
99                                    &data.iter().map(|r| r.0.path()).collect::<Vec<String>>(),
100                                ) {
101                                    Some(index) => {
102                                        response.send(Some(data[index].clone())).unwrap();
103                                        continue 'watcher;
104                                    }
105                                    _ => {}
106                                }
107                            }
108                            _ => {}
109                        };
110                        response.send(None).unwrap();
111                    }
112                    Error { code, response } => {
113                        if catch.contains_key(&code) {
114                            response
115                                .send(catch.get(&code).map(|eh| eh.clone()))
116                                .unwrap()
117                        } else if catch.contains_key(&0) {
118                            response.send(catch.get(&0).map(|eh| eh.clone())).unwrap()
119                        } else {
120                            response.send(None).unwrap()
121                        }
122                    }
123                }
124            }
125        });
126
127        self.channel = Some(tx);
128    }
129
130    async fn error(
131        &self,
132        uri: &Uri,
133        method: &Method,
134        body: &Vec<u8>,
135        code: u16,
136        reason: String,
137        channel: Sender<Command>,
138    ) -> std::result::Result<hyper::Response<Full<Bytes>>, Infallible> {
139        let (error_tx, error_rx) = oneshot::channel();
140        match channel
141            .send(Command::Error {
142                code: code.clone(),
143                response: error_tx,
144            })
145            .await
146        {
147            Ok(_) => {}
148            Err(error) => eprintln!("{:?}", error),
149        };
150
151        match error_rx.await.unwrap() {
152            Some(ErrorHandler(handler)) => {
153                match handler.execute(
154                    code.clone(),
155                    StatusCode::from(code.clone()).message(),
156                    reason.clone(),
157                ) {
158                    Ok(response) => {
159                        Router::log_request(
160                            &uri.path().to_string(),
161                            &method.clone(),
162                            &response.status().into(),
163                        );
164                        Ok(response)
165                    }
166                    Err((code, reason)) => {
167                        Router::log_request(&uri.path().to_string(), method, &code);
168                        Ok(default_error_page(
169                            &code,
170                            &reason,
171                            method,
172                            uri,
173                            std::str::from_utf8(body).unwrap_or("").to_string(),
174                        ))
175                    }
176                }
177            }
178            None => {
179                Router::log_request(&uri.path().to_string(), method, &code);
180                Ok(default_error_page(
181                    &code,
182                    &reason,
183                    method,
184                    uri,
185                    std::str::from_utf8(body).unwrap_or("").to_string(),
186                ))
187            }
188        }
189    }
190
191    fn log_request(path: &String, method: &Method, status: &u16) {
192        #[cfg(debug_assertions)]
193        eprintln!(
194            "  {}(\x1b[3{}m{}\x1b[39m) \x1b[32m{:?}\x1b[0m",
195            method,
196            match status {
197                100..=199 => 6,
198                200..=299 => 2,
199                300..=399 => 5,
200                400..=499 => 1,
201                500..=599 => 3,
202                _ => 7,
203            },
204            status,
205            path
206        );
207    }
208
209    pub async fn parse(
210        &self,
211        request: hyper::Request<hyper::body::Incoming>,
212    ) -> Result<hyper::Response<Full<Bytes>>, Infallible> {
213        // Get all needed information from request
214        let mut uri = request.uri().clone();
215        let method = request.method().clone();
216        // Can be used for validation, authentication, and other features
217        let _headers = request.headers().clone();
218        let mut body = request.collect().await.unwrap().to_bytes().to_vec();
219
220        let (endpoint_tx, endpoint_rx) = oneshot::channel();
221        match &self.channel {
222            Some(channel) => {
223                let path = format!("{}{}", self.assets, uri.path());
224                let path = Path::new(&path);
225                if let Some(extension) = path.extension().and_then(OsStr::to_str) {
226                    match fs::read_to_string(path) {
227                        Ok(text) => {
228                            Router::log_request(&uri.path().to_string(), &method, &200);
229                            let mut builder = hyper::Response::builder().status(200);
230
231                            match mime_guess::from_ext(extension).first() {
232                                Some(mime) => {
233                                    builder = builder.header("Content-Type", mime.to_string())
234                                }
235                                _ => {}
236                            };
237
238                            return Ok(builder.body(Full::new(Bytes::from(text))).unwrap());
239                        }
240                        _ => {
241                            Router::log_request(&uri.path().to_string(), &method, &404);
242                            return Ok(default_error_page(
243                                &404,
244                                &"File not found".to_string(),
245                                &method,
246                                &uri,
247                                std::str::from_utf8(body.as_slice())
248                                    .unwrap_or("")
249                                    .to_string(),
250                            ));
251                        }
252                    }
253                }
254
255                match channel
256                    .send(Command::Get {
257                        method: method.clone(),
258                        path: uri.path().to_string(),
259                        response: endpoint_tx,
260                    })
261                    .await
262                {
263                    Ok(_) => {}
264                    Err(error) => eprintln!("{}", error),
265                };
266
267                match endpoint_rx.await.unwrap() {
268                    Some(Route(endpoint)) => match endpoint.execute(&method, &mut uri, &mut body) {
269                        Ok(response) => {
270                            Router::log_request(
271                                &uri.path().to_string(),
272                                &method,
273                                &response.status().into(),
274                            );
275                            Ok(response)
276                        }
277                        Err((code, reason)) => {
278                            self.error(&uri, &method, &body, code, reason, channel.clone())
279                                .await
280                        }
281                    },
282                    None => {
283                        self.error(
284                            &uri,
285                            &method,
286                            &body,
287                            404,
288                            "Page not found in router".to_string(),
289                            channel.clone(),
290                        )
291                        .await
292                    }
293                }
294            }
295            _ => panic!("Unable to communicate with router"),
296        }
297    }
298}