small_http/
serve.rs

1/*
2 * Copyright (c) 2023-2025 Bastiaan van der Plaat
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7use std::io::Write;
8use std::net::TcpListener;
9use std::time::Duration;
10
11use crate::request::Request;
12use crate::response::Response;
13
14/// Start HTTP server single threaded
15pub fn serve_single_threaded<F>(listener: TcpListener, handler: F)
16where
17    F: Fn(&Request) -> Response + Clone + Send + 'static,
18{
19    // Listen for incoming tcp clients
20    for stream in listener.incoming() {
21        let mut stream = stream.expect("Failed to accept connection");
22        stream
23            .set_read_timeout(Some(Duration::from_secs(1)))
24            .expect("Can't set read timeout");
25
26        // Read incoming request
27        let client_addr = stream
28            .peer_addr()
29            .expect("Can't get tcp stream client addr");
30
31        match Request::read_from_stream(&mut stream, client_addr) {
32            Ok(request) => {
33                // Handle request and write response
34                let mut response = handler(&request);
35                response.write_to_stream(
36                    &mut stream,
37                    &request,
38                    request.headers.get("Connection").is_some(),
39                );
40
41                // If the response has a takeover function, start thread and move tcp stream
42                if let Some(takeover) = response.takeover.take() {
43                    std::thread::spawn(move || takeover(stream));
44                }
45            }
46            Err(err) => {
47                // Invalid request received
48                _ = write!(stream, "HTTP/1.0 400 Bad Request\r\n\r\n");
49                println!("Error: Invalid http request: {:?}", err);
50            }
51        }
52    }
53}
54
55/// Start HTTP server
56#[cfg(feature = "multi-threaded")]
57pub fn serve<F>(listener: TcpListener, handler: F)
58where
59    F: Fn(&Request) -> Response + Clone + Send + 'static,
60{
61    // Create thread pool with workers
62    let num_threads = std::thread::available_parallelism().map_or(1, |n| n.get());
63    let pool = threadpool::ThreadPool::new(num_threads * 64);
64
65    // Listen for incoming tcp clients
66    for stream in listener.incoming() {
67        let mut stream = stream.expect("Failed to accept connection");
68        stream
69            .set_read_timeout(Some(crate::KEEP_ALIVE_TIMEOUT))
70            .expect("Can't set read timeout");
71
72        let handler = handler.clone();
73        pool.execute(move || loop {
74            // Wait for data to be available
75            let mut buffer = [0; 1];
76            match stream.peek(&mut buffer) {
77                Ok(0) => {
78                    return;
79                }
80                Ok(_) => {} // Data available continue
81                Err(e) => {
82                    if e.kind() != std::io::ErrorKind::WouldBlock
83                        && e.kind() != std::io::ErrorKind::TimedOut
84                    {
85                        println!("Error: {:?}", e);
86                    }
87                    return;
88                }
89            }
90
91            // Read incoming request
92            let client_addr = stream
93                .peer_addr()
94                .expect("Can't get tcp stream client addr");
95            match Request::read_from_stream(&mut stream, client_addr) {
96                Ok(request) => {
97                    // Handle request and write response
98                    let mut response = handler(&request);
99                    response.write_to_stream(&mut stream, &request, true);
100
101                    // If the response has a takeover function, start thread and move tcp stream
102                    if let Some(takeover) = response.takeover.take() {
103                        std::thread::spawn(move || takeover(stream));
104                        return;
105                    }
106
107                    // Close connection if HTTP/1.0 or Connection: close
108                    if request.version == crate::enums::Version::Http1_0
109                        || request.headers.get("Connection").map(|v| v.as_str()) == Some("close")
110                    {
111                        return;
112                    }
113                }
114                Err(err) => {
115                    // Invalid request received
116                    _ = write!(stream, "HTTP/1.0 400 Bad Request\r\n\r\n");
117                    println!("Error: Invalid http request: {:?}", err);
118                    return;
119                }
120            }
121        });
122    }
123}
124
125// MARK: Tests
126#[cfg(test)]
127mod test {
128    use std::io::Read;
129    use std::net::{Ipv4Addr, TcpStream};
130    use std::thread;
131
132    use super::*;
133    use crate::enums::Status;
134
135    #[test]
136    fn test_serve_single_threaded() {
137        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).expect("Failed to bind address");
138        let addr = listener.local_addr().unwrap();
139
140        thread::spawn(move || {
141            serve_single_threaded(listener, |_req| Response::with_status(Status::Ok));
142        });
143
144        let mut stream = TcpStream::connect(addr).expect("Failed to connect to server");
145        stream
146            .write_all(b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n")
147            .expect("Failed to write to stream");
148
149        let mut response = Vec::new();
150        stream
151            .read_to_end(&mut response)
152            .expect("Failed to read from stream");
153        assert!(response.starts_with(b"HTTP/1.1 200 OK"));
154    }
155
156    #[test]
157    #[cfg(feature = "multi-threaded")]
158    fn test_serve() {
159        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).expect("Failed to bind address");
160        let addr = listener.local_addr().unwrap();
161
162        thread::spawn(move || {
163            serve(listener, |_req| Response::with_status(Status::Ok));
164        });
165
166        for _ in 0..10 {
167            let mut stream = TcpStream::connect(addr).expect("Failed to connect to server");
168            stream
169                .write_all(b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
170                .expect("Failed to write to stream");
171
172            let mut response = Vec::new();
173            stream
174                .read_to_end(&mut response)
175                .expect("Failed to read from stream");
176            assert!(response.starts_with(b"HTTP/1.1 200 OK"));
177        }
178    }
179}