tinyhttp_internal/
config.rs

1use std::{
2    collections::HashMap,
3    net::TcpStream,
4    ops::Deref,
5    sync::{Arc, OnceLock},
6};
7
8use crate::request::Request;
9pub use dyn_clone::DynClone;
10use std::fmt::Debug;
11
12#[cfg(feature = "ssl")]
13use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod};
14
15use crate::response::Response;
16
17use rusty_pool::{Builder, ThreadPool};
18
19#[cfg(not(feature = "async"))]
20use std::net::{Incoming, TcpListener};
21
22#[cfg(not(feature = "async"))]
23use crate::http::start_http;
24
25#[cfg(feature = "async")]
26use tokio::net::TcpListener;
27
28#[cfg(feature = "async")]
29use crate::async_http::start_http;
30
31use std::sync::Mutex;
32
33#[cfg(test)]
34use std::any::Any;
35
36type RouteVec = Vec<Box<dyn Route>>;
37
38pub static PRE_MIDDLEWARE_CONST: OnceLock<Box<dyn FnMut(&mut Request) + Send + Sync>> =
39    OnceLock::new();
40
41pub static POST_MIDDLEWARE_CONST: OnceLock<Box<dyn FnMut(&mut Request) + Send + Sync>> =
42    OnceLock::new();
43
44#[derive(Clone, Copy, Debug)]
45pub enum Method {
46    GET,
47    POST,
48}
49
50pub trait ToResponse: DynClone + Sync + Send {
51    fn to_res(&self, res: Request, sock: &mut TcpStream) -> Response;
52}
53
54pub trait Route: DynClone + Sync + Send + ToResponse {
55    fn get_path(&self) -> &str;
56    fn get_method(&self) -> Method;
57    fn wildcard(&self) -> Option<String>;
58    fn clone_dyn(&self) -> Box<dyn Route>;
59
60    #[cfg(test)]
61    fn any(&self) -> &dyn Any;
62}
63
64impl Clone for Box<dyn Route> {
65    fn clone(&self) -> Self {
66        self.clone_dyn()
67    }
68}
69
70pub struct HttpListener {
71    pub(crate) socket: TcpListener,
72    pub config: Config,
73    pub pool: ThreadPool,
74    pub use_pool: bool,
75    #[cfg(feature = "ssl")]
76    pub ssl_acpt: Option<Arc<SslAcceptor>>,
77}
78
79impl HttpListener {
80    pub fn new<P: Into<TcpListener>>(socket: P, config: Config) -> HttpListener {
81        #[cfg(feature = "log")]
82        log::debug!("Using {} threads", num_cpus::get());
83
84        if config.ssl {
85            #[cfg(feature = "ssl")]
86            let ssl_acpt = Some(build_https(
87                config.ssl_chain.clone().unwrap(),
88                config.ssl_priv.clone().unwrap(),
89            ));
90            HttpListener {
91                socket: socket.into(),
92                config,
93                pool: ThreadPool::default(),
94                #[cfg(feature = "ssl")]
95                ssl_acpt,
96                use_pool: true,
97            }
98        } else {
99            HttpListener {
100                socket: socket.into(),
101                config,
102                pool: ThreadPool::default(),
103                #[cfg(feature = "ssl")]
104                ssl_acpt: None,
105                use_pool: true,
106            }
107        }
108    }
109
110    pub fn threads(mut self, threads: usize) -> Self {
111        let pool = Builder::new().core_size(threads).build();
112
113        self.pool = pool;
114        self
115    }
116
117    pub fn use_tp(mut self, r: bool) -> Self {
118        self.use_pool = r;
119        self
120    }
121
122    #[cfg(not(feature = "async"))]
123    pub fn start(self) {
124        let conf_clone = self.config.clone();
125        start_http(self, conf_clone);
126    }
127
128    #[cfg(feature = "async")]
129    pub async fn start(self) {
130        start_http(self).await;
131    }
132
133    #[cfg(not(feature = "async"))]
134    pub fn get_stream(&self) -> Incoming<'_> {
135        self.socket.incoming()
136    }
137}
138
139#[derive(Clone)]
140pub struct Routes {
141    routes: RouteVec,
142}
143
144impl Routes {
145    pub fn new<R: Into<RouteVec>>(routes: R) -> Routes {
146        let routes = routes.into();
147        Routes { routes }
148    }
149
150    pub fn get_stream(self) -> RouteVec {
151        self.routes
152    }
153}
154
155#[derive(Clone)]
156pub struct Config {
157    mount_point: Option<String>,
158    get_routes: Option<HashMap<String, Box<dyn Route>>>,
159    post_routes: Option<HashMap<String, Box<dyn Route>>>,
160    debug: bool,
161    pub ssl: bool,
162    ssl_chain: Option<String>,
163    ssl_priv: Option<String>,
164    headers: Option<HashMap<String, String>>,
165    br: bool,
166    gzip: bool,
167    spa: bool,
168    http2: bool,
169    response_middleware: Option<Arc<Mutex<dyn FnMut(&mut Response) + Send + Sync>>>,
170    request_middleware: Option<Arc<Mutex<dyn FnMut(&mut Request) + Send + Sync>>>,
171}
172
173impl Default for Config {
174    fn default() -> Self {
175        Config::new()
176    }
177}
178
179impl Config {
180    /// Generates default settings (which don't work by itself)
181    ///
182    /// Chain with mount_point or routes
183    ///
184    /// ### Example:
185    /// ```ignore
186    /// use tinyhttp::prelude::*;
187    ///
188    /// #[get("/test")]
189    /// fn get_test() -> String {
190    ///   String::from("Hello, there!\n")
191    /// }
192    ///
193    /// let routes = Routes::new(vec![get_test()]);
194    /// let routes_config = Config::new().routes(routes);
195    /// /// or
196    /// let mount_config = Config::new().mount_point(".");
197    /// ```
198
199    pub fn new() -> Config {
200        //assert!(routes.len() > 0);
201
202        #[cfg(feature = "log")]
203        log::info!("tinyhttp version: {}", env!("CARGO_PKG_VERSION"));
204
205        Config {
206            mount_point: None,
207            get_routes: None,
208            post_routes: None,
209            debug: false,
210            ssl: false,
211            ssl_chain: None,
212            ssl_priv: None,
213            headers: None,
214            gzip: false,
215            br: false,
216            spa: false,
217            http2: false,
218            request_middleware: None,
219            response_middleware: None,
220        }
221    }
222
223    /// A mount point that will be searched when a request isn't defined with a get or post route
224    ///
225    /// ### Example:
226    /// ```ignore
227    /// let config = Config::new().mount_point(".")
228    /// /// if index.html exists in current directory, it will be returned if "/" or "/index.html" is requested.
229    /// ```
230
231    pub fn mount_point<P: Into<String>>(mut self, path: P) -> Self {
232        self.mount_point = Some(path.into());
233        self
234    }
235
236    /// Add routes with a Route member
237    ///
238    /// ### Example:
239    /// ```ignore
240    /// use tinyhttp::prelude::*;
241    ///
242    ///
243    /// #[get("/test")]
244    /// fn get_test() -> &'static str {
245    ///   "Hello, World!"
246    /// }
247    ///
248    /// #[post("/test")]
249    /// fn post_test() -> Vec<u8> {
250    ///   b"Hello, Post!".to_vec()
251    /// }
252    ///
253    /// fn main() {
254    ///   let socket = TcpListener::new(":::80").unwrap();
255    ///   let routes = Routes::new(vec![get_test(), post_test()]);
256    ///   let config = Config::new().routes(routes);
257    ///   let http = HttpListener::new(socket, config);
258    ///
259    ///   http.start();
260    /// }
261    /// ```
262
263    pub fn routes(mut self, routes: Routes) -> Self {
264        let mut get_routes = HashMap::new();
265        let mut post_routes = HashMap::new();
266        let routes = routes.get_stream();
267
268        for route in routes {
269            match route.get_method() {
270                Method::GET => {
271                    #[cfg(feature = "log")]
272                    log::info!("GET Route init!: {}", &route.get_path());
273
274                    get_routes.insert(route.get_path().to_string(), route);
275                }
276                Method::POST => {
277                    #[cfg(feature = "log")]
278                    log::info!("POST Route init!: {}", &route.get_path());
279                    post_routes.insert(route.get_path().to_string(), route);
280                }
281            }
282        }
283        if !get_routes.is_empty() {
284            self.get_routes = Some(get_routes);
285        } else {
286            self.get_routes = None;
287        }
288
289        if !post_routes.is_empty() {
290            self.post_routes = Some(post_routes);
291        } else {
292            self.post_routes = None;
293        }
294
295        self
296    }
297
298    /// Enables SSL
299    ///
300    /// ### Example:
301    /// ```ignore
302    /// let config = Config::new().ssl("./fullchain.pem", "./privkey.pem");
303    /// ```
304    /// This will only accept HTTPS connections
305
306    pub fn ssl(mut self, ssl_chain: String, ssl_priv: String) -> Self {
307        self.ssl_chain = Some(ssl_chain);
308        self.ssl_priv = Some(ssl_priv);
309        self.ssl = true;
310        self
311    }
312    pub fn debug(mut self) -> Self {
313        self.debug = true;
314        self
315    }
316
317    /// Define custom headers
318    ///
319    /// ```ignore
320    /// let config = Config::new().headers(vec!["Access-Control-Allow-Origin: *".into()]);
321    /// ```
322    pub fn headers(mut self, headers: Vec<String>) -> Self {
323        let mut hash_map: HashMap<String, String> = HashMap::new();
324        for i in headers {
325            let mut split = i.split_inclusive(": ");
326            hash_map.insert(
327                split.next().unwrap().to_string(),
328                split.next().unwrap().to_string() + "\r\n",
329            );
330        }
331
332        self.headers = Some(hash_map);
333        self
334    }
335
336    /// DOES NOT WORK!
337    /// Enables brotli compression
338    pub fn br(mut self, res: bool) -> Self {
339        self.br = res;
340        self
341    }
342
343    pub fn spa(mut self, res: bool) -> Self {
344        self.spa = res;
345        self
346    }
347
348    /// Enables gzip compression
349    pub fn gzip(mut self, res: bool) -> Self {
350        self.gzip = res;
351        self
352    }
353
354    pub fn http2(mut self, res: bool) -> Self {
355        self.http2 = res;
356        self
357    }
358
359    pub fn request_middleware<F: FnMut(&mut Request) + Send + Sync + 'static>(
360        mut self,
361        middleware_fn: F,
362    ) -> Self {
363        self.request_middleware = Some(Arc::new(Mutex::new(middleware_fn)));
364        self
365    }
366
367    pub fn response_middleware<F: FnMut(&mut Response) + Send + Sync + 'static>(
368        mut self,
369        middleware_fn: F,
370    ) -> Self {
371        self.response_middleware = Some(Arc::new(Mutex::new(middleware_fn)));
372        self
373    }
374
375    pub fn get_headers(&self) -> Option<&HashMap<String, String>> {
376        self.headers.as_ref()
377    }
378    pub fn get_br(&self) -> bool {
379        self.br
380    }
381    pub fn get_gzip(&self) -> bool {
382        self.gzip
383    }
384    pub fn get_debug(&self) -> bool {
385        self.debug
386    }
387    pub fn get_mount(&self) -> Option<&String> {
388        self.mount_point.as_ref()
389    }
390    pub fn get_routes(&self, req_path: &str) -> Option<&dyn Route> {
391        let req_path = if req_path.ends_with('/') && req_path.matches('/').count() > 1 {
392            let mut chars = req_path.chars();
393            chars.next_back();
394            chars.as_str()
395        } else {
396            req_path
397        };
398
399        #[cfg(feature = "log")]
400        log::trace!("get_routes -> new_path: {}", &req_path);
401
402        let routes = self.get_routes.as_ref()?;
403
404        if let Some(route) = routes.get(req_path) {
405            return Some(route.deref());
406        }
407
408        if let Some((_, wildcard_route)) = routes
409            .iter()
410            .find(|(path, route)| req_path.starts_with(*path) && route.wildcard().is_some())
411        {
412            return Some(wildcard_route.deref());
413        }
414
415        None
416    }
417
418    pub fn post_routes(&self, req_path: &str) -> Option<&dyn Route> {
419        #[cfg(feature = "log")]
420        log::trace!("post_routes -> path: {}", req_path);
421
422        let req_path = if req_path.ends_with('/') && req_path.matches('/').count() > 1 {
423            let mut chars = req_path.chars();
424            chars.next_back();
425            chars.as_str()
426        } else {
427            req_path
428        };
429
430        #[cfg(feature = "log")]
431        log::trace!("get_routes -> new_path: {}", &req_path);
432
433        let routes = self.post_routes.as_ref()?;
434
435        if let Some(route) = routes.get(req_path) {
436            return Some(route.deref());
437        }
438
439        if let Some((_, wildcard_route)) = routes
440            .iter()
441            .find(|(path, route)| req_path.starts_with(*path) && route.wildcard().is_some())
442        {
443            return Some(wildcard_route.deref());
444        }
445
446        None
447    }
448
449    pub fn get_spa(&self) -> bool {
450        self.spa
451    }
452
453    #[allow(dead_code)]
454    pub(crate) fn get_request_middleware(
455        &self,
456    ) -> Option<Arc<Mutex<dyn FnMut(&mut Request) + Send + Sync>>> {
457        if let Some(s) = &self.request_middleware {
458            Some(Arc::clone(s))
459        } else {
460            None
461        }
462    }
463
464    #[allow(dead_code)]
465    pub(crate) fn get_response_middleware(
466        &self,
467    ) -> Option<Arc<Mutex<dyn FnMut(&mut Response) + Send + Sync>>> {
468        if let Some(s) = &self.response_middleware {
469            Some(Arc::clone(s))
470        } else {
471            None
472        }
473    }
474}
475
476#[cfg(feature = "ssl")]
477pub fn build_https(chain: String, private: String) -> Arc<SslAcceptor> {
478    let mut acceptor = SslAcceptor::mozilla_modern_v5(SslMethod::tls()).unwrap();
479    acceptor.set_certificate_chain_file(chain).unwrap();
480    acceptor
481        .set_private_key_file(private, SslFiletype::PEM)
482        .unwrap();
483    acceptor.check_private_key().unwrap();
484    Arc::new(acceptor.build())
485}