scratch_server/
router.rs

1use base64::{prelude::BASE64_STANDARD, Engine};
2use regex::Regex;
3use serde_json::json;
4use std::{collections::HashMap, net::IpAddr, sync::Arc};
5use termcolor::Color;
6
7use crate::{logger::Logger, ApiError, Body, HttpResponse};
8
9#[derive(Debug)]
10pub enum HttpMethod {
11    GET,
12    POST,
13    PUT,
14    DELETE,
15    PATCH,
16    OPTIONS,
17    HEAD,
18    TRACE,
19    CONNECT,
20}
21
22pub struct Credentials {
23    username: String,
24    password: String,
25}
26
27impl HttpMethod {
28    fn as_str(&self) -> &str {
29        match self {
30            HttpMethod::GET => "GET",
31            HttpMethod::POST => "POST",
32            HttpMethod::PUT => "PUT",
33            HttpMethod::DELETE => "DELETE",
34            HttpMethod::PATCH => "PATCH",
35            HttpMethod::OPTIONS => "OPTIONS",
36            HttpMethod::HEAD => "HEAD",
37            HttpMethod::TRACE => "TRACE",
38            HttpMethod::CONNECT => "CONNECT",
39        }
40    }
41}
42
43fn get_status_code_color(status_code: u16) -> Color {
44    match status_code {
45        100..=199 => Color::Cyan,
46        200..=299 => Color::Green,
47        300..=399 => Color::Yellow,
48        400..=499 => Color::Red,
49        _ => Color::Magenta,
50    }
51}
52
53type Handler =
54    Box<dyn Fn(Option<&str>, HashMap<&str, &str>) -> Result<HttpResponse, ApiError> + Send + Sync>;
55
56pub struct Route {
57    pattern: Regex,
58    handler: Handler,
59    method: HttpMethod,
60    authorize: bool,
61}
62pub struct Router {
63    routes: Vec<Route>,
64    logger: Option<Arc<Logger>>,
65    pub(super) cors: Option<Cors>,
66    pub(super) credentials: Option<Credentials>,
67}
68
69impl Router {
70    pub fn new() -> Self {
71        Router {
72            routes: Vec::new(),
73            logger: None,
74            cors: None,
75            credentials: None,
76        }
77    }
78    pub fn with_logger(mut self, logger: Option<Arc<Logger>>) -> Self {
79        self.logger = logger;
80        self
81    }
82
83    pub fn with_cors(mut self, cors: Cors) -> Self {
84        self.cors = Some(cors);
85        self
86    }
87
88    pub fn with_credentials(mut self, password: &str, username: &str) -> Self {
89        self.credentials = Some(Credentials {
90            username: username.to_string(),
91            password: password.to_string(),
92        });
93        self
94    }
95
96    pub fn add_route<F>(&mut self, path: &str, method: HttpMethod, handler: F, authorize: bool)
97    where
98        F: Fn(Option<&str>, HashMap<&str, &str>) -> Result<HttpResponse, ApiError>
99            + Send
100            + Sync
101            + 'static,
102    {
103        let pattern = if path == "/*" {
104            "^.*$".to_string()
105        } else {
106            format!("^{}$", path.replace('{', "(?P<").replace('}', ">[^/]+)"))
107        };
108        let regex = Regex::new(&pattern).unwrap();
109        self.routes.push(Route {
110            pattern: regex,
111            handler: Box::new(handler),
112            method,
113            authorize,
114        });
115    }
116
117    pub fn route(
118        &self,
119        path: &str,
120        method: &str,
121        data: Option<&str>,
122        peer_addr: IpAddr,
123        headers: &HashMap<&str, &str>,
124    ) -> Result<HttpResponse, ApiError> {
125        let stripped_path: Vec<&str> = path.splitn(2, '?').collect();
126        if method == HttpMethod::OPTIONS.as_str() {
127            let mut response = HttpResponse::new(None, None, 204);
128            if let Some(cors) = &self.cors {
129                for (key, value) in &cors.headers {
130                    response = response.add_response_header(key, value);
131                }
132            }
133            Ok(response)
134        } else {
135            for route in &self.routes {
136                let pattern_match = route.pattern.captures(stripped_path[0]);
137
138                match pattern_match {
139                    Some(pattern_match) => {
140                        if route.method.as_str() != method {
141                            return Err(ApiError::new_with_json(405, "Method Not Allowed"));
142                        }
143                        if route.authorize {
144                            if let Some(credentials) = &self.credentials {
145                                if let Some(auth_header) = headers.get("Authorization") {
146                                    challenge_basic_auth(
147                                        auth_header,
148                                        &credentials.password,
149                                        &credentials.username,
150                                    )?;
151                                } else {
152                                    return Ok(HttpResponse::new(
153                                        Some(Body::Json(json!({"message": "Unauthorized"}))),
154                                        None,
155                                        401,
156                                    )
157                                    .add_response_header("WWW-Authenticate", "Basic"));
158                                }
159                            } else {
160                                return Err(ApiError::new_with_json(
161                                    500,
162                                    "Missing credentials configuration",
163                                ));
164                            }
165                        }
166                        let mut param_dict: HashMap<&str, &str> = route
167                            .pattern
168                            .capture_names()
169                            .flatten()
170                            .filter_map(|n| Some((n, pattern_match.name(n)?.as_str())))
171                            .collect();
172
173                        if stripped_path.len() == 2 {
174                            for param in stripped_path[1].split('&') {
175                                let pair: Vec<&str> = param.split('=').collect();
176                                if pair.len() == 2 {
177                                    param_dict.insert(pair[0], pair[1]);
178                                }
179                            }
180                        }
181                        let mut response =
182                            (route.handler)(data, param_dict).map_err(|mut err| {
183                                err.method = Some(method.to_string());
184                                err.path = Some(stripped_path[0].to_string());
185                                err
186                            })?;
187
188                        if let Some(cors) = &self.cors {
189                            for (key, value) in &cors.headers {
190                                response = response.add_response_header(key, value);
191                            }
192                        }
193
194                        self.log_response(
195                            response.status_code,
196                            stripped_path[0],
197                            method,
198                            peer_addr,
199                        )?;
200
201                        return Ok(response);
202                    }
203                    None => continue,
204                }
205            }
206            let error_response = HttpResponse::new(
207                Some(Body::Json(
208                    json!({"message": format!("No route found for path {}", path)}),
209                )),
210                None,
211                404,
212            );
213
214            self.log_response(
215                error_response.status_code,
216                stripped_path[0],
217                method,
218                peer_addr,
219            )?;
220
221            Ok(error_response)
222        }
223    }
224    pub fn log_response(
225        &self,
226        status_code: u16,
227        path: &str,
228        method: &str,
229        peer_addr: IpAddr,
230    ) -> Result<(), Box<dyn std::error::Error>> {
231        if let Some(logger) = &self.logger {
232            let time_string = chrono::offset::Local::now()
233                .format("%Y-%m-%d %H:%M:%S")
234                .to_string();
235            let status_code_color = get_status_code_color(status_code);
236
237            let args = vec![
238                (time_string, Some(Color::White)),
239                (peer_addr.to_string(), Some(Color::Rgb(255, 167, 7))),
240                (status_code.to_string(), Some(status_code_color)),
241                (method.to_string(), Some(Color::White)),
242                (path.to_string(), Some(Color::White)),
243            ];
244
245            logger.log_stdout("{} - {} - {} - {} {}", args)?;
246        }
247        Ok(())
248    }
249}
250
251impl Default for Router {
252    fn default() -> Self {
253        Self::new()
254    }
255}
256
257pub struct Cors {
258    headers: Vec<(String, String)>,
259}
260
261impl Cors {
262    pub fn new() -> Self {
263        Cors {
264            headers: Vec::new(),
265        }
266    }
267
268    pub fn with_origins(mut self, value: &str) -> Self {
269        self.headers
270            .push(("Access-Control-Allow-Origin".to_string(), value.to_string()));
271        self
272    }
273
274    pub fn with_methods(mut self, value: &str) -> Self {
275        self.headers.push((
276            "Access-Control-Allow-Methods".to_string(),
277            value.to_string(),
278        ));
279        self
280    }
281
282    pub fn with_headers(mut self, value: &str) -> Self {
283        self.headers.push((
284            "Access-Control-Allow-Headers".to_string(),
285            value.to_string(),
286        ));
287        self
288    }
289
290    pub fn with_credentials(mut self, value: &str) -> Self {
291        self.headers.push((
292            "Access-Control-Allow-Credentials".to_string(),
293            value.to_string(),
294        ));
295        self
296    }
297}
298
299impl Default for Cors {
300    fn default() -> Self {
301        Self::new()
302    }
303}
304
305fn challenge_basic_auth(
306    auth_header: &str,
307    expectedd_passwd: &str,
308    expected_username: &str,
309) -> Result<(), ApiError> {
310    let auth_parts: Vec<&str> = auth_header.split_whitespace().collect();
311    let challenge_response = HttpResponse::new(
312        Some(Body::Json(json!({"message": "Unauthorized"}))),
313        None,
314        401,
315    )
316    .add_response_header("WWW-Authenticate", "Basic");
317    if auth_parts.len() != 2 {
318        let err = ApiError::new_with_custom(challenge_response);
319        return Err(err);
320    }
321    let auth_type = auth_parts[0];
322    let auth_value = auth_parts[1];
323    if auth_type != "Basic" {
324        return Err(ApiError::new_with_json(
325            401,
326            "Unauthorized - unsupported auth challenge",
327        ));
328    }
329    let decoded = BASE64_STANDARD.decode(auth_value).unwrap();
330    let decoded_str = String::from_utf8(decoded).unwrap();
331    let auth_parts: Vec<&str> = decoded_str.split(':').collect();
332    if auth_parts.len() != 2 {
333        return Err(ApiError::new_with_custom(challenge_response));
334    }
335    let username = auth_parts[0];
336    let password = auth_parts[1];
337
338    if (username != expected_username) || (password != expectedd_passwd) {
339        return Err(ApiError::new_with_custom(challenge_response));
340    }
341    Ok(())
342}