Skip to main content

rust_web_server/server/
mod.rs

1#[cfg(test)]
2pub mod tests;
3#[cfg(test)]
4mod example;
5
6use std::io::prelude::*;
7use std::borrow::Borrow;
8use std::net::{IpAddr, SocketAddr, TcpListener};
9use std::str::FromStr;
10use std::time::Duration;
11
12use crate::request::{METHOD, Request};
13use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
14use crate::app::App;
15use crate::application::Application;
16use crate::core::{New};
17use crate::entry_point::{bootstrap, get_ip_port_thread_count, get_request_allocation_size, set_default_values};
18use crate::header::Header;
19use crate::log::Log;
20use crate::mime_type::MimeType;
21use crate::range::{ContentRange, Range};
22use crate::symbol::SYMBOL;
23use crate::thread_pool::ThreadPool;
24
25pub struct Server {}
26impl Server {
27    pub fn process_request(mut stream: impl Read + Write + Unpin, peer_addr: SocketAddr) -> Vec<u8> {
28        let request_allocation_size = get_request_allocation_size();
29        let mut buffer = vec![0; request_allocation_size as usize];
30        let boxed_read = stream.read(&mut buffer);
31        if boxed_read.is_err() {
32            let message = boxed_read.err().unwrap().to_string();
33            eprintln!("unable to read TCP stream {}", &message);
34
35            let raw_response = Server::bad_request_response(message);
36            let boxed_stream = stream.write(raw_response.borrow());
37            if boxed_stream.is_ok() {
38                stream.flush().unwrap();
39            };
40            return raw_response;
41        }
42
43        boxed_read.unwrap();
44        let request : &[u8] = &buffer;
45
46        // let raw_request = String::from_utf8(Vec::from(request)).unwrap();
47        // println!("\n\n______{}______\n\n", raw_request);
48
49
50        let boxed_request = Request::parse_request(request);
51        if boxed_request.is_err() {
52            let message = boxed_request.err().unwrap();
53            eprintln!("unable to parse request: {}", &message);
54
55            let raw_response = Server::bad_request_response(message);
56            let boxed_stream = stream.write(raw_response.borrow());
57            if boxed_stream.is_ok() {
58                stream.flush().unwrap();
59            };
60            return raw_response;
61        }
62
63
64        let request: Request = boxed_request.unwrap();
65        let (response, request) = App::handle_request(request);
66
67
68        let log_request_response = Log::combined(&request, &response, &peer_addr);
69        println!("{}", log_request_response);
70        let raw_response = Response::generate_response(response, request);
71
72        let boxed_stream = stream.write(raw_response.borrow());
73        if boxed_stream.is_ok() {
74            stream.flush().unwrap();
75        };
76
77        raw_response
78    }
79
80    pub fn bad_request_response(message: String) -> Vec<u8> {
81        let error_request = Request {
82            method: METHOD.get.to_string(),
83            request_uri: "".to_string(),
84            http_version: "".to_string(),
85            headers: vec![],
86            body: vec![],
87        };
88
89        let size = message.chars().count() as u64;
90        let content_range = ContentRange {
91            unit: Range::BYTES.to_string(),
92            range: Range { start: 0, end: size },
93            size: size.to_string(),
94            body: Vec::from(message.as_bytes()),
95            content_type: MimeType::TEXT_PLAIN.to_string(),
96        };
97
98        let header_list = Header::get_header_list(&error_request);
99        let error_response: Response = Response::get_response(
100            STATUS_CODE_REASON_PHRASE.n400_bad_request,
101            Some(header_list),
102            Some(vec![content_range])
103        );
104
105        let response = Response::generate_response(error_response, error_request);
106        return response;
107    }
108
109    pub fn process(mut stream: impl Read + Write + Unpin,
110                   connection: ConnectionInfo,
111                   app: impl Application) -> Result<(), String> {
112
113        let request_allocation_size = connection.request_size;
114        let mut buffer = vec![0; request_allocation_size as usize];
115        let boxed_read = stream.read(&mut buffer);
116        if boxed_read.is_err() {
117            let read_message = boxed_read.err().unwrap().to_string();
118            let raw_response = Server::bad_request_response(read_message.clone());
119            let boxed_stream = stream.write(raw_response.borrow());
120            if boxed_stream.is_ok() {
121                stream.flush().unwrap();
122            } else {
123                let write_message = boxed_stream.err().unwrap().to_string();
124                let combined_error = [read_message.clone(), SYMBOL.comma.to_string(), write_message].join(SYMBOL.empty_string);
125                return Err(combined_error);
126            };
127
128            return Err(read_message);
129        }
130
131        boxed_read.unwrap();
132        let request : &[u8] = &buffer;
133
134        // let raw_request = String::from_utf8(Vec::from(request)).unwrap();
135        // println!("\n\n______{}______\n\n", raw_request);
136
137
138        let boxed_request = Request::parse(request);
139        if boxed_request.is_err() {
140            let message = boxed_request.err().unwrap();
141
142            let raw_response = Server::bad_request_response(message.clone());
143            let boxed_stream = stream.write(raw_response.borrow());
144            if boxed_stream.is_ok() {
145                stream.flush().unwrap();
146            } else {
147                let write_message = boxed_stream.err().unwrap().to_string();
148                let combined_error = [message, SYMBOL.comma.to_string(), write_message].join(SYMBOL.empty_string);
149                return Err(combined_error);
150            };
151            return Err(message);
152        }
153
154
155        let request: Request = boxed_request.unwrap();
156
157        let app_processing = app.execute(&request, &connection);
158        if app_processing.is_err() {
159            let message = app_processing.as_ref().err().unwrap().to_string();
160            let response = Server::bad_request_response(message);
161
162            let boxed_stream = stream.write(response.borrow());
163            if boxed_stream.is_ok() {
164                stream.flush().unwrap();
165            } else {
166                let write_message = boxed_stream.err().unwrap().to_string();
167                return Err(write_message);
168            };
169        }
170        let response = app_processing.unwrap();
171
172
173        let client = connection.client;
174        let client_addr = SocketAddr::new(IpAddr::from_str(client.ip.as_str()).unwrap(), client.port as u16);
175        let log_request_response = Log::combined(&request, &response, &client_addr);
176        println!("{}", log_request_response);
177
178        let raw_response = Response::generate_response(response, request);
179
180        let boxed_stream = stream.write(raw_response.borrow());
181        if boxed_stream.is_ok() {
182            stream.flush().unwrap();
183        } else {
184            let write_message = boxed_stream.err().unwrap().to_string();
185            return Err(write_message);
186        };
187
188        Ok(())
189    }
190
191    /// Reads configuration (IP, port, thread count, TLS paths) from the layered config system
192    /// and returns a bound `TcpListener` and a sized `ThreadPool`. Call once at startup.
193    pub fn setup() -> Result<(TcpListener, ThreadPool), String> {
194        let info = Log::info("Rust Web Server");
195        println!("{}", info);
196
197        let usage_info = Log::usage_information();
198        println!("{}", usage_info);
199
200
201        println!("RWS Configuration Start: \n");
202
203        set_default_values();
204        bootstrap();
205
206        println!("\nRWS Configuration End\n\n");
207
208
209        let (ip, port, thread_count) = get_ip_port_thread_count();
210
211
212        let mut ip_readable = ip.to_string();
213
214        if ip.contains(":") {
215            ip_readable = [SYMBOL.opening_square_bracket, &ip, SYMBOL.closing_square_bracket].join("");
216        }
217
218        let bind_addr = [ip_readable, SYMBOL.colon.to_string(), port.to_string()].join(SYMBOL.empty_string);
219
220        #[cfg(feature = "http2")]
221        let protocol = {
222            let cert = std::env::var(crate::entry_point::Config::RWS_CONFIG_TLS_CERT_FILE).unwrap_or_default();
223            if cert.is_empty() { "http" } else { "https" }
224        };
225        #[cfg(not(feature = "http2"))]
226        let protocol = "http";
227
228        println!("Setting up {}://{}...", protocol, &bind_addr);
229
230        let boxed_listener = TcpListener::bind(&bind_addr);
231        if boxed_listener.is_err() {
232            let message = format!("unable to set up TCP listener: {}", boxed_listener.err().unwrap());
233            return Err(message);
234        }
235
236        let listener = boxed_listener.unwrap();
237        let pool = ThreadPool::new(thread_count as usize);
238
239
240        let server_url_thread_count = Log::server_url_thread_count(protocol, &bind_addr, thread_count);
241        println!("{}", server_url_thread_count);
242
243        Ok((listener, pool))
244    }
245
246    /// Accepts TCP connections in a loop and dispatches each to the thread pool.
247    /// Blocks forever (plain HTTP/1.1). For TLS/HTTP2/HTTP3 use [`Server::run_tls`].
248    pub fn run(listener : TcpListener,
249               pool: ThreadPool,
250               app: impl Application + New + Send + 'static + Copy) {
251        for boxed_stream in listener.incoming() {
252            if boxed_stream.is_err() {
253                eprintln!("unable to get TCP stream: {}", boxed_stream.err().unwrap());
254                return;
255            }
256
257            let stream = boxed_stream.unwrap();
258
259            print!("Connection established, ");
260
261            let boxed_local_addr = stream.local_addr();
262            if boxed_local_addr.is_ok() {
263                print!("local addr: {}", boxed_local_addr.unwrap())
264            } else {
265                eprintln!("\nunable to read local addr");
266                return;
267            }
268
269            let boxed_peer_addr = stream.peer_addr();
270            if boxed_peer_addr.is_err() {
271                eprintln!("\nunable to read peer addr");
272                return;
273            }
274            let peer_addr = boxed_peer_addr.unwrap();
275            print!(", peer addr: {}\n", peer_addr.to_string());
276
277            let (server_ip, server_port, _thread_count) = get_ip_port_thread_count();
278            let client_ip = peer_addr.ip().to_string();
279            let client_port = peer_addr.port() as i32;
280            let request_allocation_size = get_request_allocation_size();
281
282            let connection = ConnectionInfo {
283                client: Address {
284                    ip: client_ip.to_string(),
285                    port: client_port
286                },
287                server: Address {
288                    ip: server_ip,
289                    port: server_port
290                },
291                request_size: request_allocation_size,
292            };
293
294
295
296            if let Err(e) = stream.set_read_timeout(Some(Duration::from_secs(30))) {
297                eprintln!("failed to set read timeout: {}", e);
298            }
299
300            pool.execute(move || {
301                let boxed_process = Server::process(stream, connection, app);
302                if boxed_process.is_err() {
303                    let message = boxed_process.err().unwrap();
304                    eprintln!("{}", message);
305                }
306            });
307
308        }
309
310
311    }
312
313}
314
315/// Network context for the current connection, passed into every [`Controller`](crate::controller::Controller).
316#[derive(Clone)]
317pub struct ConnectionInfo {
318    /// Client (peer) address.
319    pub client: Address,
320    /// Server (local) address.
321    pub server: Address,
322    /// Bytes allocated for reading the request.
323    pub request_size: i64
324}
325
326/// IP address and port pair.
327#[derive(Clone)]
328pub struct Address {
329    pub ip: String,
330    pub port: i32
331}
332
333#[cfg(feature = "http2")]
334impl Server {
335    pub async fn run_tls(
336        listener: TcpListener,
337        pool: ThreadPool,
338        app: impl Application + New + Send + 'static + Copy,
339    ) {
340        use crate::tls::create_tls_acceptor;
341        use crate::h2_handler;
342
343        let cert_path = std::env::var(crate::entry_point::Config::RWS_CONFIG_TLS_CERT_FILE)
344            .unwrap_or_default();
345        let key_path = std::env::var(crate::entry_point::Config::RWS_CONFIG_TLS_KEY_FILE)
346            .unwrap_or_default();
347
348        if cert_path.is_empty() || key_path.is_empty() {
349            println!("No TLS certificate configured — serving plain HTTP/1.1.");
350            tokio::task::block_in_place(|| Server::run(listener, pool, app));
351            return;
352        }
353
354        let tls_acceptor = match create_tls_acceptor(&cert_path, &key_path) {
355            Ok(a) => a,
356            Err(e) => {
357                eprintln!("TLS setup failed: {}", e);
358                return;
359            }
360        };
361
362        listener
363            .set_nonblocking(true)
364            .expect("failed to set TCP listener to non-blocking");
365        let tokio_listener = tokio::net::TcpListener::from_std(listener)
366            .expect("failed to convert TCP listener to tokio");
367
368        println!("Listening for TLS connections (HTTP/1.1 + HTTP/2)...");
369
370        loop {
371            tokio::select! {
372                result = tokio_listener.accept() => {
373                    match result {
374                        Ok((tcp_stream, peer_addr)) => {
375                            let acceptor = tls_acceptor.clone();
376                            tokio::spawn(async move {
377                                match acceptor.accept(tcp_stream).await {
378                                    Ok(tls_stream) => {
379                                        let protocol = tls_stream
380                                            .get_ref()
381                                            .1
382                                            .alpn_protocol()
383                                            .map(|p| p.to_vec());
384
385                                        match protocol.as_deref() {
386                                            Some(b"h2") => {
387                                                if let Err(e) =
388                                                    h2_handler::handle_connection(tls_stream, peer_addr, app)
389                                                        .await
390                                                {
391                                                    eprintln!("H2 connection error: {}", e);
392                                                }
393                                            }
394                                            _ => {
395                                                if let Err(e) =
396                                                    Server::process_h1_tls(tls_stream, peer_addr, app).await
397                                                {
398                                                    eprintln!("H1 TLS error: {}", e);
399                                                }
400                                            }
401                                        }
402                                    }
403                                    Err(e) => eprintln!("TLS handshake failed: {}", e),
404                                }
405                            });
406                        }
407                        Err(e) => eprintln!("TCP accept error: {}", e),
408                    }
409                }
410                _ = tokio::signal::ctrl_c() => {
411                    println!("\nShutting down gracefully.");
412                    break;
413                }
414            }
415        }
416    }
417
418    async fn process_h1_tls(
419        mut stream: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
420        peer_addr: std::net::SocketAddr,
421        app: impl Application,
422    ) -> Result<(), String> {
423        use tokio::io::{AsyncReadExt, AsyncWriteExt};
424
425        let (server_ip, server_port, _) = get_ip_port_thread_count();
426        let request_allocation_size = get_request_allocation_size();
427
428        let mut buffer = vec![0u8; request_allocation_size as usize];
429        if let Err(e) = stream.read(&mut buffer).await {
430            let raw = Server::bad_request_response(e.to_string());
431            let _ = stream.write_all(&raw).await;
432            return Ok(());
433        }
434
435        let request = match Request::parse(&buffer) {
436            Ok(r) => r,
437            Err(message) => {
438                let raw = Server::bad_request_response(message);
439                let _ = stream.write_all(&raw).await;
440                return Ok(());
441            }
442        };
443
444        let connection = ConnectionInfo {
445            client: Address {
446                ip: peer_addr.ip().to_string(),
447                port: peer_addr.port() as i32,
448            },
449            server: Address {
450                ip: server_ip,
451                port: server_port,
452            },
453            request_size: request_allocation_size,
454        };
455
456        let mut response = match app.execute(&request, &connection) {
457            Ok(r) => r,
458            Err(message) => {
459                let raw = Server::bad_request_response(message);
460                let _ = stream.write_all(&raw).await;
461                return Ok(());
462            }
463        };
464
465        response.headers.push(Header::get_hsts_header());
466
467        #[cfg(feature = "http3")]
468        response.headers.push(Header {
469            name: Header::_ALT_SVC.to_string(),
470            value: format!("h3=\":{}\"", server_port),
471        });
472        #[cfg(not(feature = "http3"))]
473        response.headers.push(Header {
474            name: Header::_ALT_SVC.to_string(),
475            value: format!("h2=\":{}\"", server_port),
476        });
477
478        let log = Log::combined(&request, &response, &peer_addr);
479        println!("{}", log);
480
481        let raw = Response::generate_response(response, request);
482        stream
483            .write_all(&raw)
484            .await
485            .map_err(|e| e.to_string())?;
486        stream.flush().await.map_err(|e| e.to_string())?;
487
488        Ok(())
489    }
490}
491
492#[cfg(feature = "http3")]
493impl Server {
494    pub async fn run_quic(
495        app: impl Application + New + Send + 'static + Copy,
496    ) {
497        use crate::tls::create_quinn_server_config;
498        use crate::h3_handler;
499
500        let cert_path = std::env::var(crate::entry_point::Config::RWS_CONFIG_TLS_CERT_FILE)
501            .unwrap_or_default();
502        let key_path = std::env::var(crate::entry_point::Config::RWS_CONFIG_TLS_KEY_FILE)
503            .unwrap_or_default();
504
505        if cert_path.is_empty() || key_path.is_empty() {
506            return;
507        }
508
509        let server_config = match create_quinn_server_config(&cert_path, &key_path) {
510            Ok(c) => c,
511            Err(e) => {
512                eprintln!("QUIC TLS setup failed: {}", e);
513                return;
514            }
515        };
516
517        let (server_ip, server_port, _) = get_ip_port_thread_count();
518        let bind_addr = format!("{}:{}", server_ip, server_port);
519        let addr: std::net::SocketAddr = match bind_addr.parse() {
520            Ok(a) => a,
521            Err(e) => {
522                eprintln!("Invalid QUIC bind address '{}': {}", bind_addr, e);
523                return;
524            }
525        };
526
527        let endpoint = match quinn::Endpoint::server(server_config, addr) {
528            Ok(e) => e,
529            Err(e) => {
530                eprintln!("QUIC endpoint error: {}", e);
531                return;
532            }
533        };
534
535        println!("Listening for QUIC/HTTP3 on UDP {}:{}", server_ip, server_port);
536
537        loop {
538            tokio::select! {
539                maybe = endpoint.accept() => {
540                    match maybe {
541                        Some(incoming) => {
542                            tokio::spawn(async move {
543                                match incoming.await {
544                                    Ok(conn) => {
545                                        let peer_addr = conn.remote_address();
546                                        if let Err(e) = h3_handler::handle_connection(conn, peer_addr, app).await {
547                                            eprintln!("H3 connection error: {}", e);
548                                        }
549                                    }
550                                    Err(e) => eprintln!("QUIC connection error: {}", e),
551                                }
552                            });
553                        }
554                        None => break,
555                    }
556                }
557                _ = tokio::signal::ctrl_c() => {
558                    println!("\nShutting down QUIC.");
559                    endpoint.close(0u32.into(), b"shutdown");
560                    break;
561                }
562            }
563        }
564    }
565}
566
567