rust_integration_services/http/
http_receiver.rs

1use std::collections::HashMap;
2use std::fs::File;
3use std::io::BufReader;
4use std::path::Path;
5use std::path::PathBuf;
6use std::pin::Pin;
7use std::sync::Arc;
8use rustls::pki_types::PrivateKeyDer;
9use rustls::ServerConfig;
10use tokio::io::AsyncRead;
11use tokio::io::AsyncWrite;
12use tokio::io::AsyncWriteExt;
13use tokio::net::TcpListener;
14use tokio::net::TcpStream;
15use tokio::signal::unix::signal;
16use tokio::signal::unix::SignalKind;
17use tokio::sync::mpsc;
18use tokio::task::JoinSet;
19use tokio_rustls::TlsAcceptor;
20use uuid::Uuid;
21
22use crate::utils::error::Error;
23
24use super::http_request::HttpRequest;
25use super::http_response::HttpResponse;
26
27pub trait AsyncStream: AsyncRead + AsyncWrite + Send + Unpin {}
28impl<T: AsyncRead + AsyncWrite + Send + Unpin> AsyncStream for T {}
29
30type RouteCallback = Arc<dyn Fn(String, HttpRequest) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>> + Send + Sync>;
31
32#[derive(Clone)]
33pub enum HttpReceiverEventSignal {
34    OnConnectionReceived(String, String),
35    OnRequestSuccess(String, HttpRequest),
36    OnRequestError(String, String),
37    OnResponseSuccess(String, HttpResponse),
38    OnResponseError(String, String),
39}
40
41pub struct TlsConfig {
42    cert_path: PathBuf,
43    key_path: PathBuf,
44}
45
46pub struct HttpReceiver {
47    host: String,
48    routes: HashMap<String, RouteCallback>,
49    event_broadcast: mpsc::Sender<HttpReceiverEventSignal>,
50    event_receiver: Option<mpsc::Receiver<HttpReceiverEventSignal>>,
51    event_join_set: JoinSet<()>,
52    tls_config: Option<TlsConfig>,
53}
54
55impl HttpReceiver {
56    pub fn new<T: AsRef<str>>(host: T) -> Self {
57        let (event_broadcast, event_receiver) = mpsc::channel(128);
58        HttpReceiver {
59            host: host.as_ref().to_string(),
60            routes: HashMap::new(),
61            event_broadcast,
62            event_receiver: Some(event_receiver),
63            event_join_set: JoinSet::new(),
64            tls_config: None,
65        }
66    }
67
68    pub fn route<T, Fut, S>(mut self, method: S, route: S, callback: T) -> Self
69    where
70        T: Fn(String, HttpRequest) -> Fut + Send + Sync + 'static,
71        Fut: Future<Output = HttpResponse> + Send + 'static,
72        S: AsRef<str>,
73    {
74        self.routes.insert(format!("{}|{}", method.as_ref().to_uppercase(), route.as_ref()), Arc::new(move |uuid, request| Box::pin(callback(uuid, request))));
75        self
76    }
77
78    pub fn tls<T: AsRef<Path>>(mut self, cert_path: T, key_path: T) -> Self {
79        self.tls_config = Some(TlsConfig {
80            cert_path: cert_path.as_ref().to_path_buf(),
81            key_path: key_path.as_ref().to_path_buf(),
82        });
83        self
84    }
85
86    pub fn on_event<T, Fut>(mut self, handler: T) -> Self
87    where
88        T: Fn(HttpReceiverEventSignal) -> Fut + Send + Sync + 'static,
89        Fut: Future<Output = ()> + Send + 'static,
90    {
91        let mut receiver = self.event_receiver.unwrap();
92        let mut sigterm = signal(SignalKind::terminate()).expect("Failed to start SIGTERM signal receiver.");
93        let mut sigint = signal(SignalKind::interrupt()).expect("Failed to start SIGINT signal receiver.");
94        
95        self.event_join_set.spawn(async move {
96            loop {
97                tokio::select! {
98                    _ = sigterm.recv() => break,
99                    _ = sigint.recv() => break,
100                    event = receiver.recv() => {
101                        match event {
102                            Some(event) => handler(event).await,
103                            None => break,
104                        }
105                    }
106                }
107            }
108        });
109        
110        self.event_receiver = None;
111        self
112    }
113
114    pub async fn receive(mut self) -> tokio::io::Result<()> {
115        let listener = TcpListener::bind(&self.host).await?;
116        let tls_acceptor = self.tls_config.as_ref().map(|tls_cert| {
117            let config = Arc::new(Self::create_tls_config(&tls_cert.cert_path, &tls_cert.key_path).unwrap());
118            TlsAcceptor::from(config)
119        });
120
121        let routes = Arc::new(self.routes.clone());
122        let mut join_set_main = JoinSet::new();
123        let mut sigterm = signal(SignalKind::terminate())?;
124        let mut sigint = signal(SignalKind::interrupt())?;
125        
126        loop {
127            tokio::select! {
128                _ = sigterm.recv() => break,
129                _ = sigint.recv() => break,
130                result = listener.accept() => {
131                    let (mut stream, client_addr) = result.unwrap();
132                    let routes = Arc::clone(&routes);
133                    let event_broadcast = Arc::new(self.event_broadcast.clone());
134                    let tls_acceptor = tls_acceptor.clone();
135                    let uuid = Uuid::new_v4().to_string();
136
137                    event_broadcast.send(HttpReceiverEventSignal::OnConnectionReceived(uuid.clone(), client_addr.ip().to_string())).await.unwrap();
138                    join_set_main.spawn(async move {
139                        let mut stream: Box<dyn AsyncStream> = match tls_acceptor {
140                            Some(acceptor) => {
141                                match Self::is_connection_tls(&stream).await {
142                                    Ok(_) => {},
143                                    Err(err) => {
144                                        event_broadcast.send(HttpReceiverEventSignal::OnRequestError(uuid.clone(), err.to_string())).await.unwrap();
145                                        let response = HttpResponse::internal_server_error();
146                                        match stream.write_all(&response.to_bytes()).await {
147                                            Ok(_) => event_broadcast.send(HttpReceiverEventSignal::OnResponseSuccess(uuid.clone(), response.clone())).await.unwrap(),
148                                            Err(err) => event_broadcast.send(HttpReceiverEventSignal::OnResponseError(uuid.clone(), err.to_string())).await.unwrap(),
149                                        };
150                                        return;
151                                    },
152                                };
153
154                                match acceptor.accept(&mut stream).await {
155                                    Ok(tls_stream) => Box::new(tls_stream),
156                                    Err(err) => {
157                                        let err = format!("TLS handshake failed: {}", err.to_string());
158                                        event_broadcast.send(HttpReceiverEventSignal::OnRequestError(uuid.clone(), err.to_string())).await.unwrap();
159                                        let response = HttpResponse::internal_server_error();
160                                        match stream.write_all(&response.to_bytes()).await {
161                                            Ok(_) => event_broadcast.send(HttpReceiverEventSignal::OnResponseSuccess(uuid.clone(), response.clone())).await.unwrap(),
162                                            Err(err) => event_broadcast.send(HttpReceiverEventSignal::OnResponseError(uuid.clone(), err.to_string())).await.unwrap(),
163                                        };
164                                        return;
165                                    },
166                                }
167                            },
168                            None => Box::new(stream),
169                        };
170
171                        let request = match HttpRequest::from_stream(&mut stream).await {
172                            Ok(request) => request.ip(client_addr.ip().to_string()),
173                            Err(err) => {
174                                event_broadcast.send(HttpReceiverEventSignal::OnRequestError(uuid.clone(), err.to_string())).await.unwrap();
175                                let response = HttpResponse::internal_server_error();
176                                match stream.write_all(&response.to_bytes()).await {
177                                    Ok(_) => event_broadcast.send(HttpReceiverEventSignal::OnResponseSuccess(uuid.clone(), response.clone())).await.unwrap(),
178                                    Err(err) => event_broadcast.send(HttpReceiverEventSignal::OnResponseError(uuid.clone(), err.to_string())).await.unwrap(),
179                                };
180                                return;
181                            }
182                        };
183
184                        match routes.get(&format!("{}|{}", &request.method, &request.path)) {
185                            None => {
186                                event_broadcast.send(HttpReceiverEventSignal::OnRequestSuccess(uuid.clone(), request.clone())).await.unwrap();
187                                let response = HttpResponse::not_found();
188                                match stream.write_all(&response.to_bytes()).await {
189                                    Ok(_) => event_broadcast.send(HttpReceiverEventSignal::OnResponseSuccess(uuid.clone(), response.clone())).await.unwrap(),
190                                    Err(err) => event_broadcast.send(HttpReceiverEventSignal::OnResponseError(uuid.clone(), err.to_string())).await.unwrap(),
191                                };
192                            },
193                            Some(callback) => {
194                                event_broadcast.send(HttpReceiverEventSignal::OnRequestSuccess(uuid.clone(), request.clone())).await.unwrap();
195                                let mut response = callback(uuid.clone(), request).await;
196                                if !response.body.is_empty() {
197                                    response.headers.insert(String::from("Content-Length"), response.body.len().to_string());
198                                }
199                                match stream.write_all(&response.to_bytes()).await {
200                                    Ok(_) => event_broadcast.send(HttpReceiverEventSignal::OnResponseSuccess(uuid.clone(), response.clone())).await.unwrap(),
201                                    Err(err) => event_broadcast.send(HttpReceiverEventSignal::OnResponseError(uuid.clone(), err.to_string())).await.unwrap(),
202                                };
203                            }
204                        }
205                    });
206                }
207            }
208        }
209
210        while let Some(_) = join_set_main.join_next().await {}
211        while let Some(_) = self.event_join_set.join_next().await {}
212
213        Ok(())
214    }
215
216    async fn is_connection_tls(stream: &TcpStream) -> tokio::io::Result<()> {
217        let mut peek_buffer = [0u8; 8];
218        match stream.peek(&mut peek_buffer).await {
219            Ok(len) if len >= 3 => {
220                // Check for TLS ClientHello Signature.
221                let is_tls_client_sig = peek_buffer[0] == 0x16 && peek_buffer[1] == 0x03 && (0x01..=0x03).contains(&peek_buffer[2]);
222                if is_tls_client_sig {
223                    return Ok(())
224                }
225                Err(Error::tokio_io("Non-TLS request on TLS receiver."))
226            },
227            Ok(_) => Err(Error::tokio_io("Could not determine TLS signature.")),
228            Err(err) => Err(err),
229        }
230    }
231    
232    fn create_tls_config<T: AsRef<Path>>(cert_path: T, key_path: T) -> std::io::Result<ServerConfig> {
233        let cert_file = File::open(cert_path)?;
234        let mut cert_reader = BufReader::new(cert_file);
235        let certs = rustls_pemfile::certs(&mut cert_reader)
236            .collect::<Result<Vec<_>, _>>()
237            .map_err(|_| Error::std_io("Invalid certificate"))?;
238
239        let key_file = File::open(key_path)?;
240        let mut key_reader = BufReader::new(key_file);
241        let mut keys = rustls_pemfile::pkcs8_private_keys(&mut key_reader)
242            .collect::<Result<Vec<_>, _>>()
243            .map_err(|_| Error::std_io("Invalid private key"))?;
244
245        let key = keys.pop().unwrap();
246        let config = ServerConfig::builder()
247            .with_no_client_auth()
248            .with_single_cert(certs, PrivateKeyDer::Pkcs8(key))
249            .map_err(|err| Error::std_io(err.to_string()))?;
250
251        Ok(config)
252    }
253}