twinstar/
lib.rs

1#[macro_use] extern crate log;
2
3use std::{
4    panic::AssertUnwindSafe,
5    convert::TryFrom,
6    io::BufReader,
7    sync::Arc,
8    path::PathBuf,
9    time::Duration,
10};
11use futures_core::future::BoxFuture;
12use tokio::{
13    prelude::*,
14    io::{self, BufStream},
15    net::{TcpStream, ToSocketAddrs},
16    time::timeout,
17};
18use tokio::net::TcpListener;
19use rustls::ClientCertVerifier;
20use rustls::internal::msgs::handshake::DigitallySignedStruct;
21use tokio_rustls::{rustls, TlsAcceptor};
22use rustls::*;
23use anyhow::*;
24use lazy_static::lazy_static;
25use crate::util::opt_timeout;
26use routing::RoutingNode;
27
28pub mod types;
29pub mod util;
30pub mod routing;
31
32pub use mime;
33pub use uriparse as uri;
34pub use types::*;
35
36pub const REQUEST_URI_MAX_LEN: usize = 1024;
37pub const GEMINI_PORT: u16 = 1965;
38
39type Handler = Arc<dyn Fn(Request) -> HandlerResponse + Send + Sync>;
40pub (crate) type HandlerResponse = BoxFuture<'static, Result<Response>>;
41
42#[derive(Clone)]
43pub struct Server {
44    tls_acceptor: TlsAcceptor,
45    listener: Arc<TcpListener>,
46    routes: Arc<RoutingNode<Handler>>,
47    timeout: Duration,
48    complex_timeout: Option<Duration>,
49}
50
51impl Server {
52    pub fn bind<A: ToSocketAddrs>(addr: A) -> Builder<A> {
53        Builder::bind(addr)
54    }
55
56    async fn serve(self) -> Result<()> {
57        loop {
58            let (stream, _addr) = self.listener.accept().await
59                .context("Failed to accept client")?;
60            let this = self.clone();
61
62            tokio::spawn(async move {
63                if let Err(err) = this.serve_client(stream).await {
64                    error!("{:?}", err);
65                }
66            });
67        }
68    }
69
70    async fn serve_client(self, stream: TcpStream) -> Result<()> {
71        let fut_accept_request = async {
72            let stream = self.tls_acceptor.accept(stream).await
73                .context("Failed to establish TLS session")?;
74            let mut stream = BufStream::new(stream);
75
76            let request = receive_request(&mut stream).await
77                .context("Failed to receive request")?;
78
79            Result::<_, anyhow::Error>::Ok((request, stream))
80        };
81
82        // Use a timeout for interacting with the client
83        let fut_accept_request = timeout(self.timeout, fut_accept_request);
84        let (mut request, mut stream) = fut_accept_request.await
85            .context("Client timed out while waiting for response")??;
86
87        debug!("Client requested: {}", request.uri());
88
89        // Identify the client certificate from the tls stream.  This is the first
90        // certificate in the certificate chain.
91        let client_cert = stream.get_ref()
92            .get_ref()
93            .1
94            .get_peer_certificates()
95            .and_then(|mut v| if v.is_empty() {None} else {Some(v.remove(0))});
96
97        request.set_cert(client_cert);
98
99        let response = if let Some((trailing, handler)) = self.routes.match_request(&request) {
100
101            request.set_trailing(trailing);
102
103            let handler = (handler)(request);
104            let handler = AssertUnwindSafe(handler);
105
106            util::HandlerCatchUnwind::new(handler).await
107                .unwrap_or_else(|_| Response::server_error(""))
108                .or_else(|err| {
109                    error!("Handler failed: {:?}", err);
110                    Response::server_error("")
111                })
112                .context("Request handler failed")?
113        } else {
114            Response::not_found()
115        };
116
117        self.send_response(response, &mut stream).await
118            .context("Failed to send response")?;
119
120        Ok(())
121    }
122
123    async fn send_response(&self, mut response: Response, stream: &mut (impl AsyncWrite + Unpin)) -> Result<()> {
124        let maybe_body = response.take_body();
125        let header = response.header();
126
127        let use_complex_timeout =
128            header.status.is_success() &&
129            maybe_body.is_some() &&
130            header.meta.as_str() != "text/plain" &&
131            header.meta.as_str() != "text/gemini" &&
132            self.complex_timeout.is_some();
133
134        let send_general_timeout;
135        let send_header_timeout;
136        let send_body_timeout;
137
138        if use_complex_timeout {
139            send_general_timeout = None;
140            send_header_timeout = Some(self.timeout);
141            send_body_timeout = self.complex_timeout;
142        } else {
143            send_general_timeout = Some(self.timeout);
144            send_header_timeout = None;
145            send_body_timeout = None;
146        }
147
148        opt_timeout(send_general_timeout, async {
149            // Send the header
150            opt_timeout(send_header_timeout, send_response_header(response.header(), stream))
151                .await
152                .context("Timed out while sending response header")?
153                .context("Failed to write response header")?;
154
155            // Send the body
156            opt_timeout(send_body_timeout, maybe_send_response_body(maybe_body, stream))
157                .await
158                .context("Timed out while sending response body")?
159                .context("Failed to write response body")?;
160
161            Ok::<_,Error>(())
162        })
163        .await
164        .context("Timed out while sending response data")??;
165
166        Ok(())
167    }
168}
169
170pub struct Builder<A> {
171    addr: A,
172    cert_path: PathBuf,
173    key_path: PathBuf,
174    timeout: Duration,
175    complex_body_timeout_override: Option<Duration>,
176    routes: RoutingNode<Handler>,
177}
178
179impl<A: ToSocketAddrs> Builder<A> {
180    fn bind(addr: A) -> Self {
181        Self {
182            addr,
183            timeout: Duration::from_secs(1),
184            complex_body_timeout_override: Some(Duration::from_secs(30)),
185            cert_path: PathBuf::from("cert/cert.pem"),
186            key_path: PathBuf::from("cert/key.pem"),
187            routes: RoutingNode::default(),
188        }
189    }
190
191    /// Sets the directory that twinstar should look for TLS certs and keys into
192    ///
193    /// Northstar will look for files called `cert.pem` and `key.pem` in the provided
194    /// directory.
195    ///
196    /// This does not need to be set if both [`set_cert()`](Self::set_cert()) and
197    /// [`set_key()`](Self::set_key()) have been called.
198    ///
199    /// If not set, the default is `cert/`
200    pub fn set_tls_dir(self, dir: impl Into<PathBuf>) -> Self {
201        let dir = dir.into();
202        self.set_cert(dir.join("cert.pem"))
203            .set_key(dir.join("key.pem"))
204    }
205
206    /// Set the path to the TLS certificate twinstar will use
207    ///
208    /// This defaults to `cert/cert.pem`.
209    ///
210    /// This does not need to be called it [`set_tls_dir()`](Self::set_tls_dir()) has been
211    /// called.
212    pub fn set_cert(mut self, cert_path: impl Into<PathBuf>) -> Self {
213        self.cert_path = cert_path.into();
214        self
215    }
216
217    /// Set the path to the ertificate key twinstar will use
218    ///
219    /// This defaults to `cert/key.pem`.
220    ///
221    /// This does not need to be called it [`set_tls_dir()`](Self::set_tls_dir()) has been
222    /// called.
223    ///
224    /// This should of course correspond to the key set in
225    /// [`set_cert()`](Self::set_cert())
226    pub fn set_key(mut self, key_path: impl Into<PathBuf>) -> Self {
227        self.key_path = key_path.into();
228        self
229    }
230
231    /// Set the timeout on incoming requests
232    ///
233    /// Note that this timeout is applied twice, once for the delivery of the request, and
234    /// once for sending the client's response.  This means that for a 1 second timeout,
235    /// the client will have 1 second to complete the TLS handshake and deliver a request
236    /// header, then your API will have as much time as it needs to handle the request,
237    /// before the client has another second to receive the response.
238    ///
239    /// If you would like a timeout for your code itself, please use
240    /// [`tokio::time::Timeout`] to implement it internally.
241    ///
242    /// **The default timeout is 1 second.**  As somewhat of a workaround for
243    /// shortcomings of the specification, this timeout, and any timeout set using this
244    /// method, is overridden in special cases, specifically for MIME types outside of
245    /// `text/plain` and `text/gemini`, to be 30 seconds.  If you would like to change or
246    /// prevent this, please see
247    /// [`override_complex_body_timeout`](Self::override_complex_body_timeout()).
248    pub fn set_timeout(mut self, timeout: Duration) -> Self {
249        self.timeout = timeout;
250        self
251    }
252
253    /// Override the timeout for complex body types
254    ///
255    /// Many clients choose to handle body types which cannot be displayed by prompting
256    /// the user if they would like to download or open the request body.  However, since
257    /// this prompt occurs in the middle of receiving a request, often the connection
258    /// times out before the end user is able to respond to the prompt.
259    ///
260    /// As a workaround, it is possible to set an override on the request timeout in
261    /// specific conditions:
262    ///
263    /// 1. **Only override the timeout for receiving the body of the request.**  This will
264    ///    not override the timeout on sending the request header, nor on receiving the
265    ///    response header.
266    /// 2. **Only override the timeout for successful responses.**  The only bodies which
267    ///    have bodies are successful ones.  In all other cases, there's no body to
268    ///    timeout for
269    /// 3. **Only override the timeout for complex body types.**  Almost all clients are
270    ///    able to display `text/plain` and `text/gemini` responses, and will not prompt
271    ///    the user for these response types.  This means that there is no reason to
272    ///    expect a client to have a human-length response time for these MIME types.
273    ///    Because of this, responses of this type will not be overridden.
274    ///
275    /// This method is used to override the timeout for responses meeting these specific
276    /// criteria.  All other stages of the connection will use the timeout specified in
277    /// [`set_timeout()`](Self::set_timeout()).
278    ///
279    /// If this is set to [`None`], then the client will have the default amount of time
280    /// to both receive the header and the body.  If this is set to [`Some`], the client
281    /// will have the default amount of time to recieve the header, and an *additional*
282    /// alotment of time to recieve the body.
283    ///
284    /// The default timeout for this is 30 seconds.
285    pub fn override_complex_body_timeout(mut self, timeout: Option<Duration>) -> Self {
286        self.complex_body_timeout_override = timeout;
287        self
288    }
289
290    /// Add a handler for a route
291    ///
292    /// A route must be an absolute path, for example "/endpoint" or "/", but not
293    /// "endpoint".  Entering a relative or malformed path will result in a panic.
294    ///
295    /// For more information about routing mechanics, see the docs for [`RoutingNode`].
296    pub fn add_route<H>(mut self, path: &'static str, handler: H) -> Self
297    where
298        H: Fn(Request) -> HandlerResponse + Send + Sync + 'static,
299    {
300        self.routes.add_route(path, Arc::new(handler));
301        self
302    }
303
304    pub async fn serve(mut self) -> Result<()> {
305        let config = tls_config(&self.cert_path, &self.key_path)
306            .context("Failed to create TLS config")?;
307
308        let listener = TcpListener::bind(self.addr).await
309            .context("Failed to create socket")?;
310
311        self.routes.shrink();
312
313        let server = Server {
314            tls_acceptor: TlsAcceptor::from(config),
315            listener: Arc::new(listener),
316            routes: Arc::new(self.routes),
317            timeout: self.timeout,
318            complex_timeout: self.complex_body_timeout_override,
319        };
320
321        server.serve().await
322    }
323}
324
325async fn receive_request(stream: &mut (impl AsyncBufRead + Unpin)) -> Result<Request> {
326    let limit = REQUEST_URI_MAX_LEN + "\r\n".len();
327    let mut stream = stream.take(limit as u64);
328    let mut uri = Vec::new();
329
330    stream.read_until(b'\n', &mut uri).await?;
331
332    if !uri.ends_with(b"\r\n") {
333        if uri.len() < REQUEST_URI_MAX_LEN {
334            bail!("Request header not terminated with CRLF")
335        } else {
336            bail!("Request URI too long")
337        }
338    }
339
340    // Strip CRLF
341    uri.pop();
342    uri.pop();
343
344    let uri = URIReference::try_from(&*uri)
345        .context("Request URI is invalid")?
346        .into_owned();
347    let request = Request::from_uri(uri)
348        .context("Failed to create request from URI")?;
349
350    Ok(request)
351}
352
353async fn send_response_header(header: &ResponseHeader, stream: &mut (impl AsyncWrite + Unpin)) -> Result<()> {
354    let header = format!(
355        "{status} {meta}\r\n",
356        status = header.status.code(),
357        meta = header.meta.as_str(),
358    );
359
360    stream.write_all(header.as_bytes()).await?;
361    stream.flush().await?;
362
363    Ok(())
364}
365
366async fn maybe_send_response_body(maybe_body: Option<Body>, stream: &mut (impl AsyncWrite + Unpin)) -> Result<()> {
367    if let Some(body) = maybe_body {
368        send_response_body(body, stream).await?;
369    }
370
371    Ok(())
372}
373
374async fn send_response_body(body: Body, stream: &mut (impl AsyncWrite + Unpin)) -> Result<()> {
375    match body {
376        Body::Bytes(bytes) => stream.write_all(&bytes).await?,
377        Body::Reader(mut reader) => { io::copy(&mut reader, stream).await?; },
378    }
379
380    stream.flush().await?;
381
382    Ok(())
383}
384
385fn tls_config(cert_path: &PathBuf, key_path: &PathBuf) -> Result<Arc<ServerConfig>> {
386    let mut config = ServerConfig::new(AllowAnonOrSelfsignedClient::new());
387
388    let cert_chain = load_cert_chain(cert_path)
389        .context("Failed to load TLS certificate")?;
390    let key = load_key(key_path)
391        .context("Failed to load TLS key")?;
392    config.set_single_cert(cert_chain, key)
393        .context("Failed to use loaded TLS certificate")?;
394
395    Ok(config.into())
396}
397
398fn load_cert_chain(cert_path: &PathBuf) -> Result<Vec<Certificate>> {
399    let certs = std::fs::File::open(cert_path)
400        .with_context(|| format!("Failed to open `{:?}`", cert_path))?;
401    let mut certs = BufReader::new(certs);
402    let certs = rustls::internal::pemfile::certs(&mut certs)
403        .map_err(|_| anyhow!("failed to load certs `{:?}`", cert_path))?;
404
405    Ok(certs)
406}
407
408fn load_key(key_path: &PathBuf) -> Result<PrivateKey> {
409    let keys = std::fs::File::open(key_path)
410        .with_context(|| format!("Failed to open `{:?}`", key_path))?;
411    let mut keys = BufReader::new(keys);
412    let mut keys = rustls::internal::pemfile::pkcs8_private_keys(&mut keys)
413        .map_err(|_| anyhow!("failed to load key `{:?}`", key_path))?;
414
415    ensure!(!keys.is_empty(), "no key found");
416
417    let key = keys.swap_remove(0);
418
419    Ok(key)
420}
421
422/// Mime for Gemini documents
423pub const GEMINI_MIME_STR: &str = "text/gemini";
424
425lazy_static! {
426    /// Mime for Gemini documents ("text/gemini")
427    pub static ref GEMINI_MIME: Mime = GEMINI_MIME_STR.parse().expect("twinstar BUG");
428}
429
430#[deprecated(note = "Use `GEMINI_MIME` instead", since = "0.3.0")]
431pub fn gemini_mime() -> Result<Mime> {
432    Ok(GEMINI_MIME.clone())
433}
434
435/// A client cert verifier that accepts all connections
436///
437/// Unfortunately, rustls doesn't provide a ClientCertVerifier that accepts self-signed
438/// certificates, so we need to implement this ourselves.
439struct AllowAnonOrSelfsignedClient { }
440impl AllowAnonOrSelfsignedClient {
441
442    /// Create a new verifier
443    fn new() -> Arc<Self> {
444        Arc::new(Self {})
445    }
446
447}
448
449impl ClientCertVerifier for AllowAnonOrSelfsignedClient {
450
451    fn client_auth_root_subjects(
452        &self,
453        _: Option<&webpki::DNSName>
454    ) -> Option<DistinguishedNames> {
455        Some(Vec::new())
456    }
457
458    fn client_auth_mandatory(&self, _sni: Option<&webpki::DNSName>) -> Option<bool> {
459        Some(false)
460    }
461
462    // the below methods are a hack until webpki doesn't break with certain certs
463
464    fn verify_client_cert(
465        &self,
466        _: &[Certificate],
467        _: Option<&webpki::DNSName>
468    ) -> Result<ClientCertVerified, TLSError> {
469        Ok(ClientCertVerified::assertion())
470    }
471
472    fn verify_tls12_signature(
473        &self,
474        _message: &[u8],
475        _cert: &Certificate,
476        _dss: &DigitallySignedStruct,
477    ) -> Result<HandshakeSignatureValid, TLSError> {
478        Ok(HandshakeSignatureValid::assertion())
479    }
480
481    fn verify_tls13_signature(
482        &self,
483        _message: &[u8],
484        _cert: &Certificate,
485        _dss: &DigitallySignedStruct,
486    ) -> Result<HandshakeSignatureValid, TLSError> {
487        Ok(HandshakeSignatureValid::assertion())
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn gemini_mime_parses() {
497        let _: &Mime = &GEMINI_MIME;
498    }
499}