1use std::io::Write;
8use std::net::TcpListener;
9use std::time::Duration;
10
11use crate::request::Request;
12use crate::response::Response;
13
14pub fn serve_single_threaded<F>(listener: TcpListener, handler: F)
16where
17 F: Fn(&Request) -> Response + Clone + Send + 'static,
18{
19 for stream in listener.incoming() {
21 let mut stream = stream.expect("Failed to accept connection");
22 stream
23 .set_read_timeout(Some(Duration::from_secs(1)))
24 .expect("Can't set read timeout");
25
26 let client_addr = stream
28 .peer_addr()
29 .expect("Can't get tcp stream client addr");
30
31 match Request::read_from_stream(&mut stream, client_addr) {
32 Ok(request) => {
33 let mut response = handler(&request);
35 response.write_to_stream(
36 &mut stream,
37 &request,
38 request.headers.get("Connection").is_some(),
39 );
40
41 if let Some(takeover) = response.takeover.take() {
43 std::thread::spawn(move || takeover(stream));
44 }
45 }
46 Err(err) => {
47 _ = write!(stream, "HTTP/1.0 400 Bad Request\r\n\r\n");
49 println!("Error: Invalid http request: {:?}", err);
50 }
51 }
52 }
53}
54
55#[cfg(feature = "multi-threaded")]
57pub fn serve<F>(listener: TcpListener, handler: F)
58where
59 F: Fn(&Request) -> Response + Clone + Send + 'static,
60{
61 let num_threads = std::thread::available_parallelism().map_or(1, |n| n.get());
63 let pool = threadpool::ThreadPool::new(num_threads * 64);
64
65 for stream in listener.incoming() {
67 let mut stream = stream.expect("Failed to accept connection");
68 stream
69 .set_read_timeout(Some(crate::KEEP_ALIVE_TIMEOUT))
70 .expect("Can't set read timeout");
71
72 let handler = handler.clone();
73 pool.execute(move || loop {
74 let mut buffer = [0; 1];
76 match stream.peek(&mut buffer) {
77 Ok(0) => {
78 return;
79 }
80 Ok(_) => {} Err(e) => {
82 if e.kind() != std::io::ErrorKind::WouldBlock
83 && e.kind() != std::io::ErrorKind::TimedOut
84 {
85 println!("Error: {:?}", e);
86 }
87 return;
88 }
89 }
90
91 let client_addr = stream
93 .peer_addr()
94 .expect("Can't get tcp stream client addr");
95 match Request::read_from_stream(&mut stream, client_addr) {
96 Ok(request) => {
97 let mut response = handler(&request);
99 response.write_to_stream(&mut stream, &request, true);
100
101 if let Some(takeover) = response.takeover.take() {
103 std::thread::spawn(move || takeover(stream));
104 return;
105 }
106
107 if request.version == crate::enums::Version::Http1_0
109 || request.headers.get("Connection").map(|v| v.as_str()) == Some("close")
110 {
111 return;
112 }
113 }
114 Err(err) => {
115 _ = write!(stream, "HTTP/1.0 400 Bad Request\r\n\r\n");
117 println!("Error: Invalid http request: {:?}", err);
118 return;
119 }
120 }
121 });
122 }
123}
124
125#[cfg(test)]
127mod test {
128 use std::io::Read;
129 use std::net::{Ipv4Addr, TcpStream};
130 use std::thread;
131
132 use super::*;
133 use crate::enums::Status;
134
135 #[test]
136 fn test_serve_single_threaded() {
137 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).expect("Failed to bind address");
138 let addr = listener.local_addr().unwrap();
139
140 thread::spawn(move || {
141 serve_single_threaded(listener, |_req| Response::with_status(Status::Ok));
142 });
143
144 let mut stream = TcpStream::connect(addr).expect("Failed to connect to server");
145 stream
146 .write_all(b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n")
147 .expect("Failed to write to stream");
148
149 let mut response = Vec::new();
150 stream
151 .read_to_end(&mut response)
152 .expect("Failed to read from stream");
153 assert!(response.starts_with(b"HTTP/1.1 200 OK"));
154 }
155
156 #[test]
157 #[cfg(feature = "multi-threaded")]
158 fn test_serve() {
159 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).expect("Failed to bind address");
160 let addr = listener.local_addr().unwrap();
161
162 thread::spawn(move || {
163 serve(listener, |_req| Response::with_status(Status::Ok));
164 });
165
166 for _ in 0..10 {
167 let mut stream = TcpStream::connect(addr).expect("Failed to connect to server");
168 stream
169 .write_all(b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
170 .expect("Failed to write to stream");
171
172 let mut response = Vec::new();
173 stream
174 .read_to_end(&mut response)
175 .expect("Failed to read from stream");
176 assert!(response.starts_with(b"HTTP/1.1 200 OK"));
177 }
178 }
179}