scratch_server/
lib.rs

1use api_error::ApiError;
2use http_parse_error::HttpParseError;
3use include_dir::{include_dir, Dir};
4use logger::Logger;
5use native_tls::{Identity, TlsAcceptor};
6use std::borrow::Cow;
7use std::collections::HashMap;
8use std::fs::{self, File};
9use std::io::{self, BufRead, BufReader, Read, Write};
10use std::net::{IpAddr, SocketAddr, TcpListener, TcpStream};
11use std::path::PathBuf;
12use std::sync::Arc;
13use termcolor::Color;
14use utils::get_option;
15
16mod errors;
17mod http_response;
18mod logger;
19mod router;
20mod thread_pool;
21mod utils;
22
23pub use errors::*;
24pub use http_response::*;
25pub use router::*;
26
27pub static STATIC_FILES: Dir<'_> = include_dir!("src/dist");
28
29pub trait ReadWrite: Read + Write + Send + 'static {}
30
31impl<T: Read + Write + Send + 'static> ReadWrite for T {}
32
33struct NetworkStream {
34    delegate: Option<Box<dyn ReadWrite>>,
35    tls_acceptor: Option<TlsAcceptor>,
36}
37
38impl NetworkStream {
39    pub fn new(
40        cert_path: Option<&PathBuf>,
41        cert_pass: Option<&String>,
42    ) -> Result<NetworkStream, Box<dyn std::error::Error>> {
43        match &cert_path {
44            Some(path) => {
45                let identity_bytes = fs::read(path)?;
46
47                let identity = Identity::from_pkcs12(&identity_bytes, cert_pass.unwrap())?;
48
49                let tls_acceptor = TlsAcceptor::new(identity)?;
50
51                Ok(NetworkStream {
52                    tls_acceptor: Some(tls_acceptor),
53                    delegate: None,
54                })
55            }
56            None => Ok(NetworkStream {
57                tls_acceptor: None,
58                delegate: None,
59            }),
60        }
61    }
62    pub fn get_stream(
63        &mut self,
64        stream: TcpStream,
65    ) -> Result<&mut NetworkStream, Box<dyn std::error::Error>> {
66        match &self.tls_acceptor {
67            Some(acceptor) => {
68                let tls_stream = acceptor.accept(stream)?;
69                self.delegate = Some(Box::new(tls_stream));
70                Ok(self)
71            }
72            None => {
73                self.delegate = Some(Box::new(stream));
74                Ok(self)
75            }
76        }
77    }
78}
79
80pub struct HttpServer {
81    port: u16,
82    threads: usize,
83    cert_path: Option<PathBuf>,
84    cert_pass: Option<String>,
85    router: Router,
86    logger: Option<Arc<Logger>>,
87    bind_address: IpAddr,
88    compression: bool,
89}
90
91impl HttpServer {
92    pub fn build(
93        port: u16,
94        threads: usize,
95        cert_path: Option<PathBuf>,
96        cert_pass: Option<String>,
97        bind_address: IpAddr,
98        compression: bool,
99    ) -> HttpServer {
100        HttpServer {
101            port,
102            threads,
103            cert_path,
104            cert_pass,
105            router: Router::new(),
106            logger: None,
107            bind_address,
108            compression,
109        }
110    }
111    pub fn with_logger(mut self) -> Self {
112        self.logger = Some(Arc::new(Logger::new()));
113        self.router = self
114            .router
115            .with_logger(Some(Arc::clone(self.logger.as_ref().unwrap())));
116        self
117    }
118
119    pub fn with_credentials(mut self, password: &str, username: &str) -> Self {
120        self.router = self.router.with_credentials(username, password);
121        self
122    }
123
124    pub fn add_routes<F>(mut self, routes: F) -> Self
125    where
126        F: Fn(&mut Router) + Send + Sync + 'static,
127    {
128        routes(&mut self.router);
129        self
130    }
131
132    pub fn with_cors_policy(mut self, policy: Cors) -> Self {
133        self.router = self.router.with_cors(policy);
134        self
135    }
136    pub fn run(self) -> Result<(), Box<dyn std::error::Error>> {
137        self.print_server_info();
138        let listener = TcpListener::bind(SocketAddr::from((self.bind_address, self.port)))?;
139        let pool = thread_pool::ThreadPool::build(self.threads)?;
140
141        let arc_router = Arc::new(self.router);
142        let mut network_stream =
143            NetworkStream::new(self.cert_path.as_ref(), self.cert_pass.as_ref())?;
144        for stream in listener.incoming() {
145            let stream = stream?;
146            let peer_addr = stream.peer_addr()?;
147            let Ok(stream) = network_stream.get_stream(stream) else {
148                continue;
149            };
150            let mut stream = stream.delegate.take().unwrap();
151
152            let router_clone = Arc::clone(&arc_router);
153            let logger_clone = self.logger.clone();
154
155            pool.execute(move || {
156                handle_connection(&mut stream, &router_clone, peer_addr.ip())
157                    .unwrap_or_else(|err| {
158                        if let (Some(method), Some(path)) = (&err.method, &err.path) {
159                            router_clone
160                                .log_response(
161                                    err.error_response.status_code,
162                                    path,
163                                    method,
164                                    peer_addr.ip(),
165                                )
166                                .unwrap();
167                        }
168
169                        err.error_response
170                    })
171                    .write_response(&mut stream, self.compression)
172                    .unwrap_or_else(|err| {
173                        if let Some(logger) = logger_clone {
174                            logger
175                                .log_stderr("Error: {}", vec![(err.to_string(), Some(Color::Red))])
176                                .unwrap();
177                        }
178                    });
179            })?;
180        }
181        Ok(())
182    }
183    fn print_server_info(&self) {
184        if let Some(logger) = &self.logger {
185            logger.log_stdout(
186                r#"
187
188 ========================================================================================================
189 
190   _____ _                 _        _    _ _______ _______ _____     _____                          
191  / ____(_)               | |      | |  | |__   __|__   __|  __ \   / ____|                         
192 | (___  _ _ __ ___  _ __ | | ___  | |__| |  | |     | |  | |__) | | (___   ___ _ ____   _____ _ __ 
193  \___ \| | '_ ` _ \| '_ \| |/ _ \ |  __  |  | |     | |  |  ___/   \___ \ / _ \ '__\ \ / / _ \ '__|
194  ____) | | | | | | | |_) | |  __/ | |  | |  | |     | |  | |       ____) |  __/ |   \ V /  __/ |   
195 |_____/|_|_| |_| |_| .__/|_|\___| |_|  |_|  |_|     |_|  |_|      |_____/ \___|_|    \_/ \___|_|   
196                    | |                                                                             
197                    |_|                                                                             
198
199=========================================================================================================
200
201Port: {}
202Threads: {}
203HTTPS: {}
204CORS: {}
205Auth: {}
206Compression: {}
207
208====================
209Logs:"#,
210                vec![
211                    (self.port.to_string(), Some(Color::Blue)),
212                    (self.threads.to_string(), Some(Color::Blue)),
213                    get_option(&self.cert_path),
214                    get_option(&self.router.cors),
215                    get_option(&self.router.credentials),
216                    if self.compression { ("Enabled".to_string(), Some(Color::Green)) } else { ("Disabled".to_string(), Some(Color::Yellow)) },
217                ],
218            )
219            .unwrap();
220        }
221    }
222}
223
224fn parse_http<'a>(
225    reader: &mut BufReader<&mut Box<dyn ReadWrite>>,
226    request_string: &'a mut String,
227) -> Result<(&'a str, &'a str, HashMap<&'a str, &'a str>), HttpParseError> {
228    loop {
229        let mut line = String::new();
230        reader.read_line(&mut line)?;
231        request_string.push_str(&line);
232        if line == "\r\n" {
233            break;
234        }
235    }
236    let http_parts: Vec<&str> = request_string.split("\r\n\r\n").collect();
237    let request_lines: Vec<&str> = http_parts
238        .first()
239        .ok_or(HttpParseError::default())?
240        .lines()
241        .collect();
242
243    let http_method: Vec<&str> = request_lines
244        .first()
245        .ok_or(HttpParseError::default())?
246        .split_whitespace()
247        .collect();
248
249    if http_method.len() < 3 {
250        return Err(HttpParseError::default());
251    }
252
253    let (method, path, _version) = (http_method[0], http_method[1], http_method[2]);
254
255    let mut headers = std::collections::HashMap::new();
256    for line in &request_lines[1..] {
257        let parts: Vec<&str> = line.splitn(2, ':').collect();
258        if parts.len() == 2 {
259            headers.insert(
260                *parts.first().ok_or(HttpParseError::default())?,
261                parts.get(1).ok_or(HttpParseError::default())?.trim(),
262            );
263        }
264    }
265
266    Ok((method, path, headers))
267}
268
269fn handle_connection(
270    stream: &mut Box<dyn ReadWrite>,
271    router: &Arc<Router>,
272    peer_addr: IpAddr,
273) -> Result<HttpResponse, ApiError> {
274    let mut reader = BufReader::new(&mut *stream);
275
276    let mut request = String::new();
277    let (method, path, headers) = parse_http(&mut reader, &mut request)?;
278
279    let mut buffer = Vec::new();
280    let body;
281
282    match headers.get("Content-Type") {
283        Some(content_type) => {
284            if content_type.contains("multipart/form-data") {
285                let path = headers.get("Path").unwrap();
286                let response =
287                    handle_multipart_file_upload(content_type, &headers, &mut reader, path)
288                        .map_err(|err| {
289                            ApiError::new_with_html(400, &format!("File upload error: {}", err))
290                        })?;
291                return Ok(response);
292            } else {
293                body = parse_body(&headers, reader, &mut buffer)?;
294            }
295        }
296        None => {
297            body = parse_body(&headers, reader, &mut buffer)?;
298        }
299    }
300
301    let response = router.route(path, method, body.as_deref(), peer_addr, &headers)?;
302    Ok(response)
303}
304
305fn parse_body<'a>(
306    headers: &HashMap<&str, &str>,
307    reader: BufReader<&mut Box<dyn ReadWrite>>,
308    buffer: &'a mut Vec<u8>,
309) -> Result<Option<Cow<'a, str>>, Box<dyn std::error::Error>> {
310    match headers.get("Content-Length") {
311        Some(content_length) => {
312            let content_length = content_length.parse::<usize>()?;
313            let mut body_reader = reader.take(content_length.try_into()?);
314            body_reader.read_to_end(buffer)?;
315            let body = String::from_utf8_lossy(&buffer[..]);
316            Ok(Some(body))
317        }
318        None => Ok(None),
319    }
320}
321
322fn handle_multipart_file_upload(
323    content_type: &str,
324    headers: &HashMap<&str, &str>,
325    reader: &mut BufReader<&mut Box<dyn ReadWrite>>,
326    path: &str,
327) -> Result<HttpResponse, Box<dyn std::error::Error>> {
328    let idx = content_type
329        .find("boundary=")
330        .ok_or("Missing multipart boundary")?;
331    let boundary = format!("--{}", &content_type[(idx + "boundary=".len())..]);
332    let mut multipart_headers = HashMap::new();
333    let mut header_size = 0;
334
335    //read headers
336    loop {
337        let mut line = String::new();
338        header_size += reader.read_line(&mut line)?;
339        if line.trim() == boundary {
340            continue;
341        }
342        if line == "\r\n" {
343            break;
344        }
345
346        let parts: Vec<&str> = line.trim().split(':').map(|s| s.trim()).collect();
347        if parts.len() < 2 {
348            return Err("Error parsing multipart request".into());
349        }
350        multipart_headers.insert(parts[0].to_owned(), parts[1].to_owned());
351    }
352
353    //get file name from content disposition and form target path
354    let content_disposition = multipart_headers
355        .get("Content-Disposition")
356        .ok_or("Missing content disposition")?;
357    let filename = content_disposition
358        .split("filename=\"")
359        .nth(1)
360        .and_then(|s| s.split('\"').next())
361        .ok_or("Error parsing file name")?;
362    let mut target_path = PathBuf::from("./").canonicalize()?.join(path);
363    target_path.push(filename);
364
365    let current_dir = std::env::current_dir()?;
366    if !target_path.starts_with(current_dir) {
367        return Err("Only paths relative to the current directory are allowed".into());
368    }
369
370    //calculate file size based on whole content length so that reading the stream can be stopped
371    let mut file = File::create(target_path)?;
372    let content_length = headers
373        .get("Content-Length")
374        .ok_or("Missing content length")?
375        .parse::<usize>()?;
376    let file_bytes = content_length - boundary.len() - header_size - 6;
377
378    //take only the file length from the main buf reader
379    let mut limited_reader = reader.take(file_bytes.try_into()?);
380
381    //copy streams
382    io::copy(&mut limited_reader, &mut file)?;
383
384    let response = HttpResponse::new(
385        Some(Body::Text(format!(
386            "File {} uploaded successfully.",
387            filename
388        ))),
389        Some(String::from("text/plain")),
390        200,
391    );
392    Ok(response)
393}