Skip to main content

rust_integration_services/http/server/
http_server.rs

1use std::{collections::HashMap, convert::Infallible, panic::AssertUnwindSafe, pin::Pin, sync::Arc};
2
3use futures::FutureExt;
4use http_body_util::{BodyExt, combinators::BoxBody};
5use hyper::{Request, Response, body::{Bytes, Incoming}, service::service_fn};
6use hyper_util::rt::TokioIo;
7use matchit::Router;
8use tokio::{net::{TcpListener, TcpStream}, signal::unix::{signal, SignalKind}, task::JoinSet};
9use tokio_rustls::TlsAcceptor;
10
11use crate::http::{executor::Executor, http_request::HttpRequest, http_response::HttpResponse, server::http_server_config::HttpServerConfig};
12
13type RouteCallback = Arc<dyn Fn(HttpRequest) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>> + Send + Sync>;
14
15pub struct HttpServer {
16    config: HttpServerConfig,
17    router: Router<RouteCallback>,
18}
19
20impl HttpServer {
21    pub fn new(config: HttpServerConfig) -> Self {
22        HttpServer {
23            config,
24            router: Router::new(),
25        }
26    }
27
28    /// Registers a route with a path, associating it with a handler callback.
29    pub fn route<T, Fut>(mut self, path: impl Into<String>, callback: T) -> Self
30    where
31        T: Fn(HttpRequest) -> Fut + Send + Sync + 'static,
32        Fut: Future<Output = HttpResponse> + Send + 'static,
33    {
34        self.router.insert(path.into(), Arc::new(move |request| Box::pin(callback(request)))).unwrap();
35        self
36    }
37
38    /// Run the HTTP server and begins listening for incoming TCP connections (optionally over TLS).
39    ///
40    /// This method binds to the configured host address and enters a loop to accept new TCP connections.
41    /// It also listens for system termination signals (SIGINT, SIGTERM) to gracefully shut down the server.
42    pub async fn run(self) {
43        let tls_acceptor = self.config.tls_config.map(|tls_config| {
44            TlsAcceptor::from(Arc::new(tls_config))
45        });
46
47        let host = format!("{}:{}", self.config.ip, self.config.port);
48        let listener = TcpListener::bind(&host).await.expect("Failed to start TCP Listener");
49        let mut sigterm = signal(SignalKind::terminate()).expect("Failed to start SIGTERM signal receiver");
50        let mut sigint = signal(SignalKind::interrupt()).expect("Failed to start SIGINT signal receiver");
51        let mut receiver_join_set = JoinSet::new();
52        let router = Arc::new(self.router);
53        
54        tracing::trace!("Started on {}", &host);
55        loop {
56            tokio::select! {
57                _ = sigterm.recv() => {
58                    drop(listener);
59                    break;
60                },
61                _ = sigint.recv() => {
62                    drop(listener);
63                    break;
64                },
65                result = listener.accept() => {
66                    let tls_acceptor = tls_acceptor.clone();
67                    let router = router.clone();
68                    let (tcp_stream, client_addr) = match result {
69                        Ok(pair) => pair,
70                        Err(err) => {
71                            tracing::error!("{:?}", err);
72                            continue;
73                        },
74                    };
75
76                    tracing::trace!("Connection {:?}", client_addr);
77                    match tls_acceptor {
78                        Some(acceptor) => {
79                            receiver_join_set.spawn(Self::tls_connection(acceptor, tcp_stream, router));
80                        },
81                        None => {
82                            receiver_join_set.spawn(Self::tcp_connection(tcp_stream, router));
83                        },
84                    }
85                }
86            }
87        }
88
89        tracing::trace!("Shut down pending...");
90        while let Some(_) = receiver_join_set.join_next().await {}
91        tracing::trace!("Shut down complete");
92    }
93
94    async fn tcp_connection(tcp_stream: TcpStream, router: Arc<Router<RouteCallback>>) {
95        let service = {
96            let router = router.clone();
97            service_fn(move |req| {
98                Self::incoming_request(req, router.clone())
99            })
100        };
101        
102        let io = TokioIo::new(tcp_stream);
103        if let Err(err) = hyper::server::conn::http1::Builder::new().serve_connection(io, service).await {
104            tracing::error!("{:?}", err);
105        }
106    }
107
108    async fn tls_connection(tls_acceptor: TlsAcceptor, tcp_stream: TcpStream, router: Arc<Router<RouteCallback>>) {
109        let tls_stream = match tls_acceptor.accept(tcp_stream).await {
110            Ok(stream) => stream,
111            Err(err) => {
112                tracing::error!("TLS handshake failed {:?}", err);
113                return;
114            },
115        };
116        
117        let service = {
118            let router = router.clone();
119            service_fn(move |req| {
120                Self::incoming_request(req, router.clone())
121            })
122        };
123        
124        let io = TokioIo::new(tls_stream);
125        let protocol = io.inner().get_ref().1.alpn_protocol();
126        match protocol.as_deref() {
127            Some(b"h2") => {
128                if let Err(err) = hyper::server::conn::http2::Builder::new(Executor).serve_connection(io, service).await {
129                    tracing::error!("TLS handshake failed {:?}", err);
130                }
131            }
132            _ => {
133                if let Err(err) = hyper::server::conn::http1::Builder::new().keep_alive(false).serve_connection(io, service).await {
134                    tracing::error!("{:?}", err);
135                }
136            }
137        }
138    }
139
140    async fn incoming_request(request: Request<Incoming>, router: Arc<Router<RouteCallback>>) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Infallible> {
141        match router.at(&request.uri().path()) {
142            Ok(matched) => {
143                let params: HashMap<String, String> = matched.params.iter().map(|(key, value)| (key.to_string(), value.to_string())).collect();
144                let callback = matched.value;
145                let (parts, body) = request.into_parts();
146                let req = HttpRequest::from_parts_with_params(body.boxed(), parts, params);
147                let callback_fut = callback(req);
148                let result = AssertUnwindSafe(callback_fut).catch_unwind().await;
149                let response = match result {
150                    Ok(res) => res,
151                    Err(err) => {
152                        tracing::error!("{:?}", err);
153                        HttpResponse::builder().status(500).body_empty().unwrap()
154                    }
155                };
156
157                Ok(Response::from(response))
158            },
159            Err(_) => {
160                let response = HttpResponse::builder().status(404).body_empty().unwrap();
161                Ok(Response::from(response))
162            },
163        }
164    }
165}