rust_integration_services/http/
http_receiver.rs1use std::collections::HashMap;
2use std::fs::File;
3use std::io::BufReader;
4use std::path::Path;
5use std::path::PathBuf;
6use std::pin::Pin;
7use std::sync::Arc;
8use rustls::pki_types::PrivateKeyDer;
9use rustls::ServerConfig;
10use tokio::io::AsyncRead;
11use tokio::io::AsyncWrite;
12use tokio::io::AsyncWriteExt;
13use tokio::net::TcpListener;
14use tokio::net::TcpStream;
15use tokio::signal::unix::signal;
16use tokio::signal::unix::SignalKind;
17use tokio::sync::mpsc;
18use tokio::task::JoinSet;
19use tokio_rustls::TlsAcceptor;
20use uuid::Uuid;
21
22use crate::utils::error::Error;
23
24use super::http_request::HttpRequest;
25use super::http_response::HttpResponse;
26
27pub trait AsyncStream: AsyncRead + AsyncWrite + Send + Unpin {}
28impl<T: AsyncRead + AsyncWrite + Send + Unpin> AsyncStream for T {}
29
30type RouteCallback = Arc<dyn Fn(String, HttpRequest) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>> + Send + Sync>;
31
32#[derive(Clone)]
33pub enum HttpReceiverEventSignal {
34 OnConnectionReceived(String, String),
35 OnRequestSuccess(String, HttpRequest),
36 OnRequestError(String, String),
37 OnResponseSuccess(String, HttpResponse),
38 OnResponseError(String, String),
39}
40
41pub struct TlsConfig {
42 cert_path: PathBuf,
43 key_path: PathBuf,
44}
45
46pub struct HttpReceiver {
47 host: String,
48 routes: HashMap<String, RouteCallback>,
49 event_broadcast: mpsc::Sender<HttpReceiverEventSignal>,
50 event_receiver: Option<mpsc::Receiver<HttpReceiverEventSignal>>,
51 event_join_set: JoinSet<()>,
52 tls_config: Option<TlsConfig>,
53}
54
55impl HttpReceiver {
56 pub fn new<T: AsRef<str>>(host: T) -> Self {
57 let (event_broadcast, event_receiver) = mpsc::channel(128);
58 HttpReceiver {
59 host: host.as_ref().to_string(),
60 routes: HashMap::new(),
61 event_broadcast,
62 event_receiver: Some(event_receiver),
63 event_join_set: JoinSet::new(),
64 tls_config: None,
65 }
66 }
67
68 pub fn route<T, Fut, S>(mut self, method: S, route: S, callback: T) -> Self
69 where
70 T: Fn(String, HttpRequest) -> Fut + Send + Sync + 'static,
71 Fut: Future<Output = HttpResponse> + Send + 'static,
72 S: AsRef<str>,
73 {
74 self.routes.insert(format!("{}|{}", method.as_ref().to_uppercase(), route.as_ref()), Arc::new(move |uuid, request| Box::pin(callback(uuid, request))));
75 self
76 }
77
78 pub fn tls<T: AsRef<Path>>(mut self, cert_path: T, key_path: T) -> Self {
79 self.tls_config = Some(TlsConfig {
80 cert_path: cert_path.as_ref().to_path_buf(),
81 key_path: key_path.as_ref().to_path_buf(),
82 });
83 self
84 }
85
86 pub fn on_event<T, Fut>(mut self, handler: T) -> Self
87 where
88 T: Fn(HttpReceiverEventSignal) -> Fut + Send + Sync + 'static,
89 Fut: Future<Output = ()> + Send + 'static,
90 {
91 let mut receiver = self.event_receiver.unwrap();
92 let mut sigterm = signal(SignalKind::terminate()).expect("Failed to start SIGTERM signal receiver.");
93 let mut sigint = signal(SignalKind::interrupt()).expect("Failed to start SIGINT signal receiver.");
94
95 self.event_join_set.spawn(async move {
96 loop {
97 tokio::select! {
98 _ = sigterm.recv() => break,
99 _ = sigint.recv() => break,
100 event = receiver.recv() => {
101 match event {
102 Some(event) => handler(event).await,
103 None => break,
104 }
105 }
106 }
107 }
108 });
109
110 self.event_receiver = None;
111 self
112 }
113
114 pub async fn receive(mut self) -> tokio::io::Result<()> {
115 let listener = TcpListener::bind(&self.host).await?;
116 let tls_acceptor = self.tls_config.as_ref().map(|tls_cert| {
117 let config = Arc::new(Self::create_tls_config(&tls_cert.cert_path, &tls_cert.key_path).unwrap());
118 TlsAcceptor::from(config)
119 });
120
121 let routes = Arc::new(self.routes.clone());
122 let mut join_set_main = JoinSet::new();
123 let mut sigterm = signal(SignalKind::terminate())?;
124 let mut sigint = signal(SignalKind::interrupt())?;
125
126 loop {
127 tokio::select! {
128 _ = sigterm.recv() => break,
129 _ = sigint.recv() => break,
130 result = listener.accept() => {
131 let (mut stream, client_addr) = result.unwrap();
132 let routes = Arc::clone(&routes);
133 let event_broadcast = Arc::new(self.event_broadcast.clone());
134 let tls_acceptor = tls_acceptor.clone();
135 let uuid = Uuid::new_v4().to_string();
136
137 event_broadcast.send(HttpReceiverEventSignal::OnConnectionReceived(uuid.clone(), client_addr.ip().to_string())).await.unwrap();
138 join_set_main.spawn(async move {
139 let mut stream: Box<dyn AsyncStream> = match tls_acceptor {
140 Some(acceptor) => {
141 match Self::is_connection_tls(&stream).await {
142 Ok(_) => {},
143 Err(err) => {
144 event_broadcast.send(HttpReceiverEventSignal::OnRequestError(uuid.clone(), err.to_string())).await.unwrap();
145 let response = HttpResponse::internal_server_error();
146 match stream.write_all(&response.to_bytes()).await {
147 Ok(_) => event_broadcast.send(HttpReceiverEventSignal::OnResponseSuccess(uuid.clone(), response.clone())).await.unwrap(),
148 Err(err) => event_broadcast.send(HttpReceiverEventSignal::OnResponseError(uuid.clone(), err.to_string())).await.unwrap(),
149 };
150 return;
151 },
152 };
153
154 match acceptor.accept(&mut stream).await {
155 Ok(tls_stream) => Box::new(tls_stream),
156 Err(err) => {
157 let err = format!("TLS handshake failed: {}", err.to_string());
158 event_broadcast.send(HttpReceiverEventSignal::OnRequestError(uuid.clone(), err.to_string())).await.unwrap();
159 let response = HttpResponse::internal_server_error();
160 match stream.write_all(&response.to_bytes()).await {
161 Ok(_) => event_broadcast.send(HttpReceiverEventSignal::OnResponseSuccess(uuid.clone(), response.clone())).await.unwrap(),
162 Err(err) => event_broadcast.send(HttpReceiverEventSignal::OnResponseError(uuid.clone(), err.to_string())).await.unwrap(),
163 };
164 return;
165 },
166 }
167 },
168 None => Box::new(stream),
169 };
170
171 let request = match HttpRequest::from_stream(&mut stream).await {
172 Ok(request) => request.ip(client_addr.ip().to_string()),
173 Err(err) => {
174 event_broadcast.send(HttpReceiverEventSignal::OnRequestError(uuid.clone(), err.to_string())).await.unwrap();
175 let response = HttpResponse::internal_server_error();
176 match stream.write_all(&response.to_bytes()).await {
177 Ok(_) => event_broadcast.send(HttpReceiverEventSignal::OnResponseSuccess(uuid.clone(), response.clone())).await.unwrap(),
178 Err(err) => event_broadcast.send(HttpReceiverEventSignal::OnResponseError(uuid.clone(), err.to_string())).await.unwrap(),
179 };
180 return;
181 }
182 };
183
184 match routes.get(&format!("{}|{}", &request.method, &request.path)) {
185 None => {
186 event_broadcast.send(HttpReceiverEventSignal::OnRequestSuccess(uuid.clone(), request.clone())).await.unwrap();
187 let response = HttpResponse::not_found();
188 match stream.write_all(&response.to_bytes()).await {
189 Ok(_) => event_broadcast.send(HttpReceiverEventSignal::OnResponseSuccess(uuid.clone(), response.clone())).await.unwrap(),
190 Err(err) => event_broadcast.send(HttpReceiverEventSignal::OnResponseError(uuid.clone(), err.to_string())).await.unwrap(),
191 };
192 },
193 Some(callback) => {
194 event_broadcast.send(HttpReceiverEventSignal::OnRequestSuccess(uuid.clone(), request.clone())).await.unwrap();
195 let mut response = callback(uuid.clone(), request).await;
196 if !response.body.is_empty() {
197 response.headers.insert(String::from("Content-Length"), response.body.len().to_string());
198 }
199 match stream.write_all(&response.to_bytes()).await {
200 Ok(_) => event_broadcast.send(HttpReceiverEventSignal::OnResponseSuccess(uuid.clone(), response.clone())).await.unwrap(),
201 Err(err) => event_broadcast.send(HttpReceiverEventSignal::OnResponseError(uuid.clone(), err.to_string())).await.unwrap(),
202 };
203 }
204 }
205 });
206 }
207 }
208 }
209
210 while let Some(_) = join_set_main.join_next().await {}
211 while let Some(_) = self.event_join_set.join_next().await {}
212
213 Ok(())
214 }
215
216 async fn is_connection_tls(stream: &TcpStream) -> tokio::io::Result<()> {
217 let mut peek_buffer = [0u8; 8];
218 match stream.peek(&mut peek_buffer).await {
219 Ok(len) if len >= 3 => {
220 let is_tls_client_sig = peek_buffer[0] == 0x16 && peek_buffer[1] == 0x03 && (0x01..=0x03).contains(&peek_buffer[2]);
222 if is_tls_client_sig {
223 return Ok(())
224 }
225 Err(Error::tokio_io("Non-TLS request on TLS receiver."))
226 },
227 Ok(_) => Err(Error::tokio_io("Could not determine TLS signature.")),
228 Err(err) => Err(err),
229 }
230 }
231
232 fn create_tls_config<T: AsRef<Path>>(cert_path: T, key_path: T) -> std::io::Result<ServerConfig> {
233 let cert_file = File::open(cert_path)?;
234 let mut cert_reader = BufReader::new(cert_file);
235 let certs = rustls_pemfile::certs(&mut cert_reader)
236 .collect::<Result<Vec<_>, _>>()
237 .map_err(|_| Error::std_io("Invalid certificate"))?;
238
239 let key_file = File::open(key_path)?;
240 let mut key_reader = BufReader::new(key_file);
241 let mut keys = rustls_pemfile::pkcs8_private_keys(&mut key_reader)
242 .collect::<Result<Vec<_>, _>>()
243 .map_err(|_| Error::std_io("Invalid private key"))?;
244
245 let key = keys.pop().unwrap();
246 let config = ServerConfig::builder()
247 .with_no_client_auth()
248 .with_single_cert(certs, PrivateKeyDer::Pkcs8(key))
249 .map_err(|err| Error::std_io(err.to_string()))?;
250
251 Ok(config)
252 }
253}