sbd_server/
lib.rs

1//! Sbd server library.
2#![deny(missing_docs)]
3
4/// defined by the sbd spec
5const MAX_MSG_BYTES: i32 = 20_000;
6
7use std::collections::HashMap;
8use std::io::{Error, Result};
9use std::net::{IpAddr, Ipv6Addr};
10use std::sync::{Arc, Mutex};
11
12mod config;
13pub use config::*;
14
15mod maybe_tls;
16pub use maybe_tls::*;
17
18mod ip_deny;
19mod ip_rate;
20pub use ip_rate::*;
21
22mod cslot;
23pub use cslot::*;
24
25mod cmd;
26
27/// Websocket backend abstraction.
28pub mod ws {
29    /// Payload.
30    pub enum Payload {
31        /// Vec.
32        Vec(Vec<u8>),
33
34        /// BytesMut.
35        BytesMut(bytes::BytesMut),
36    }
37
38    impl std::ops::Deref for Payload {
39        type Target = [u8];
40
41        #[inline(always)]
42        fn deref(&self) -> &Self::Target {
43            match self {
44                Payload::Vec(v) => v.as_slice(),
45                Payload::BytesMut(b) => b.as_ref(),
46            }
47        }
48    }
49
50    impl Payload {
51        /// Mutable payload.
52        #[inline(always)]
53        pub fn to_mut(&mut self) -> &mut [u8] {
54            match self {
55                Payload::Vec(ref mut owned) => owned,
56                Payload::BytesMut(b) => b.as_mut(),
57            }
58        }
59    }
60
61    use futures::future::BoxFuture;
62
63    /// Websocket trait.
64    pub trait SbdWebsocket: Send + Sync + 'static {
65        /// Receive from the websocket.
66        fn recv(&self) -> BoxFuture<'static, std::io::Result<Payload>>;
67
68        /// Send to the websocket.
69        fn send(
70            &self,
71            payload: Payload,
72        ) -> BoxFuture<'static, std::io::Result<()>>;
73
74        /// Close the websocket.
75        fn close(&self) -> BoxFuture<'static, ()>;
76    }
77}
78
79pub use ws::{Payload, SbdWebsocket};
80
81/// Public key.
82#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
83pub struct PubKey(pub Arc<[u8; 32]>);
84
85impl PubKey {
86    /// Verify a signature with this pub key.
87    pub fn verify(&self, sig: &[u8; 64], data: &[u8]) -> bool {
88        use ed25519_dalek::Verifier;
89        if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&self.0) {
90            k.verify(data, &ed25519_dalek::Signature::from_bytes(sig))
91                .is_ok()
92        } else {
93            false
94        }
95    }
96}
97
98/// SbdServer.
99pub struct SbdServer {
100    task_list: Vec<tokio::task::JoinHandle<()>>,
101    bind_addrs: Vec<std::net::SocketAddr>,
102    _cslot: cslot::CSlot,
103}
104
105impl Drop for SbdServer {
106    fn drop(&mut self) {
107        for task in self.task_list.iter() {
108            task.abort();
109        }
110    }
111}
112
113/// Convert an IP address to an IPv6 address.
114pub fn to_canonical_ip(ip: IpAddr) -> Arc<Ipv6Addr> {
115    Arc::new(match ip {
116        IpAddr::V4(ip) => ip.to_ipv6_mapped(),
117        IpAddr::V6(ip) => ip,
118    })
119}
120
121/// If the check passes, the canonical IP is returned, otherwise None and the connection should be
122/// dropped.
123pub async fn preflight_ip_check(
124    config: &Config,
125    ip_rate: &IpRate,
126    addr: std::net::SocketAddr,
127) -> Option<Arc<Ipv6Addr>> {
128    let raw_ip = to_canonical_ip(addr.ip());
129
130    let use_trusted_ip = config.trusted_ip_header.is_some();
131
132    if !use_trusted_ip {
133        // Do this check BEFORE handshake to avoid extra
134        // server process when capable.
135        // If we *are* behind a reverse proxy, we assume
136        // some amount of DDoS mitigation is happening there
137        // and thus we can accept a little more process overhead
138        if ip_rate.is_blocked(&raw_ip).await {
139            return None;
140        }
141
142        // Also precheck our rate limit, using up one byte
143        if !ip_rate.is_ok(&raw_ip, 1).await {
144            return None;
145        }
146    }
147
148    Some(raw_ip)
149}
150
151/// Handle an upgraded websocket connection.
152pub async fn handle_upgraded(
153    config: Arc<Config>,
154    ip_rate: Arc<IpRate>,
155    weak_cslot: WeakCSlot,
156    ws: Arc<impl SbdWebsocket>,
157    pub_key: PubKey,
158    calc_ip: Arc<Ipv6Addr>,
159    maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
160) {
161    let use_trusted_ip = config.trusted_ip_header.is_some();
162
163    // illegal pub key
164    if &pub_key.0[..28] == cmd::CMD_PREFIX {
165        return;
166    }
167
168    if use_trusted_ip {
169        // if using a trusted ip, check block here.
170        // see note above before the handshakes.
171        if ip_rate.is_blocked(&calc_ip).await {
172            return;
173        }
174
175        // Also precheck our rate limit, using up one byte
176        if !ip_rate.is_ok(&calc_ip, 1).await {
177            return;
178        }
179    }
180
181    if let Some(cslot) = weak_cslot.upgrade() {
182        cslot
183            .insert(&config, calc_ip, pub_key, ws, maybe_auth)
184            .await;
185    }
186}
187
188async fn handle_auth(
189    axum::extract::State(app_state): axum::extract::State<AppState>,
190    body: bytes::Bytes,
191) -> axum::response::Response {
192    use AuthenticateTokenError::*;
193
194    match process_authenticate_token(
195        &app_state.config,
196        &app_state.token_tracker,
197        body,
198    )
199    .await
200    {
201        Ok(token) => axum::response::IntoResponse::into_response(axum::Json(
202            serde_json::json!({
203                "authToken": *token,
204            }),
205        )),
206        Err(Unauthorized) => {
207            tracing::debug!("/authenticate: UNAUTHORIZED");
208            axum::response::IntoResponse::into_response((
209                axum::http::StatusCode::UNAUTHORIZED,
210                "Unauthorized",
211            ))
212        }
213        Err(HookServerError(err)) => {
214            tracing::debug!(?err, "/authenticate: BAD_GATEWAY");
215            axum::response::IntoResponse::into_response((
216                axum::http::StatusCode::BAD_GATEWAY,
217                format!("BAD_GATEWAY: {err:?}"),
218            ))
219        }
220        Err(OtherError(err)) => {
221            tracing::warn!(?err, "/authenticate: INTERNAL_SERVER_ERROR");
222            axum::response::IntoResponse::into_response((
223                axum::http::StatusCode::INTERNAL_SERVER_ERROR,
224                format!("INTERNAL_SERVER_ERROR: {err:?}"),
225            ))
226        }
227    }
228}
229
230/// Authenticate token error type.
231pub enum AuthenticateTokenError {
232    /// The token is invalid.
233    Unauthorized,
234    /// We had an error talking to the hook server.
235    HookServerError(std::io::Error),
236    /// We had an internal error.
237    OtherError(std::io::Error),
238}
239
240/// Handle receiving a PUT "/authenticate" rest api request.
241pub async fn process_authenticate_token(
242    config: &Config,
243    token_tracker: &AuthTokenTracker,
244    auth_material: bytes::Bytes,
245) -> std::result::Result<Arc<str>, AuthenticateTokenError> {
246    use AuthenticateTokenError::*;
247
248    let token: Arc<str> = if let Some(url) = &config.authentication_hook_server
249    {
250        let url = url.clone();
251        let token = tokio::task::spawn_blocking(move || {
252            ureq::put(&url)
253                .set("Content-Type", "application/octet-stream")
254                .send(&auth_material[..])
255                .map_err(|err| match err {
256                    ureq::Error::Status(401, _) => Unauthorized,
257                    oth => HookServerError(std::io::Error::other(oth)),
258                })?
259                .into_string()
260                // this is a HookServerError, not an OtherError, because
261                // it is the hook server that either failed to send a full
262                // response, or sent back non-utf8 bytes, etc...
263                .map_err(HookServerError)
264        })
265        .await
266        .map_err(|_| OtherError(std::io::Error::other("tokio task died")))??;
267
268        #[derive(serde::Deserialize)]
269        #[serde(rename_all = "camelCase")]
270        struct Token {
271            auth_token: String,
272        }
273
274        let token: Token = serde_json::from_str(&token)
275            .map_err(|err| OtherError(std::io::Error::other(err)))?;
276
277        token.auth_token
278    } else {
279        // If no backend configured, fallback to gen random token:
280        use base64::prelude::*;
281        use rand::Rng;
282
283        let mut bytes = [0; 32];
284        rand::thread_rng().fill(&mut bytes);
285        BASE64_URL_SAFE_NO_PAD.encode(&bytes[..])
286    }
287    .into();
288
289    token_tracker.register_token(token.clone());
290
291    Ok(token)
292}
293
294#[derive(Clone)]
295struct WebsocketImpl {
296    write: Arc<
297        tokio::sync::Mutex<
298            futures::stream::SplitSink<
299                axum::extract::ws::WebSocket,
300                axum::extract::ws::Message,
301            >,
302        >,
303    >,
304    read: Arc<
305        tokio::sync::Mutex<
306            futures::stream::SplitStream<axum::extract::ws::WebSocket>,
307        >,
308    >,
309}
310
311impl SbdWebsocket for WebsocketImpl {
312    fn recv(&self) -> futures::future::BoxFuture<'static, Result<Payload>> {
313        let this = self.clone();
314        Box::pin(async move {
315            let mut read = this.read.lock().await;
316            use futures::stream::StreamExt;
317            loop {
318                match read.next().await {
319                    None => return Err(Error::other("closed")),
320                    Some(r) => {
321                        let msg = r.map_err(Error::other)?;
322                        match msg {
323                            axum::extract::ws::Message::Text(s) => {
324                                return Ok(Payload::Vec(s.as_bytes().to_vec()))
325                            }
326                            axum::extract::ws::Message::Binary(v) => {
327                                return Ok(Payload::Vec(v[..].to_vec()))
328                            }
329                            axum::extract::ws::Message::Ping(_)
330                            | axum::extract::ws::Message::Pong(_) => (),
331                            axum::extract::ws::Message::Close(_) => {
332                                return Err(Error::other("closed"))
333                            }
334                        }
335                    }
336                }
337            }
338        })
339    }
340
341    fn send(
342        &self,
343        payload: Payload,
344    ) -> futures::future::BoxFuture<'static, Result<()>> {
345        use futures::SinkExt;
346        let this = self.clone();
347        Box::pin(async move {
348            let mut write = this.write.lock().await;
349            let v = match payload {
350                Payload::Vec(v) => v,
351                Payload::BytesMut(b) => b.to_vec(),
352            };
353            write
354                .send(axum::extract::ws::Message::Binary(
355                    bytes::Bytes::copy_from_slice(&v),
356                ))
357                .await
358                .map_err(Error::other)?;
359            write.flush().await.map_err(Error::other)?;
360            Ok(())
361        })
362    }
363
364    fn close(&self) -> futures::future::BoxFuture<'static, ()> {
365        use futures::SinkExt;
366        let this = self.clone();
367        Box::pin(async move {
368            let _ = this.write.lock().await.close().await;
369        })
370    }
371}
372
373impl WebsocketImpl {
374    fn new(ws: axum::extract::ws::WebSocket) -> Self {
375        use futures::StreamExt;
376        let (tx, rx) = ws.split();
377        Self {
378            write: Arc::new(tokio::sync::Mutex::new(tx)),
379            read: Arc::new(tokio::sync::Mutex::new(rx)),
380        }
381    }
382}
383
384async fn handle_ws(
385    axum::extract::Path(pub_key): axum::extract::Path<String>,
386    headers: axum::http::HeaderMap,
387    ws: axum::extract::WebSocketUpgrade,
388    axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo<
389        std::net::SocketAddr,
390    >,
391    axum::extract::State(app_state): axum::extract::State<AppState>,
392) -> impl axum::response::IntoResponse {
393    use axum::response::IntoResponse;
394    use base64::Engine;
395
396    let token: Option<Arc<str>> = headers
397        .get("Authorization")
398        .and_then(|t| t.to_str().ok().map(<Arc<str>>::from));
399
400    let maybe_auth = Some((token.clone(), app_state.token_tracker.clone()));
401
402    if !app_state
403        .token_tracker
404        .check_is_token_valid(&app_state.config, token)
405    {
406        return axum::response::IntoResponse::into_response((
407            axum::http::StatusCode::UNAUTHORIZED,
408            "Unauthorized",
409        ));
410    }
411
412    let pk = match base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(pub_key) {
413        Ok(pk) if pk.len() == 32 => {
414            let mut sized_pk = [0; 32];
415            sized_pk.copy_from_slice(&pk);
416            PubKey(Arc::new(sized_pk))
417        }
418        _ => return axum::http::StatusCode::BAD_REQUEST.into_response(),
419    };
420
421    let mut calc_ip = to_canonical_ip(addr.ip());
422
423    if let Some(trusted_ip_header) = &app_state.config.trusted_ip_header {
424        if let Some(header) =
425            headers.get(trusted_ip_header).and_then(|h| h.to_str().ok())
426        {
427            if let Ok(ip) = header.parse::<IpAddr>() {
428                calc_ip = to_canonical_ip(ip);
429            }
430        }
431    }
432
433    ws.max_message_size(MAX_MSG_BYTES as usize).on_upgrade(
434        move |socket| async move {
435            handle_upgraded(
436                app_state.config.clone(),
437                app_state.ip_rate.clone(),
438                app_state.cslot.clone(),
439                Arc::new(WebsocketImpl::new(socket)),
440                pk,
441                calc_ip,
442                maybe_auth,
443            )
444            .await;
445        },
446    )
447}
448
449/// Utility for managing auth tokens.
450#[derive(Clone, Default)]
451pub struct AuthTokenTracker {
452    token_map: Arc<Mutex<HashMap<Arc<str>, std::time::Instant>>>,
453}
454
455impl AuthTokenTracker {
456    /// Register a token as valid.
457    pub fn register_token(&self, token: Arc<str>) {
458        self.token_map
459            .lock()
460            .unwrap()
461            .insert(token, std::time::Instant::now());
462    }
463
464    /// Check that a token is valid.
465    /// If so, mark it as recently used so it doesn't time out.
466    /// The "token" parameter should be direct from the http header
467    /// i.e. with the "Barer" include, like "Bearer base64".
468    /// This should be called with None as the token if no Authenticate
469    /// header was specified.
470    pub fn check_is_token_valid(
471        &self,
472        config: &Config,
473        token: Option<Arc<str>>,
474    ) -> bool {
475        let token: Arc<str> = if let Some(token) = token {
476            // If the client supplied a token, always validate it,
477            // even if no hook server was specified in the config.
478            if !token.starts_with("Bearer ") {
479                return false;
480            }
481            token.trim_start_matches("Bearer ").into()
482        } else if config.authentication_hook_server.is_none() {
483            // If the client did not supply a token, and we have no
484            // hook server configured, allow the request.
485            return true;
486        } else {
487            // We have no token, but one is required. Unauthorized.
488            return false;
489        };
490
491        let mut lock = self.token_map.lock().unwrap();
492
493        let idle_dur = config.idle_dur();
494
495        lock.retain(|_t, e| e.elapsed() < idle_dur);
496
497        if let std::collections::hash_map::Entry::Occupied(mut e) =
498            lock.entry(token)
499        {
500            e.insert(std::time::Instant::now());
501            true
502        } else {
503            false
504        }
505    }
506}
507
508#[derive(Clone)]
509struct AppState {
510    config: Arc<Config>,
511    token_tracker: AuthTokenTracker,
512    ip_rate: Arc<IpRate>,
513    cslot: WeakCSlot,
514}
515
516impl AppState {
517    pub fn new(
518        config: Arc<Config>,
519        ip_rate: Arc<IpRate>,
520        cslot: WeakCSlot,
521    ) -> Self {
522        Self {
523            config,
524            token_tracker: AuthTokenTracker::default(),
525            ip_rate,
526            cslot,
527        }
528    }
529}
530
531impl SbdServer {
532    /// Construct a new running sbd server with the provided config.
533    pub async fn new(config: Arc<Config>) -> Result<Self> {
534        let tls_config = if let (Some(cert), Some(pk)) =
535            (&config.cert_pem_file, &config.priv_key_pem_file)
536        {
537            Some(Arc::new(maybe_tls::TlsConfig::new(cert, pk).await?))
538        } else {
539            None
540        };
541
542        let mut task_list = Vec::new();
543        let mut bind_addrs = Vec::new();
544
545        let ip_rate = Arc::new(IpRate::new(config.clone()));
546        task_list.push(spawn_prune_task(ip_rate.clone()));
547
548        let cslot = CSlot::new(config.clone(), ip_rate.clone());
549        let weak_cslot = cslot.weak();
550
551        let app: axum::Router<()> = axum::Router::new()
552            .route("/authenticate", axum::routing::put(handle_auth))
553            .route("/{pub_key}", axum::routing::any(handle_ws))
554            .layer(axum::extract::DefaultBodyLimit::max(1024))
555            .with_state(AppState::new(
556                config.clone(),
557                ip_rate.clone(),
558                weak_cslot.clone(),
559            ));
560
561        let app =
562            app.into_make_service_with_connect_info::<std::net::SocketAddr>();
563
564        let mut found_port_zero: Option<u16> = None;
565
566        for bind in config.bind.iter() {
567            let mut a: std::net::SocketAddr =
568                bind.parse().map_err(Error::other)?;
569            if let Some(found_port_zero) = &found_port_zero {
570                if a.port() == 0 {
571                    a.set_port(*found_port_zero);
572                }
573            }
574
575            let h = axum_server::Handle::new();
576
577            if let Some(tls_config) = &tls_config {
578                let tls_config =
579                    axum_server::tls_rustls::RustlsConfig::from_config(
580                        tls_config.config(),
581                    );
582                let server = axum_server::bind_rustls(a, tls_config)
583                    .handle(h.clone())
584                    .serve(app.clone());
585                task_list.push(tokio::task::spawn(async move {
586                    if let Err(err) = server.await {
587                        tracing::error!(?err);
588                    }
589                }));
590            } else {
591                let server =
592                    axum_server::bind(a).handle(h.clone()).serve(app.clone());
593                task_list.push(tokio::task::spawn(async move {
594                    if let Err(err) = server.await {
595                        tracing::error!(?err);
596                    }
597                }));
598            }
599
600            if let Some(addr) = h.listening().await {
601                if found_port_zero.is_none() && a.port() == 0 {
602                    found_port_zero = Some(addr.port());
603                }
604                bind_addrs.push(addr);
605            }
606        }
607
608        Ok(Self {
609            task_list,
610            bind_addrs,
611            _cslot: cslot,
612        })
613    }
614
615    /// Get the list of addresses bound locally.
616    pub fn bind_addrs(&self) -> &[std::net::SocketAddr] {
617        self.bind_addrs.as_slice()
618    }
619}
620
621#[cfg(test)]
622mod test;