rust_integration_services/http/
http_receiver.rs

1use std::{collections::HashMap, convert::Infallible, path::Path, pin::Pin, sync::Arc};
2
3use http_body_util::{BodyExt, Full};
4use hyper::{body::{Bytes, Incoming}, header::{HeaderName, HeaderValue}, service::service_fn, Request, Response};
5use hyper_util::rt::TokioIo;
6use rustls::ServerConfig;
7use tokio::{net::TcpListener, net::TcpStream, signal::unix::{signal, SignalKind}, sync::mpsc::{self, Sender}, task::JoinSet};
8use tokio_rustls::TlsAcceptor;
9use uuid::Uuid;
10
11use crate::{http::{http_executor::HttpExecutor, http_method::HttpMethod, http_request::HttpRequest, http_response::HttpResponse}, utils::{crypto::Crypto, result::ResultDyn}};
12
13type RouteCallback = Arc<dyn Fn(String, HttpRequest) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>> + Send + Sync>;
14
15#[derive(Clone)]
16pub enum HttpReceiverEventSignal {
17    OnConnectionOpened(String, String),
18    OnRequest(String, HttpRequest),
19    OnResponse(String, HttpResponse),
20    OnConnectionFailed(String, String),
21}
22
23pub struct HttpReceiver {
24    host: String,
25    routes: HashMap<String, RouteCallback>,
26    event_broadcast: mpsc::Sender<HttpReceiverEventSignal>,
27    event_receiver: Option<mpsc::Receiver<HttpReceiverEventSignal>>,
28    event_join_set: JoinSet<()>,
29    tls_config: Option<ServerConfig>,
30}
31
32impl HttpReceiver {
33    /// Creates a new `HttpReceiver` instance bound to the specified host address. Example: `127.0.0.1:8080`.
34    pub fn new<T: AsRef<str>>(host: T) -> Self {
35        let (event_broadcast, event_receiver) = mpsc::channel(128);
36        HttpReceiver {
37            host: host.as_ref().to_string(),
38            routes: HashMap::new(),
39            event_broadcast,
40            event_receiver: Some(event_receiver),
41            event_join_set: JoinSet::new(),
42            tls_config: None,
43        }
44    }
45
46    /// Enables TLS for incoming connections using the provided server certificate and private key in `.pem` format and
47    /// configures the TLS context and sets supported ALPN protocols to allow HTTP/2 and HTTP/1.1.
48    pub fn tls<T: AsRef<Path>>(mut self, tls_server_cert_path: T, tls_server_key_path: T) -> Self {
49        let mut tls_config = Self::create_tls_config(tls_server_cert_path, tls_server_key_path).expect("Failed to create TLS config");
50        tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
51
52        self.tls_config = Some(tls_config);
53        self
54    }
55
56    fn create_tls_config<T: AsRef<Path>>(cert_path: T, key_path: T) -> ResultDyn<ServerConfig> {
57        let certs = Crypto::pem_load_certs(cert_path)?;
58        let key = Crypto::pem_load_private_key(key_path)?;
59        let config = ServerConfig::builder()
60            .with_no_client_auth()
61            .with_single_cert(certs, key)?;
62
63        Ok(config)
64    }
65
66    /// Registers a route with a specific HTTP method and path, associating it with a handler callback.
67    pub fn route<T, Fut, S>(mut self, method: S, path: S, callback: T) -> Self
68    where
69        T: Fn(String, HttpRequest) -> Fut + Send + Sync + 'static,
70        Fut: Future<Output = HttpResponse> + Send + 'static,
71        S: AsRef<str>,
72    {
73        self.routes.insert(format!("{}|{}", method.as_ref().to_uppercase(), path.as_ref()), Arc::new(move |uuid, request| Box::pin(callback(uuid, request))));
74        self
75    }
76
77    /// Registers an asynchronous event handler callback for incoming `HttpReceiverEventSignal`s.
78    ///
79    /// This sets up a background task that listens for system signals (SIGTERM, SIGINT)
80    /// and incoming events from the internal event channel.
81    pub fn on_event<T, Fut>(mut self, handler: T) -> Self
82    where
83        T: Fn(HttpReceiverEventSignal) -> Fut + Send + Sync + 'static,
84        Fut: Future<Output = ()> + Send + 'static,
85    {
86        let mut receiver = self.event_receiver.unwrap();
87        let mut sigterm = signal(SignalKind::terminate()).expect("Failed to start SIGTERM signal receiver");
88        let mut sigint = signal(SignalKind::interrupt()).expect("Failed to start SIGINT signal receiver");
89        
90        self.event_join_set.spawn(async move {
91            loop {
92                tokio::select! {
93                    _ = sigterm.recv() => break,
94                    _ = sigint.recv() => break,
95                    event = receiver.recv() => {
96                        match event {
97                            Some(event) => handler(event).await,
98                            None => break,
99                        }
100                    }
101                }
102            }
103        });
104        
105        self.event_receiver = None;
106        self
107    }
108
109    async fn incoming_request(req: Request<Incoming>, uuid: String, routes: Arc<HashMap<String, RouteCallback>>, event_broadcast: Arc<Sender<HttpReceiverEventSignal>>) -> Result<Response<Full<Bytes>>, Infallible> {
110        let request = Self::build_http_request(req).await;
111        event_broadcast.send(HttpReceiverEventSignal::OnRequest(uuid.clone(), request.clone())).await.unwrap();
112
113        match routes.get(&format!("{}|{}", &request.method.as_str(), &request.path)) {
114            Some(callback) => {
115                let response = callback(uuid.clone(), request.clone()).await;
116                let res = Self::build_http_response(response.clone()).await;
117                event_broadcast.send(HttpReceiverEventSignal::OnResponse(uuid.clone(), response)).await.unwrap();
118                Ok(res)
119            },
120            None => {
121                let response = HttpResponse::not_found();
122                let res = Self::build_http_response(response.clone()).await;
123                event_broadcast.send(HttpReceiverEventSignal::OnResponse(uuid.clone(), response)).await.unwrap();
124                Ok(res)
125            },
126        }
127    }
128
129    /// Starts the HTTP server and begins listening for incoming TCP connections (optionally over TLS).
130    ///
131    /// This method binds to the configured host address and enters a loop to accept new TCP connections.
132    /// It also listens for system termination signals (SIGINT, SIGTERM) to gracefully shut down the server.
133    pub async fn receive(mut self) {
134        let tls_acceptor = self.tls_config.map(|tls_config| {
135            TlsAcceptor::from(Arc::new(tls_config))
136        });
137        let listener = TcpListener::bind(&self.host).await.expect("Failed to start TCP Listener");
138        let mut sigterm = signal(SignalKind::terminate()).expect("Failed to start SIGTERM signal receiver");
139        let mut sigint = signal(SignalKind::interrupt()).expect("Failed to start SIGINT signal receiver");
140        let mut join_set = JoinSet::new();
141        
142        loop {
143            tokio::select! {
144                _ = sigterm.recv() => break,
145                _ = sigint.recv() => break,
146                result = listener.accept() => {
147                    let (tcp_stream, client_addr) = result.unwrap();
148                    let uuid = Uuid::new_v4().to_string();
149                    let event_broadcast = Arc::new(self.event_broadcast.clone());
150                    let tls_acceptor = tls_acceptor.clone();
151                    let routes = Arc::new(self.routes.clone());
152                    
153                    event_broadcast.send(HttpReceiverEventSignal::OnConnectionOpened(uuid.clone(), client_addr.ip().to_string())).await.unwrap();
154                    match tls_acceptor {
155                        Some(acceptor) => {
156                            join_set.spawn(Self::tls_connection(acceptor, tcp_stream, uuid, routes, event_broadcast));
157                        },
158                        None => {
159                            join_set.spawn(Self::tcp_connection(tcp_stream, uuid, routes, event_broadcast));
160                        },
161                    }
162                }
163            }
164        }
165
166        while let Some(_) = join_set.join_next().await {}
167        while let Some(_) = self.event_join_set.join_next().await {}
168    }
169
170    async fn tcp_connection(tcp_stream: TcpStream, uuid: String, routes: Arc<HashMap<String, RouteCallback>>, event_broadcast: Arc<Sender<HttpReceiverEventSignal>>) {
171        let uuid_clone = uuid.clone();
172        let event_broadcast_clone = event_broadcast.clone();
173        let io = TokioIo::new(tcp_stream);
174        let service = service_fn(move |req| {
175            Self::incoming_request(req, uuid_clone.to_owned(), routes.clone(), event_broadcast_clone.to_owned())
176        });
177        
178        if let Err(err) = hyper::server::conn::http1::Builder::new().serve_connection(io, service).await {
179            event_broadcast.send(HttpReceiverEventSignal::OnConnectionFailed(uuid, err.to_string())).await.unwrap();
180        }
181    }
182
183    async fn tls_connection(tls_acceptor: TlsAcceptor, tcp_stream: TcpStream, uuid: String, routes: Arc<HashMap<String, RouteCallback>>, event_broadcast: Arc<Sender<HttpReceiverEventSignal>>) {
184        let tls_stream = match tls_acceptor.accept(tcp_stream).await {
185            Ok(stream) => stream,
186            Err(err) => {
187                event_broadcast.send(HttpReceiverEventSignal::OnConnectionFailed(uuid, format!("TLS handshake failed: {:?}", err))).await.unwrap();
188                return;
189            },
190        };
191        
192        let uuid_clone = uuid.clone();
193        let event_broadcast_clone = event_broadcast.clone();
194        let service = service_fn(move |req| {
195            Self::incoming_request(req, uuid_clone.to_owned(), routes.clone(), event_broadcast_clone.to_owned())
196        });
197        
198        let io = TokioIo::new(tls_stream);
199        let protocol = io.inner().get_ref().1.alpn_protocol();
200
201        match protocol.as_deref() {
202            Some(b"h2") => {
203                if let Err(err) = hyper::server::conn::http2::Builder::new(HttpExecutor).serve_connection(io, service).await {
204                    event_broadcast.send(HttpReceiverEventSignal::OnConnectionFailed(uuid, format!("Connection failed: {:?}", err))).await.unwrap();
205                }
206            }
207            _ => {
208                if let Err(err) = hyper::server::conn::http1::Builder::new().serve_connection(io, service).await {
209                    event_broadcast.send(HttpReceiverEventSignal::OnConnectionFailed(uuid, err.to_string())).await.unwrap();
210                }
211            }
212        }
213
214    }
215
216    async fn build_http_request(req: Request<Incoming>) -> HttpRequest {
217        let (parts, body) = req.into_parts();
218        let mut request = HttpRequest::new();
219        request.method = HttpMethod::from_str(parts.method.as_str()).unwrap();
220        request.path = parts.uri.path().to_string();
221        for (key, value) in parts.headers {
222            if let (Some(key), Ok(value)) = (key, value.to_str()) {
223                request.headers.insert(key.to_string(), value.to_string());
224            }
225        }
226        request.body = body.collect().await.unwrap().to_bytes().to_vec();
227        request
228    }
229
230    async fn build_http_response(res: HttpResponse) -> Response<Full<Bytes>> {
231        let mut response: Response<Full<Bytes>> = Response::builder().status(res.status.code()).body(res.body.into()).unwrap();
232        for (key, value) in res.headers {
233            let header_name = HeaderName::from_bytes(key.as_bytes()).unwrap();
234            let header_value = HeaderValue::from_str(&value).unwrap();
235            response.headers_mut().insert(header_name, header_value);
236        }
237        response
238    }
239}