1use base64::{prelude::BASE64_STANDARD, Engine};
2use regex::Regex;
3use serde_json::json;
4use std::{collections::HashMap, net::IpAddr, sync::Arc};
5use termcolor::Color;
6
7use crate::{logger::Logger, ApiError, Body, HttpResponse};
8
9#[derive(Debug)]
10pub enum HttpMethod {
11 GET,
12 POST,
13 PUT,
14 DELETE,
15 PATCH,
16 OPTIONS,
17 HEAD,
18 TRACE,
19 CONNECT,
20}
21
22pub struct Credentials {
23 username: String,
24 password: String,
25}
26
27impl HttpMethod {
28 fn as_str(&self) -> &str {
29 match self {
30 HttpMethod::GET => "GET",
31 HttpMethod::POST => "POST",
32 HttpMethod::PUT => "PUT",
33 HttpMethod::DELETE => "DELETE",
34 HttpMethod::PATCH => "PATCH",
35 HttpMethod::OPTIONS => "OPTIONS",
36 HttpMethod::HEAD => "HEAD",
37 HttpMethod::TRACE => "TRACE",
38 HttpMethod::CONNECT => "CONNECT",
39 }
40 }
41}
42
43fn get_status_code_color(status_code: u16) -> Color {
44 match status_code {
45 100..=199 => Color::Cyan,
46 200..=299 => Color::Green,
47 300..=399 => Color::Yellow,
48 400..=499 => Color::Red,
49 _ => Color::Magenta,
50 }
51}
52
53type Handler =
54 Box<dyn Fn(Option<&str>, HashMap<&str, &str>) -> Result<HttpResponse, ApiError> + Send + Sync>;
55
56pub struct Route {
57 pattern: Regex,
58 handler: Handler,
59 method: HttpMethod,
60 authorize: bool,
61}
62pub struct Router {
63 routes: Vec<Route>,
64 logger: Option<Arc<Logger>>,
65 pub(super) cors: Option<Cors>,
66 pub(super) credentials: Option<Credentials>,
67}
68
69impl Router {
70 pub fn new() -> Self {
71 Router {
72 routes: Vec::new(),
73 logger: None,
74 cors: None,
75 credentials: None,
76 }
77 }
78 pub fn with_logger(mut self, logger: Option<Arc<Logger>>) -> Self {
79 self.logger = logger;
80 self
81 }
82
83 pub fn with_cors(mut self, cors: Cors) -> Self {
84 self.cors = Some(cors);
85 self
86 }
87
88 pub fn with_credentials(mut self, password: &str, username: &str) -> Self {
89 self.credentials = Some(Credentials {
90 username: username.to_string(),
91 password: password.to_string(),
92 });
93 self
94 }
95
96 pub fn add_route<F>(&mut self, path: &str, method: HttpMethod, handler: F, authorize: bool)
97 where
98 F: Fn(Option<&str>, HashMap<&str, &str>) -> Result<HttpResponse, ApiError>
99 + Send
100 + Sync
101 + 'static,
102 {
103 let pattern = if path == "/*" {
104 "^.*$".to_string()
105 } else {
106 format!("^{}$", path.replace('{', "(?P<").replace('}', ">[^/]+)"))
107 };
108 let regex = Regex::new(&pattern).unwrap();
109 self.routes.push(Route {
110 pattern: regex,
111 handler: Box::new(handler),
112 method,
113 authorize,
114 });
115 }
116
117 pub fn route(
118 &self,
119 path: &str,
120 method: &str,
121 data: Option<&str>,
122 peer_addr: IpAddr,
123 headers: &HashMap<&str, &str>,
124 ) -> Result<HttpResponse, ApiError> {
125 let stripped_path: Vec<&str> = path.splitn(2, '?').collect();
126 if method == HttpMethod::OPTIONS.as_str() {
127 let mut response = HttpResponse::new(None, None, 204);
128 if let Some(cors) = &self.cors {
129 for (key, value) in &cors.headers {
130 response = response.add_response_header(key, value);
131 }
132 }
133 Ok(response)
134 } else {
135 for route in &self.routes {
136 let pattern_match = route.pattern.captures(stripped_path[0]);
137
138 match pattern_match {
139 Some(pattern_match) => {
140 if route.method.as_str() != method {
141 return Err(ApiError::new_with_json(405, "Method Not Allowed"));
142 }
143 if route.authorize {
144 if let Some(credentials) = &self.credentials {
145 if let Some(auth_header) = headers.get("Authorization") {
146 challenge_basic_auth(
147 auth_header,
148 &credentials.password,
149 &credentials.username,
150 )?;
151 } else {
152 return Ok(HttpResponse::new(
153 Some(Body::Json(json!({"message": "Unauthorized"}))),
154 None,
155 401,
156 )
157 .add_response_header("WWW-Authenticate", "Basic"));
158 }
159 } else {
160 return Err(ApiError::new_with_json(
161 500,
162 "Missing credentials configuration",
163 ));
164 }
165 }
166 let mut param_dict: HashMap<&str, &str> = route
167 .pattern
168 .capture_names()
169 .flatten()
170 .filter_map(|n| Some((n, pattern_match.name(n)?.as_str())))
171 .collect();
172
173 if stripped_path.len() == 2 {
174 for param in stripped_path[1].split('&') {
175 let pair: Vec<&str> = param.split('=').collect();
176 if pair.len() == 2 {
177 param_dict.insert(pair[0], pair[1]);
178 }
179 }
180 }
181 let mut response =
182 (route.handler)(data, param_dict).map_err(|mut err| {
183 err.method = Some(method.to_string());
184 err.path = Some(stripped_path[0].to_string());
185 err
186 })?;
187
188 if let Some(cors) = &self.cors {
189 for (key, value) in &cors.headers {
190 response = response.add_response_header(key, value);
191 }
192 }
193
194 self.log_response(
195 response.status_code,
196 stripped_path[0],
197 method,
198 peer_addr,
199 )?;
200
201 return Ok(response);
202 }
203 None => continue,
204 }
205 }
206 let error_response = HttpResponse::new(
207 Some(Body::Json(
208 json!({"message": format!("No route found for path {}", path)}),
209 )),
210 None,
211 404,
212 );
213
214 self.log_response(
215 error_response.status_code,
216 stripped_path[0],
217 method,
218 peer_addr,
219 )?;
220
221 Ok(error_response)
222 }
223 }
224 pub fn log_response(
225 &self,
226 status_code: u16,
227 path: &str,
228 method: &str,
229 peer_addr: IpAddr,
230 ) -> Result<(), Box<dyn std::error::Error>> {
231 if let Some(logger) = &self.logger {
232 let time_string = chrono::offset::Local::now()
233 .format("%Y-%m-%d %H:%M:%S")
234 .to_string();
235 let status_code_color = get_status_code_color(status_code);
236
237 let args = vec![
238 (time_string, Some(Color::White)),
239 (peer_addr.to_string(), Some(Color::Rgb(255, 167, 7))),
240 (status_code.to_string(), Some(status_code_color)),
241 (method.to_string(), Some(Color::White)),
242 (path.to_string(), Some(Color::White)),
243 ];
244
245 logger.log_stdout("{} - {} - {} - {} {}", args)?;
246 }
247 Ok(())
248 }
249}
250
251impl Default for Router {
252 fn default() -> Self {
253 Self::new()
254 }
255}
256
257pub struct Cors {
258 headers: Vec<(String, String)>,
259}
260
261impl Cors {
262 pub fn new() -> Self {
263 Cors {
264 headers: Vec::new(),
265 }
266 }
267
268 pub fn with_origins(mut self, value: &str) -> Self {
269 self.headers
270 .push(("Access-Control-Allow-Origin".to_string(), value.to_string()));
271 self
272 }
273
274 pub fn with_methods(mut self, value: &str) -> Self {
275 self.headers.push((
276 "Access-Control-Allow-Methods".to_string(),
277 value.to_string(),
278 ));
279 self
280 }
281
282 pub fn with_headers(mut self, value: &str) -> Self {
283 self.headers.push((
284 "Access-Control-Allow-Headers".to_string(),
285 value.to_string(),
286 ));
287 self
288 }
289
290 pub fn with_credentials(mut self, value: &str) -> Self {
291 self.headers.push((
292 "Access-Control-Allow-Credentials".to_string(),
293 value.to_string(),
294 ));
295 self
296 }
297}
298
299impl Default for Cors {
300 fn default() -> Self {
301 Self::new()
302 }
303}
304
305fn challenge_basic_auth(
306 auth_header: &str,
307 expectedd_passwd: &str,
308 expected_username: &str,
309) -> Result<(), ApiError> {
310 let auth_parts: Vec<&str> = auth_header.split_whitespace().collect();
311 let challenge_response = HttpResponse::new(
312 Some(Body::Json(json!({"message": "Unauthorized"}))),
313 None,
314 401,
315 )
316 .add_response_header("WWW-Authenticate", "Basic");
317 if auth_parts.len() != 2 {
318 let err = ApiError::new_with_custom(challenge_response);
319 return Err(err);
320 }
321 let auth_type = auth_parts[0];
322 let auth_value = auth_parts[1];
323 if auth_type != "Basic" {
324 return Err(ApiError::new_with_json(
325 401,
326 "Unauthorized - unsupported auth challenge",
327 ));
328 }
329 let decoded = BASE64_STANDARD.decode(auth_value).unwrap();
330 let decoded_str = String::from_utf8(decoded).unwrap();
331 let auth_parts: Vec<&str> = decoded_str.split(':').collect();
332 if auth_parts.len() != 2 {
333 return Err(ApiError::new_with_custom(challenge_response));
334 }
335 let username = auth_parts[0];
336 let password = auth_parts[1];
337
338 if (username != expected_username) || (password != expectedd_passwd) {
339 return Err(ApiError::new_with_custom(challenge_response));
340 }
341 Ok(())
342}