small_http/
serve.rs

1/*
2 * Copyright (c) 2023-2025 Bastiaan van der Plaat
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7use std::io::{self, Write};
8use std::net::TcpListener;
9use std::thread;
10use std::time::Duration;
11
12use threadpool::ThreadPool;
13
14use crate::enums::Version;
15use crate::request::Request;
16use crate::response::Response;
17
18const WORK_THREAD_PER_CORE: usize = 64;
19pub(crate) const KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(5);
20
21/// Start HTTP server
22pub fn serve<F>(listener: TcpListener, handler: F)
23where
24    F: Fn(&Request) -> Response + Clone + Send + 'static,
25{
26    // Create thread pool with workers
27    let num_cores = thread::available_parallelism()
28        .map(|n| n.get())
29        .unwrap_or(1);
30    let pool = ThreadPool::new(num_cores * WORK_THREAD_PER_CORE);
31
32    // Listen for incoming tcp clients
33    for mut stream in listener.incoming().flatten() {
34        stream
35            .set_read_timeout(Some(KEEP_ALIVE_TIMEOUT))
36            .expect("Can't set read timeout");
37
38        let handler = handler.clone();
39        pool.execute(move || loop {
40            // Wait for data to be available
41            let mut buffer = [0; 1];
42            match stream.peek(&mut buffer) {
43                Ok(0) => {
44                    return;
45                }
46                Ok(_) => {} // Data available continue
47                Err(e) => {
48                    if e.kind() != io::ErrorKind::WouldBlock && e.kind() != io::ErrorKind::TimedOut
49                    {
50                        println!("Error: {:?}", e);
51                    }
52                    return;
53                }
54            }
55
56            // Read incoming request
57            let client_addr = stream
58                .peer_addr()
59                .expect("Can't get tcp stream client addr");
60            match Request::read_from_stream(&mut stream, client_addr) {
61                Ok(req) => {
62                    // Handle request
63                    handler(&req).write_to_stream(&mut stream, &req);
64
65                    // Close connection if HTTP/1.0 or Connection: close
66                    if req.version == Version::Http1_0
67                        || req.headers.get("Connection").map(|v| v.as_str()) == Some("close")
68                    {
69                        return;
70                    }
71                }
72                Err(err) => {
73                    // Invalid request received
74                    _ = write!(stream, "HTTP/1.0 400 Bad Request\r\n\r\n");
75                    println!("Error: Invalid http request: {:?}", err);
76                    return;
77                }
78            }
79        });
80    }
81}
82
83// MARK: Tests
84#[cfg(test)]
85mod test {
86    use std::net::{Ipv4Addr, TcpStream};
87    use std::thread;
88
89    use io::Read;
90
91    use super::*;
92    use crate::enums::Status;
93
94    #[test]
95    fn test_serve() {
96        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).expect("Failed to bind address");
97        let addr = listener.local_addr().unwrap();
98
99        thread::spawn(move || {
100            serve(listener, |_req| Response::with_status(Status::Ok));
101        });
102
103        for _ in 0..10 {
104            let mut stream = TcpStream::connect(addr).expect("Failed to connect to server");
105            stream
106                .write_all(b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
107                .expect("Failed to write to stream");
108
109            let mut response = Vec::new();
110            stream
111                .read_to_end(&mut response)
112                .expect("Failed to read from stream");
113            assert!(response.starts_with(b"HTTP/1.1 200 OK"));
114        }
115    }
116}