tosic_http/server/
mod.rs

1//! Main entry point for the HTTP server.
2
3use crate::body::message_body::MessageBody;
4use crate::error::{Error, ServerError};
5use crate::handlers::Handlers;
6use crate::request::{HttpPayload, HttpRequest};
7use crate::response::HttpResponse;
8use crate::route::HandlerFn;
9use crate::server::builder::HttpServerBuilder;
10use crate::state::State;
11use http::HeaderMap;
12use std::fmt::Debug;
13use tokio::io;
14use tokio::io::BufReader;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::ToSocketAddrs;
17use tower::layer::util::Identity;
18use tower::{Layer, Service, ServiceBuilder, ServiceExt};
19#[cfg(feature = "trace")]
20use tracing::trace;
21use tracing::{debug, error, info};
22
23pub mod builder;
24mod test;
25
26/// Represents a running HTTP server.
27///
28/// To construct a server, use [`HttpServer::builder`] or the builder struct directly [`HttpServerBuilder`].
29pub struct HttpServer<L>
30where
31    L: Layer<HandlerFn> + Clone + Send + 'static,
32{
33    listener: tokio::net::TcpListener,
34    handlers: Handlers,
35    app_state: State,
36    service_builder: ServiceBuilder<L>,
37}
38
39impl<L> HttpServer<L>
40where
41    L: Layer<HandlerFn> + Clone + Send + 'static,
42    L::Service: Service<(HttpRequest, HttpPayload), Response = HttpResponse, Error = Error>
43        + Send
44        + 'static,
45    <L::Service as Service<(HttpRequest, HttpPayload)>>::Future: Send + 'static,
46{
47    #[cfg_attr(
48        feature = "trace",
49        tracing::instrument(level = "trace", skip(service_builder))
50    )]
51    /// Create a new [`HttpServer`] instance and binds the server to the provided address.
52    ///
53    /// This meant to be called from [`HttpServerBuilder`] and not externally
54    pub(crate) async fn new<T: ToSocketAddrs + Debug>(
55        addr: T,
56        handlers: Handlers,
57        app_state: State,
58        service_builder: ServiceBuilder<L>,
59    ) -> io::Result<Self> {
60        let listener = tokio::net::TcpListener::bind(addr).await?;
61
62        #[cfg(feature = "trace")]
63        trace!("Server Bound to {}", listener.local_addr()?);
64
65        Ok(Self {
66            listener,
67            handlers,
68            app_state,
69            service_builder,
70        })
71    }
72
73    /// Returns a new [`HttpServerBuilder`] for configuring and building an [`HttpServer`].
74    pub fn builder<T: ToSocketAddrs + Default + Debug + Clone>() -> HttpServerBuilder<T, Identity> {
75        HttpServerBuilder::<T, Identity>::new()
76    }
77
78    /// Starts the server and listens for incoming connections.
79    pub async fn serve(self) -> Result<(), ServerError> {
80        info!("Listening on {}", self.listener.local_addr()?);
81        loop {
82            match self.listener.accept().await {
83                Ok((stream, socket)) => {
84                    #[cfg(feature = "trace")]
85                    trace!("Accepted connection from {}", socket);
86                    self.accept_connection(stream, socket)?;
87                }
88                Err(err) => {
89                    error!("Failed to accept connection: {}", err);
90                    continue;
91                }
92            }
93        }
94    }
95
96    /// Main entry point for an incoming connection.
97    ///
98    /// In this step we spawn a new thread and handle the connection inside it to not block the main thread.
99    fn accept_connection(
100        &self,
101        stream: tokio::net::TcpStream,
102        socket: std::net::SocketAddr,
103    ) -> Result<(), ServerError> {
104        let handlers = self.handlers.clone();
105        let state = self.app_state.clone();
106        let service_builder = self.service_builder.clone();
107
108        tokio::spawn(async move {
109            if let Err(e) = Self::handle_connection(
110                stream,
111                #[cfg(feature = "trace")]
112                socket,
113                handlers,
114                state,
115                service_builder,
116            )
117            .await
118            {
119                error!("Error handling connection from {}: {:?}", socket, e);
120            }
121        });
122
123        Ok(())
124    }
125
126    #[cfg_attr(feature = "trace", tracing::instrument(level = "trace", skip_all))]
127    /// Handles an incoming connection by reading the request, processing it, and sending the response
128    async fn handle_connection(
129        stream: tokio::net::TcpStream,
130        #[cfg(feature = "trace")] socket: std::net::SocketAddr,
131        handlers: Handlers,
132        state: State,
133        service_builder: ServiceBuilder<L>,
134    ) -> Result<(), ServerError> {
135        #[cfg(feature = "trace")]
136        trace!("Accepted connection from {}", socket);
137
138        let mut reader = BufReader::new(stream);
139
140        let request_buffer = match Self::read_request(&mut reader).await {
141            Ok(buffer) => buffer,
142            Err(e) => {
143                error!("Failed to read request: {}", e);
144                return Err(e);
145            }
146        };
147
148        let (mut request, payload) = match HttpRequest::from_bytes(&request_buffer) {
149            Ok(req) => req,
150            Err(e) => {
151                error!("Failed to parse request: {}", e);
152                return Err(e);
153            }
154        };
155
156        request.data = state;
157
158        #[cfg(feature = "trace")]
159        trace!("Request: {:?}", request);
160
161        let handler = handlers.get_handler(request.method(), request.uri().path());
162
163        request.params_mut().extend(handler.1.clone());
164
165        let mut service = service_builder.service(handler.handler());
166
167        match service.ready().await {
168            Ok(_) => {}
169            Err(e) => {
170                error!("Failed to construct service: {}", e);
171                return Err(ServerError::ServiceConstructionFailed);
172            }
173        };
174
175        let response = service.call((request, payload)).await.unwrap_or_else(|e| {
176            error!("Failed to process request: {}", e);
177            e.error_response()
178        });
179
180        Self::send_response(reader, response).await
181    }
182
183    #[cfg_attr(feature = "trace", tracing::instrument(level = "trace", skip(reader)))]
184    /// Sends the response back to the client
185    async fn send_response(
186        reader: BufReader<tokio::net::TcpStream>,
187        mut response: HttpResponse,
188    ) -> Result<(), ServerError> {
189        let content_length = response
190            .body
191            .clone()
192            .try_into_bytes()
193            .unwrap_or_default()
194            .len() as u64;
195
196        Self::insert_content_length(response.headers_mut(), content_length);
197
198        response
199            .headers_mut()
200            .insert("Connection", "close".parse()?);
201
202        let response_bytes = response.to_bytes()?;
203
204        let mut stream = reader.into_inner();
205        stream.write_all(&response_bytes).await?;
206        stream.flush().await?;
207
208        Ok(())
209    }
210
211    fn insert_content_length(headers: &mut HeaderMap, content_length: u64) {
212        headers.insert(
213            "Content-Length",
214            content_length.to_string().parse().unwrap(),
215        );
216    }
217
218    #[cfg_attr(feature = "trace", tracing::instrument(level = "trace", skip(reader)))]
219    /// Reads the request body and returns it as a vector of bytes
220    async fn read_request(
221        reader: &mut BufReader<tokio::net::TcpStream>,
222    ) -> Result<Vec<u8>, ServerError> {
223        let mut request_buffer = Vec::new();
224        let mut headers_read = false;
225        let mut content_length = 0;
226
227        loop {
228            let mut buf = [0; 1024];
229            let n = reader.read(&mut buf).await?;
230
231            if n == 0 {
232                debug!("Connection closed by the client.");
233                return Err(ServerError::ConnectionClosed);
234            }
235
236            request_buffer.extend_from_slice(&buf[..n]);
237
238            if !headers_read {
239                if let Some(headers_end) = Self::find_headers_end(&request_buffer) {
240                    headers_read = true;
241
242                    let headers = &request_buffer[..headers_end];
243                    let headers_str = String::from_utf8_lossy(headers);
244
245                    for line in headers_str.lines() {
246                        if line.to_lowercase().starts_with("content-length:") {
247                            if let Some(length_str) = line.split(':').nth(1) {
248                                content_length = length_str.trim().parse::<usize>().unwrap_or(0);
249                            }
250                        }
251                    }
252
253                    let body_bytes_read = request_buffer.len() - headers_end;
254                    if body_bytes_read >= content_length {
255                        break;
256                    }
257                }
258            } else {
259                let total_bytes = request_buffer.len();
260                let headers_end = Self::find_headers_end(&request_buffer).unwrap_or(0);
261                let body_bytes_read = total_bytes - headers_end;
262
263                if body_bytes_read >= content_length {
264                    break;
265                }
266            }
267        }
268
269        Ok(request_buffer)
270    }
271
272    #[inline]
273    /// Find the end of the request headers
274    fn find_headers_end(buffer: &[u8]) -> Option<usize> {
275        buffer
276            .windows(4)
277            .position(|window| window == b"\r\n\r\n")
278            .map(|pos| pos + 4)
279    }
280}