1use std::collections::HashMap;
2use std::io::prelude::*;
3use std::net::{TcpListener, TcpStream};
4use std::panic;
5use std::sync::{Arc, Mutex};
6
7use crate::http::{http_err, Request, Response};
8use crate::threadpool::ThreadPool;
9
10use concat_idents::concat_idents;
11
12macro_rules! setter {
13 ($name:ident, $type:ty) => {
14 concat_idents!(fn_name = set_, $name {
15 pub fn fn_name(&mut self, val: $type) {
16 self.$name = val;
17 }
18 });
19 }
20}
21
22type RouteCallback = fn(Request) -> Response; type Routes = Arc<Mutex<HashMap<String, RouteCallback>>>;
24
25pub struct Server {
26 num_threads: usize,
27 routes: Routes,
28}
29
30impl Server {
31 pub fn listen(&self, port: usize) -> std::io::Result<()> {
32 let listener: TcpListener = TcpListener::bind(format!("0.0.0.0:{}", port))?;
33
34 println!("listening on port {}", port);
35 let pool = ThreadPool::build(4)?;
36
37 for stream in listener.incoming() {
38 let stream = stream.unwrap();
39
40 let routes = self.routes.clone();
41
42 pool.execute(|| {
43 Self::handle_connection(stream, routes);
44 });
45 }
46
47 return Ok(());
48 }
49
50 pub fn default() -> Self {
51 Self {
52 num_threads: 10,
53 routes: Arc::new(Mutex::new(HashMap::new())),
54 }
55 }
56
57 pub fn register(&mut self, path: &str, cb: RouteCallback) {
60 self.routes.lock().unwrap().insert(path.to_string(), cb);
61 }
62
63 fn handle_connection(mut stream: TcpStream, routes: Routes) {
64 println!("Connection established!");
65 if let Err(_) = stream.set_read_timeout(Some(std::time::Duration::from_secs(3))) {
66 eprint!("failed to set request read timeout");
67 }
68
69 let request = Request::try_from_stream(&mut stream);
70 match request {
71 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
72 eprintln!("request timed out");
73 return;
74 }
75 Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => {
76 eprintln!("request time out");
77 return;
78 }
79 Err(_) => {
80 eprintln!("unhandled request error");
81 return;
82 }
83 Ok(_) => {}
84 }
85
86 let response = Self::process_request(request.unwrap(), routes);
87 if let Err(e) = response {
88 eprintln!("{}", e.to_string());
89 return;
90 }
91
92 stream
94 .write_all(response.unwrap().as_string().as_bytes())
95 .unwrap();
96 }
97
98 fn process_request(request: Request, routes: Routes) -> std::io::Result<Response> {
99 dbg!();
100 dbg!();
110
111 dbg!(request.uri.path());
118 let path = request.uri.path();
119 if let Some(cb) = routes.lock().unwrap().get(path) {
120 let result = panic::catch_unwind(|| cb(Request::new()));
121 return match result {
122 Ok(response) => Ok(response),
123 Err(_) => Err(http_err(&format!("route {} panicked", path))),
124 };
125 }
126
127 Response::not_found()
128 }
129
130 setter!(num_threads, usize);
131}