rust_integration_services/http/server/
http_server.rs1use std::{collections::HashMap, convert::Infallible, panic::AssertUnwindSafe, pin::Pin, sync::Arc};
2
3use futures::FutureExt;
4use http_body_util::{BodyExt, combinators::BoxBody};
5use hyper::{Request, Response, body::{Bytes, Incoming}, service::service_fn};
6use hyper_util::rt::TokioIo;
7use matchit::Router;
8use tokio::{net::{TcpListener, TcpStream}, signal::unix::{signal, SignalKind}, task::JoinSet};
9use tokio_rustls::TlsAcceptor;
10
11use crate::http::{executor::Executor, http_request::HttpRequest, http_response::HttpResponse, server::http_server_config::HttpServerConfig};
12
13type RouteCallback = Arc<dyn Fn(HttpRequest) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>> + Send + Sync>;
14
15pub struct HttpServer {
16 config: HttpServerConfig,
17 router: Router<RouteCallback>,
18}
19
20impl HttpServer {
21 pub fn new(config: HttpServerConfig) -> Self {
22 HttpServer {
23 config,
24 router: Router::new(),
25 }
26 }
27
28 pub fn route<T, Fut>(mut self, path: impl Into<String>, callback: T) -> Self
30 where
31 T: Fn(HttpRequest) -> Fut + Send + Sync + 'static,
32 Fut: Future<Output = HttpResponse> + Send + 'static,
33 {
34 self.router.insert(path.into(), Arc::new(move |request| Box::pin(callback(request)))).unwrap();
35 self
36 }
37
38 pub async fn run(self) {
43 let tls_acceptor = self.config.tls_config.map(|tls_config| {
44 TlsAcceptor::from(Arc::new(tls_config))
45 });
46
47 let host = format!("{}:{}", self.config.ip, self.config.port);
48 let listener = TcpListener::bind(&host).await.expect("Failed to start TCP Listener");
49 let mut sigterm = signal(SignalKind::terminate()).expect("Failed to start SIGTERM signal receiver");
50 let mut sigint = signal(SignalKind::interrupt()).expect("Failed to start SIGINT signal receiver");
51 let mut receiver_join_set = JoinSet::new();
52 let router = Arc::new(self.router);
53
54 tracing::trace!("Started on {}", &host);
55 loop {
56 tokio::select! {
57 _ = sigterm.recv() => {
58 drop(listener);
59 break;
60 },
61 _ = sigint.recv() => {
62 drop(listener);
63 break;
64 },
65 result = listener.accept() => {
66 let tls_acceptor = tls_acceptor.clone();
67 let router = router.clone();
68 let (tcp_stream, client_addr) = match result {
69 Ok(pair) => pair,
70 Err(err) => {
71 tracing::error!("{:?}", err);
72 continue;
73 },
74 };
75
76 tracing::trace!("Connection {:?}", client_addr);
77 match tls_acceptor {
78 Some(acceptor) => {
79 receiver_join_set.spawn(Self::tls_connection(acceptor, tcp_stream, router));
80 },
81 None => {
82 receiver_join_set.spawn(Self::tcp_connection(tcp_stream, router));
83 },
84 }
85 }
86 }
87 }
88
89 tracing::trace!("Shut down pending...");
90 while let Some(_) = receiver_join_set.join_next().await {}
91 tracing::trace!("Shut down complete");
92 }
93
94 async fn tcp_connection(tcp_stream: TcpStream, router: Arc<Router<RouteCallback>>) {
95 let service = {
96 let router = router.clone();
97 service_fn(move |req| {
98 Self::incoming_request(req, router.clone())
99 })
100 };
101
102 let io = TokioIo::new(tcp_stream);
103 if let Err(err) = hyper::server::conn::http1::Builder::new().serve_connection(io, service).await {
104 tracing::error!("{:?}", err);
105 }
106 }
107
108 async fn tls_connection(tls_acceptor: TlsAcceptor, tcp_stream: TcpStream, router: Arc<Router<RouteCallback>>) {
109 let tls_stream = match tls_acceptor.accept(tcp_stream).await {
110 Ok(stream) => stream,
111 Err(err) => {
112 tracing::error!("TLS handshake failed {:?}", err);
113 return;
114 },
115 };
116
117 let service = {
118 let router = router.clone();
119 service_fn(move |req| {
120 Self::incoming_request(req, router.clone())
121 })
122 };
123
124 let io = TokioIo::new(tls_stream);
125 let protocol = io.inner().get_ref().1.alpn_protocol();
126 match protocol.as_deref() {
127 Some(b"h2") => {
128 if let Err(err) = hyper::server::conn::http2::Builder::new(Executor).serve_connection(io, service).await {
129 tracing::error!("TLS handshake failed {:?}", err);
130 }
131 }
132 _ => {
133 if let Err(err) = hyper::server::conn::http1::Builder::new().keep_alive(false).serve_connection(io, service).await {
134 tracing::error!("{:?}", err);
135 }
136 }
137 }
138 }
139
140 async fn incoming_request(request: Request<Incoming>, router: Arc<Router<RouteCallback>>) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Infallible> {
141 match router.at(&request.uri().path()) {
142 Ok(matched) => {
143 let params: HashMap<String, String> = matched.params.iter().map(|(key, value)| (key.to_string(), value.to_string())).collect();
144 let callback = matched.value;
145 let (parts, body) = request.into_parts();
146 let req = HttpRequest::from_parts_with_params(body.boxed(), parts, params);
147 let callback_fut = callback(req);
148 let result = AssertUnwindSafe(callback_fut).catch_unwind().await;
149 let response = match result {
150 Ok(res) => res,
151 Err(err) => {
152 tracing::error!("{:?}", err);
153 HttpResponse::builder().status(500).body_empty().unwrap()
154 }
155 };
156
157 Ok(Response::from(response))
158 },
159 Err(_) => {
160 let response = HttpResponse::builder().status(404).body_empty().unwrap();
161 Ok(Response::from(response))
162 },
163 }
164 }
165}