Skip to main content

rustio_admin/
server.rs

1//! The HTTP server. Binds a TCP listener, runs each connection on its
2//! own Tokio task, and shuts down gracefully on Ctrl-C.
3
4use std::collections::HashMap;
5use std::net::SocketAddr;
6use std::sync::Arc;
7
8use bytes::Bytes;
9use http_body_util::{BodyExt, Full};
10use hyper::body::Incoming;
11use hyper::service::service_fn;
12use hyper::StatusCode;
13use hyper_util::rt::TokioIo;
14use tokio::net::TcpListener;
15
16use crate::error::Result;
17use crate::http::{Request, Response};
18use crate::router::Router;
19
20// public:
21pub struct Server {
22    router: Arc<Router>,
23    addr: SocketAddr,
24}
25
26impl Server {
27    // public:
28    pub fn new(router: Router, addr: SocketAddr) -> Self {
29        Self {
30            router: Arc::new(router),
31            addr,
32        }
33    }
34
35    // public:
36    /// Run until Ctrl-C / SIGTERM. Active connections get a brief grace
37    /// period to drain before the runtime drops them.
38    pub async fn run(self) -> Result<()> {
39        let listener = TcpListener::bind(self.addr).await?;
40        log::info!("rustio listening on http://{}", self.addr);
41
42        let shutdown = shutdown_signal();
43        tokio::pin!(shutdown);
44
45        loop {
46            tokio::select! {
47                accept = listener.accept() => {
48                    let (stream, peer) = accept?;
49                    let io = TokioIo::new(stream);
50                    let router = self.router.clone();
51                    tokio::spawn(async move {
52                        let svc = service_fn(move |req: hyper::Request<Incoming>| {
53                            let router = router.clone();
54                            async move { handle(router, req, peer).await }
55                        });
56                        let conn = hyper::server::conn::http1::Builder::new()
57                            .keep_alive(true)
58                            .serve_connection(io, svc);
59                        if let Err(e) = conn.await {
60                            // Normal client disconnects produce noisy errors;
61                            // only log at debug level.
62                            log::debug!("connection error: {e}");
63                        }
64                    });
65                }
66                _ = &mut shutdown => {
67                    log::info!("shutdown signal received, stopping accept loop");
68                    break;
69                }
70            }
71        }
72
73        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
74        Ok(())
75    }
76}
77
78async fn handle(
79    router: Arc<Router>,
80    hyper_req: hyper::Request<Incoming>,
81    _peer: SocketAddr,
82) -> std::result::Result<hyper::Response<Full<Bytes>>, hyper::Error> {
83    let method = hyper_req.method().clone();
84    let uri = hyper_req.uri().clone();
85    let path = uri.path().to_string();
86    let query = uri.query().unwrap_or("").to_string();
87
88    let mut headers = HashMap::new();
89    for (name, value) in hyper_req.headers() {
90        if let Ok(v) = value.to_str() {
91            headers.insert(name.as_str().to_ascii_lowercase(), v.to_string());
92        }
93    }
94
95    let body = match hyper_req.into_body().collect().await {
96        Ok(b) => b.to_bytes(),
97        Err(_) => {
98            return Ok(simple_response(
99                StatusCode::BAD_REQUEST,
100                "could not read body",
101            ));
102        }
103    };
104
105    let our_req = Request::new(method, path, query, headers, body);
106    let our_resp = router.dispatch(our_req).await;
107    Ok(to_hyper(our_resp))
108}
109
110fn to_hyper(resp: Response) -> hyper::Response<Full<Bytes>> {
111    let mut builder = hyper::Response::builder().status(resp.status);
112    for (name, value) in resp.headers {
113        builder = builder.header(name, value);
114    }
115    builder.body(Full::new(resp.body)).unwrap_or_else(|_| {
116        hyper::Response::builder()
117            .status(StatusCode::INTERNAL_SERVER_ERROR)
118            .body(Full::new(Bytes::from("internal error")))
119            .unwrap()
120    })
121}
122
123fn simple_response(status: StatusCode, body: &str) -> hyper::Response<Full<Bytes>> {
124    hyper::Response::builder()
125        .status(status)
126        .header("content-type", "text/plain; charset=utf-8")
127        .body(Full::new(Bytes::from(body.to_string())))
128        .unwrap()
129}
130
131async fn shutdown_signal() {
132    let ctrl_c = async {
133        tokio::signal::ctrl_c().await.ok();
134    };
135
136    #[cfg(unix)]
137    let terminate = async {
138        if let Ok(mut sig) =
139            tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
140        {
141            sig.recv().await;
142        }
143    };
144
145    #[cfg(not(unix))]
146    let terminate = std::future::pending::<()>();
147
148    tokio::select! {
149        _ = ctrl_c => {}
150        _ = terminate => {}
151    }
152}
153
154// public:
155/// Serve a static file from disk. Strips path separators and rejects
156/// `..` traversal.
157pub async fn serve_static(root: std::path::PathBuf, name: &str) -> Result<Response> {
158    let safe: String = name
159        .chars()
160        .filter(|c| *c != '/' && *c != '\\' && *c != '\0')
161        .collect();
162    if safe.contains("..") {
163        return Err(crate::error::Error::BadRequest("invalid path".into()));
164    }
165    let path = root.join(&safe);
166    if !path.is_file() {
167        return Err(crate::error::Error::NotFound(safe));
168    }
169    let bytes = tokio::fs::read(&path).await?;
170    Ok(Response::new(StatusCode::OK, Bytes::from(bytes))
171        .with_header("content-type", guess_content_type(&safe)))
172}
173
174fn guess_content_type(name: &str) -> &'static str {
175    match name.rsplit_once('.').map(|(_, ext)| ext) {
176        Some("css") => "text/css; charset=utf-8",
177        Some("js") => "application/javascript; charset=utf-8",
178        Some("png") => "image/png",
179        Some("jpg" | "jpeg") => "image/jpeg",
180        Some("svg") => "image/svg+xml",
181        Some("ico") => "image/x-icon",
182        Some("html") => "text/html; charset=utf-8",
183        Some("woff2") => "font/woff2",
184        Some("json") => "application/json; charset=utf-8",
185        _ => "application/octet-stream",
186    }
187}