rust_integration_services/http/
http_receiver.rs1use std::{collections::HashMap, convert::Infallible, path::Path, pin::Pin, sync::Arc};
2
3use http_body_util::{BodyExt, Full};
4use hyper::{body::{Bytes, Incoming}, header::{HeaderName, HeaderValue}, service::service_fn, Request, Response};
5use hyper_util::rt::TokioIo;
6use rustls::ServerConfig;
7use tokio::{net::TcpListener, net::TcpStream, signal::unix::{signal, SignalKind}, sync::mpsc::{self, Sender}, task::JoinSet};
8use tokio_rustls::TlsAcceptor;
9use uuid::Uuid;
10
11use crate::{http::{http_executor::HttpExecutor, http_method::HttpMethod, http_request::HttpRequest, http_response::HttpResponse}, utils::{crypto::Crypto, result::ResultDyn}};
12
13type RouteCallback = Arc<dyn Fn(String, HttpRequest) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>> + Send + Sync>;
14
15#[derive(Clone)]
16pub enum HttpReceiverEventSignal {
17 OnConnectionOpened(String, String),
18 OnRequest(String, HttpRequest),
19 OnResponse(String, HttpResponse),
20 OnConnectionFailed(String, String),
21}
22
23pub struct HttpReceiver {
24 host: String,
25 routes: HashMap<String, RouteCallback>,
26 event_broadcast: mpsc::Sender<HttpReceiverEventSignal>,
27 event_receiver: Option<mpsc::Receiver<HttpReceiverEventSignal>>,
28 event_join_set: JoinSet<()>,
29 tls_config: Option<ServerConfig>,
30}
31
32impl HttpReceiver {
33 pub fn new<T: AsRef<str>>(host: T) -> Self {
35 let (event_broadcast, event_receiver) = mpsc::channel(128);
36 HttpReceiver {
37 host: host.as_ref().to_string(),
38 routes: HashMap::new(),
39 event_broadcast,
40 event_receiver: Some(event_receiver),
41 event_join_set: JoinSet::new(),
42 tls_config: None,
43 }
44 }
45
46 pub fn tls<T: AsRef<Path>>(mut self, tls_server_cert_path: T, tls_server_key_path: T) -> Self {
49 let mut tls_config = Self::create_tls_config(tls_server_cert_path, tls_server_key_path).expect("Failed to create TLS config");
50 tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
51
52 self.tls_config = Some(tls_config);
53 self
54 }
55
56 fn create_tls_config<T: AsRef<Path>>(cert_path: T, key_path: T) -> ResultDyn<ServerConfig> {
57 let certs = Crypto::pem_load_certs(cert_path)?;
58 let key = Crypto::pem_load_private_key(key_path)?;
59 let config = ServerConfig::builder()
60 .with_no_client_auth()
61 .with_single_cert(certs, key)?;
62
63 Ok(config)
64 }
65
66 pub fn route<T, Fut, S>(mut self, method: S, path: S, callback: T) -> Self
68 where
69 T: Fn(String, HttpRequest) -> Fut + Send + Sync + 'static,
70 Fut: Future<Output = HttpResponse> + Send + 'static,
71 S: AsRef<str>,
72 {
73 self.routes.insert(format!("{}|{}", method.as_ref().to_uppercase(), path.as_ref()), Arc::new(move |uuid, request| Box::pin(callback(uuid, request))));
74 self
75 }
76
77 pub fn on_event<T, Fut>(mut self, handler: T) -> Self
82 where
83 T: Fn(HttpReceiverEventSignal) -> Fut + Send + Sync + 'static,
84 Fut: Future<Output = ()> + Send + 'static,
85 {
86 let mut receiver = self.event_receiver.unwrap();
87 let mut sigterm = signal(SignalKind::terminate()).expect("Failed to start SIGTERM signal receiver");
88 let mut sigint = signal(SignalKind::interrupt()).expect("Failed to start SIGINT signal receiver");
89
90 self.event_join_set.spawn(async move {
91 loop {
92 tokio::select! {
93 _ = sigterm.recv() => break,
94 _ = sigint.recv() => break,
95 event = receiver.recv() => {
96 match event {
97 Some(event) => handler(event).await,
98 None => break,
99 }
100 }
101 }
102 }
103 });
104
105 self.event_receiver = None;
106 self
107 }
108
109 async fn incoming_request(req: Request<Incoming>, uuid: String, routes: Arc<HashMap<String, RouteCallback>>, event_broadcast: Arc<Sender<HttpReceiverEventSignal>>) -> Result<Response<Full<Bytes>>, Infallible> {
110 let request = Self::build_http_request(req).await;
111 event_broadcast.send(HttpReceiverEventSignal::OnRequest(uuid.clone(), request.clone())).await.unwrap();
112
113 match routes.get(&format!("{}|{}", &request.method.as_str(), &request.path)) {
114 Some(callback) => {
115 let response = callback(uuid.clone(), request.clone()).await;
116 let res = Self::build_http_response(response.clone()).await;
117 event_broadcast.send(HttpReceiverEventSignal::OnResponse(uuid.clone(), response)).await.unwrap();
118 Ok(res)
119 },
120 None => {
121 let response = HttpResponse::not_found();
122 let res = Self::build_http_response(response.clone()).await;
123 event_broadcast.send(HttpReceiverEventSignal::OnResponse(uuid.clone(), response)).await.unwrap();
124 Ok(res)
125 },
126 }
127 }
128
129 pub async fn receive(mut self) {
134 let tls_acceptor = self.tls_config.map(|tls_config| {
135 TlsAcceptor::from(Arc::new(tls_config))
136 });
137 let listener = TcpListener::bind(&self.host).await.expect("Failed to start TCP Listener");
138 let mut sigterm = signal(SignalKind::terminate()).expect("Failed to start SIGTERM signal receiver");
139 let mut sigint = signal(SignalKind::interrupt()).expect("Failed to start SIGINT signal receiver");
140 let mut join_set = JoinSet::new();
141
142 loop {
143 tokio::select! {
144 _ = sigterm.recv() => break,
145 _ = sigint.recv() => break,
146 result = listener.accept() => {
147 let (tcp_stream, client_addr) = result.unwrap();
148 let uuid = Uuid::new_v4().to_string();
149 let event_broadcast = Arc::new(self.event_broadcast.clone());
150 let tls_acceptor = tls_acceptor.clone();
151 let routes = Arc::new(self.routes.clone());
152
153 event_broadcast.send(HttpReceiverEventSignal::OnConnectionOpened(uuid.clone(), client_addr.ip().to_string())).await.unwrap();
154 match tls_acceptor {
155 Some(acceptor) => {
156 join_set.spawn(Self::tls_connection(acceptor, tcp_stream, uuid, routes, event_broadcast));
157 },
158 None => {
159 join_set.spawn(Self::tcp_connection(tcp_stream, uuid, routes, event_broadcast));
160 },
161 }
162 }
163 }
164 }
165
166 while let Some(_) = join_set.join_next().await {}
167 while let Some(_) = self.event_join_set.join_next().await {}
168 }
169
170 async fn tcp_connection(tcp_stream: TcpStream, uuid: String, routes: Arc<HashMap<String, RouteCallback>>, event_broadcast: Arc<Sender<HttpReceiverEventSignal>>) {
171 let uuid_clone = uuid.clone();
172 let event_broadcast_clone = event_broadcast.clone();
173 let io = TokioIo::new(tcp_stream);
174 let service = service_fn(move |req| {
175 Self::incoming_request(req, uuid_clone.to_owned(), routes.clone(), event_broadcast_clone.to_owned())
176 });
177
178 if let Err(err) = hyper::server::conn::http1::Builder::new().serve_connection(io, service).await {
179 event_broadcast.send(HttpReceiverEventSignal::OnConnectionFailed(uuid, err.to_string())).await.unwrap();
180 }
181 }
182
183 async fn tls_connection(tls_acceptor: TlsAcceptor, tcp_stream: TcpStream, uuid: String, routes: Arc<HashMap<String, RouteCallback>>, event_broadcast: Arc<Sender<HttpReceiverEventSignal>>) {
184 let tls_stream = match tls_acceptor.accept(tcp_stream).await {
185 Ok(stream) => stream,
186 Err(err) => {
187 event_broadcast.send(HttpReceiverEventSignal::OnConnectionFailed(uuid, format!("TLS handshake failed: {:?}", err))).await.unwrap();
188 return;
189 },
190 };
191
192 let uuid_clone = uuid.clone();
193 let event_broadcast_clone = event_broadcast.clone();
194 let service = service_fn(move |req| {
195 Self::incoming_request(req, uuid_clone.to_owned(), routes.clone(), event_broadcast_clone.to_owned())
196 });
197
198 let io = TokioIo::new(tls_stream);
199 let protocol = io.inner().get_ref().1.alpn_protocol();
200
201 match protocol.as_deref() {
202 Some(b"h2") => {
203 if let Err(err) = hyper::server::conn::http2::Builder::new(HttpExecutor).serve_connection(io, service).await {
204 event_broadcast.send(HttpReceiverEventSignal::OnConnectionFailed(uuid, format!("Connection failed: {:?}", err))).await.unwrap();
205 }
206 }
207 _ => {
208 if let Err(err) = hyper::server::conn::http1::Builder::new().serve_connection(io, service).await {
209 event_broadcast.send(HttpReceiverEventSignal::OnConnectionFailed(uuid, err.to_string())).await.unwrap();
210 }
211 }
212 }
213
214 }
215
216 async fn build_http_request(req: Request<Incoming>) -> HttpRequest {
217 let (parts, body) = req.into_parts();
218 let mut request = HttpRequest::new();
219 request.method = HttpMethod::from_str(parts.method.as_str()).unwrap();
220 request.path = parts.uri.path().to_string();
221 for (key, value) in parts.headers {
222 if let (Some(key), Ok(value)) = (key, value.to_str()) {
223 request.headers.insert(key.to_string(), value.to_string());
224 }
225 }
226 request.body = body.collect().await.unwrap().to_bytes().to_vec();
227 request
228 }
229
230 async fn build_http_response(res: HttpResponse) -> Response<Full<Bytes>> {
231 let mut response: Response<Full<Bytes>> = Response::builder().status(res.status.code()).body(res.body.into()).unwrap();
232 for (key, value) in res.headers {
233 let header_name = HeaderName::from_bytes(key.as_bytes()).unwrap();
234 let header_value = HeaderValue::from_str(&value).unwrap();
235 response.headers_mut().insert(header_name, header_value);
236 }
237 response
238 }
239}