1use std::{
2 collections::HashMap,
3 net::TcpStream,
4 ops::Deref,
5 sync::{Arc, OnceLock},
6};
7
8use crate::request::Request;
9pub use dyn_clone::DynClone;
10use std::fmt::Debug;
11
12#[cfg(feature = "ssl")]
13use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod};
14
15use crate::response::Response;
16
17use rusty_pool::{Builder, ThreadPool};
18
19#[cfg(not(feature = "async"))]
20use std::net::{Incoming, TcpListener};
21
22#[cfg(not(feature = "async"))]
23use crate::http::start_http;
24
25#[cfg(feature = "async")]
26use tokio::net::TcpListener;
27
28#[cfg(feature = "async")]
29use crate::async_http::start_http;
30
31use std::sync::Mutex;
32
33#[cfg(test)]
34use std::any::Any;
35
36type RouteVec = Vec<Box<dyn Route>>;
37
38pub static PRE_MIDDLEWARE_CONST: OnceLock<Box<dyn FnMut(&mut Request) + Send + Sync>> =
39 OnceLock::new();
40
41pub static POST_MIDDLEWARE_CONST: OnceLock<Box<dyn FnMut(&mut Request) + Send + Sync>> =
42 OnceLock::new();
43
44#[derive(Clone, Copy, Debug)]
45pub enum Method {
46 GET,
47 POST,
48}
49
50pub trait ToResponse: DynClone + Sync + Send {
51 fn to_res(&self, res: Request, sock: &mut TcpStream) -> Response;
52}
53
54pub trait Route: DynClone + Sync + Send + ToResponse {
55 fn get_path(&self) -> &str;
56 fn get_method(&self) -> Method;
57 fn wildcard(&self) -> Option<String>;
58 fn clone_dyn(&self) -> Box<dyn Route>;
59
60 #[cfg(test)]
61 fn any(&self) -> &dyn Any;
62}
63
64impl Clone for Box<dyn Route> {
65 fn clone(&self) -> Self {
66 self.clone_dyn()
67 }
68}
69
70pub struct HttpListener {
71 pub(crate) socket: TcpListener,
72 pub config: Config,
73 pub pool: ThreadPool,
74 pub use_pool: bool,
75 #[cfg(feature = "ssl")]
76 pub ssl_acpt: Option<Arc<SslAcceptor>>,
77}
78
79impl HttpListener {
80 pub fn new<P: Into<TcpListener>>(socket: P, config: Config) -> HttpListener {
81 #[cfg(feature = "log")]
82 log::debug!("Using {} threads", num_cpus::get());
83
84 if config.ssl {
85 #[cfg(feature = "ssl")]
86 let ssl_acpt = Some(build_https(
87 config.ssl_chain.clone().unwrap(),
88 config.ssl_priv.clone().unwrap(),
89 ));
90 HttpListener {
91 socket: socket.into(),
92 config,
93 pool: ThreadPool::default(),
94 #[cfg(feature = "ssl")]
95 ssl_acpt,
96 use_pool: true,
97 }
98 } else {
99 HttpListener {
100 socket: socket.into(),
101 config,
102 pool: ThreadPool::default(),
103 #[cfg(feature = "ssl")]
104 ssl_acpt: None,
105 use_pool: true,
106 }
107 }
108 }
109
110 pub fn threads(mut self, threads: usize) -> Self {
111 let pool = Builder::new().core_size(threads).build();
112
113 self.pool = pool;
114 self
115 }
116
117 pub fn use_tp(mut self, r: bool) -> Self {
118 self.use_pool = r;
119 self
120 }
121
122 #[cfg(not(feature = "async"))]
123 pub fn start(self) {
124 let conf_clone = self.config.clone();
125 start_http(self, conf_clone);
126 }
127
128 #[cfg(feature = "async")]
129 pub async fn start(self) {
130 start_http(self).await;
131 }
132
133 #[cfg(not(feature = "async"))]
134 pub fn get_stream(&self) -> Incoming<'_> {
135 self.socket.incoming()
136 }
137}
138
139#[derive(Clone)]
140pub struct Routes {
141 routes: RouteVec,
142}
143
144impl Routes {
145 pub fn new<R: Into<RouteVec>>(routes: R) -> Routes {
146 let routes = routes.into();
147 Routes { routes }
148 }
149
150 pub fn get_stream(self) -> RouteVec {
151 self.routes
152 }
153}
154
155#[derive(Clone)]
156pub struct Config {
157 mount_point: Option<String>,
158 get_routes: Option<HashMap<String, Box<dyn Route>>>,
159 post_routes: Option<HashMap<String, Box<dyn Route>>>,
160 debug: bool,
161 pub ssl: bool,
162 ssl_chain: Option<String>,
163 ssl_priv: Option<String>,
164 headers: Option<HashMap<String, String>>,
165 br: bool,
166 gzip: bool,
167 spa: bool,
168 http2: bool,
169 response_middleware: Option<Arc<Mutex<dyn FnMut(&mut Response) + Send + Sync>>>,
170 request_middleware: Option<Arc<Mutex<dyn FnMut(&mut Request) + Send + Sync>>>,
171}
172
173impl Default for Config {
174 fn default() -> Self {
175 Config::new()
176 }
177}
178
179impl Config {
180 pub fn new() -> Config {
200 #[cfg(feature = "log")]
203 log::info!("tinyhttp version: {}", env!("CARGO_PKG_VERSION"));
204
205 Config {
206 mount_point: None,
207 get_routes: None,
208 post_routes: None,
209 debug: false,
210 ssl: false,
211 ssl_chain: None,
212 ssl_priv: None,
213 headers: None,
214 gzip: false,
215 br: false,
216 spa: false,
217 http2: false,
218 request_middleware: None,
219 response_middleware: None,
220 }
221 }
222
223 pub fn mount_point<P: Into<String>>(mut self, path: P) -> Self {
232 self.mount_point = Some(path.into());
233 self
234 }
235
236 pub fn routes(mut self, routes: Routes) -> Self {
264 let mut get_routes = HashMap::new();
265 let mut post_routes = HashMap::new();
266 let routes = routes.get_stream();
267
268 for route in routes {
269 match route.get_method() {
270 Method::GET => {
271 #[cfg(feature = "log")]
272 log::info!("GET Route init!: {}", &route.get_path());
273
274 get_routes.insert(route.get_path().to_string(), route);
275 }
276 Method::POST => {
277 #[cfg(feature = "log")]
278 log::info!("POST Route init!: {}", &route.get_path());
279 post_routes.insert(route.get_path().to_string(), route);
280 }
281 }
282 }
283 if !get_routes.is_empty() {
284 self.get_routes = Some(get_routes);
285 } else {
286 self.get_routes = None;
287 }
288
289 if !post_routes.is_empty() {
290 self.post_routes = Some(post_routes);
291 } else {
292 self.post_routes = None;
293 }
294
295 self
296 }
297
298 pub fn ssl(mut self, ssl_chain: String, ssl_priv: String) -> Self {
307 self.ssl_chain = Some(ssl_chain);
308 self.ssl_priv = Some(ssl_priv);
309 self.ssl = true;
310 self
311 }
312 pub fn debug(mut self) -> Self {
313 self.debug = true;
314 self
315 }
316
317 pub fn headers(mut self, headers: Vec<String>) -> Self {
323 let mut hash_map: HashMap<String, String> = HashMap::new();
324 for i in headers {
325 let mut split = i.split_inclusive(": ");
326 hash_map.insert(
327 split.next().unwrap().to_string(),
328 split.next().unwrap().to_string() + "\r\n",
329 );
330 }
331
332 self.headers = Some(hash_map);
333 self
334 }
335
336 pub fn br(mut self, res: bool) -> Self {
339 self.br = res;
340 self
341 }
342
343 pub fn spa(mut self, res: bool) -> Self {
344 self.spa = res;
345 self
346 }
347
348 pub fn gzip(mut self, res: bool) -> Self {
350 self.gzip = res;
351 self
352 }
353
354 pub fn http2(mut self, res: bool) -> Self {
355 self.http2 = res;
356 self
357 }
358
359 pub fn request_middleware<F: FnMut(&mut Request) + Send + Sync + 'static>(
360 mut self,
361 middleware_fn: F,
362 ) -> Self {
363 self.request_middleware = Some(Arc::new(Mutex::new(middleware_fn)));
364 self
365 }
366
367 pub fn response_middleware<F: FnMut(&mut Response) + Send + Sync + 'static>(
368 mut self,
369 middleware_fn: F,
370 ) -> Self {
371 self.response_middleware = Some(Arc::new(Mutex::new(middleware_fn)));
372 self
373 }
374
375 pub fn get_headers(&self) -> Option<&HashMap<String, String>> {
376 self.headers.as_ref()
377 }
378 pub fn get_br(&self) -> bool {
379 self.br
380 }
381 pub fn get_gzip(&self) -> bool {
382 self.gzip
383 }
384 pub fn get_debug(&self) -> bool {
385 self.debug
386 }
387 pub fn get_mount(&self) -> Option<&String> {
388 self.mount_point.as_ref()
389 }
390 pub fn get_routes(&self, req_path: &str) -> Option<&dyn Route> {
391 let req_path = if req_path.ends_with('/') && req_path.matches('/').count() > 1 {
392 let mut chars = req_path.chars();
393 chars.next_back();
394 chars.as_str()
395 } else {
396 req_path
397 };
398
399 #[cfg(feature = "log")]
400 log::trace!("get_routes -> new_path: {}", &req_path);
401
402 let routes = self.get_routes.as_ref()?;
403
404 if let Some(route) = routes.get(req_path) {
405 return Some(route.deref());
406 }
407
408 if let Some((_, wildcard_route)) = routes
409 .iter()
410 .find(|(path, route)| req_path.starts_with(*path) && route.wildcard().is_some())
411 {
412 return Some(wildcard_route.deref());
413 }
414
415 None
416 }
417
418 pub fn post_routes(&self, req_path: &str) -> Option<&dyn Route> {
419 #[cfg(feature = "log")]
420 log::trace!("post_routes -> path: {}", req_path);
421
422 let req_path = if req_path.ends_with('/') && req_path.matches('/').count() > 1 {
423 let mut chars = req_path.chars();
424 chars.next_back();
425 chars.as_str()
426 } else {
427 req_path
428 };
429
430 #[cfg(feature = "log")]
431 log::trace!("get_routes -> new_path: {}", &req_path);
432
433 let routes = self.post_routes.as_ref()?;
434
435 if let Some(route) = routes.get(req_path) {
436 return Some(route.deref());
437 }
438
439 if let Some((_, wildcard_route)) = routes
440 .iter()
441 .find(|(path, route)| req_path.starts_with(*path) && route.wildcard().is_some())
442 {
443 return Some(wildcard_route.deref());
444 }
445
446 None
447 }
448
449 pub fn get_spa(&self) -> bool {
450 self.spa
451 }
452
453 #[allow(dead_code)]
454 pub(crate) fn get_request_middleware(
455 &self,
456 ) -> Option<Arc<Mutex<dyn FnMut(&mut Request) + Send + Sync>>> {
457 if let Some(s) = &self.request_middleware {
458 Some(Arc::clone(s))
459 } else {
460 None
461 }
462 }
463
464 #[allow(dead_code)]
465 pub(crate) fn get_response_middleware(
466 &self,
467 ) -> Option<Arc<Mutex<dyn FnMut(&mut Response) + Send + Sync>>> {
468 if let Some(s) = &self.response_middleware {
469 Some(Arc::clone(s))
470 } else {
471 None
472 }
473 }
474}
475
476#[cfg(feature = "ssl")]
477pub fn build_https(chain: String, private: String) -> Arc<SslAcceptor> {
478 let mut acceptor = SslAcceptor::mozilla_modern_v5(SslMethod::tls()).unwrap();
479 acceptor.set_certificate_chain_file(chain).unwrap();
480 acceptor
481 .set_private_key_file(private, SslFiletype::PEM)
482 .unwrap();
483 acceptor.check_private_key().unwrap();
484 Arc::new(acceptor.build())
485}