sbd_server/
lib.rs

1//! Sbd server library.
2//!
3//! ## Metrics
4//!
5//! The server exports the following OpenTelemetry metrics if an OTLP endpoint is configured.
6//!
7//! Prometheus example: `sbd-serverd --otlp-endpoint http://localhost:9090/api/v1/otlp/v1/metrics`
8//!
9//! | Full Metric Name | Type | Unit (optional) | Description | Attributes |
10//! | ---------------- | ---- | --------------- | ----------- | ---------- |
11//! | `sbd.server.open_connections` | `f64_up_down_counter` | `count` | The current number of open connections | |
12//! | `sbd.server.ip_rate_limited` | `u64_counter` | `count` | The number of connections that have been closed because of an IP rate limiting violation | - `pub_key`: The base64 encoded public key declared by the offending connection.<br />- `kind`: Has two possible values. It will be "auth" for violations during authentication and "msg" for violations while sending messages. |
13//! | `sbd.server.bytes_send` | `u64_counter` | `bytes` | The number of bytes sent per public key. Resets when a new connection is opened. | - `pub_key`: The base64 encoded public key declared by the offending connection. |
14//! | `sbd.server.bytes_recv` | `u64_counter` | `bytes` | The number of bytes received per public key. Resets when a new connection is opened. | - `pub_key`: The base64 encoded public key declared by the offending connection. |
15//! | `sbd.server.auth_failures` | `u64_counter` | `count` | The number of failed authentication attempts. | - `pub_key`: The base64 encoded public key declared by the offending connection. This is only present if an invalid token is used with a specific public key. |
16
17#![deny(missing_docs)]
18
19/// defined by the sbd spec
20const MAX_MSG_BYTES: i32 = 20_000;
21
22use base64::Engine;
23use opentelemetry::global;
24use std::collections::HashMap;
25use std::io::{Error, Result};
26use std::net::{IpAddr, Ipv6Addr};
27use std::sync::{Arc, Mutex};
28
29mod config;
30pub use config::*;
31
32mod maybe_tls;
33pub use maybe_tls::*;
34
35mod ip_deny;
36mod ip_rate;
37pub use ip_rate::*;
38
39mod cslot;
40pub use cslot::*;
41
42mod cmd;
43
44mod metrics;
45pub use metrics::*;
46
47/// Websocket backend abstraction.
48pub mod ws {
49    /// Payload.
50    pub enum Payload {
51        /// Vec.
52        Vec(Vec<u8>),
53
54        /// BytesMut.
55        BytesMut(bytes::BytesMut),
56    }
57
58    impl std::ops::Deref for Payload {
59        type Target = [u8];
60
61        #[inline(always)]
62        fn deref(&self) -> &Self::Target {
63            match self {
64                Payload::Vec(v) => v.as_slice(),
65                Payload::BytesMut(b) => b.as_ref(),
66            }
67        }
68    }
69
70    impl Payload {
71        /// Mutable payload.
72        #[inline(always)]
73        pub fn to_mut(&mut self) -> &mut [u8] {
74            match self {
75                Payload::Vec(ref mut owned) => owned,
76                Payload::BytesMut(b) => b.as_mut(),
77            }
78        }
79    }
80
81    use futures::future::BoxFuture;
82
83    /// Websocket trait.
84    pub trait SbdWebsocket: Send + Sync + 'static {
85        /// Receive from the websocket.
86        fn recv(&self) -> BoxFuture<'static, std::io::Result<Payload>>;
87
88        /// Send to the websocket.
89        fn send(
90            &self,
91            payload: Payload,
92        ) -> BoxFuture<'static, std::io::Result<()>>;
93
94        /// Close the websocket.
95        fn close(&self) -> BoxFuture<'static, ()>;
96    }
97}
98
99pub use ws::{Payload, SbdWebsocket};
100
101/// Public key.
102#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
103pub struct PubKey(pub Arc<[u8; 32]>);
104
105impl PubKey {
106    /// Verify a signature with this pub key.
107    pub fn verify(&self, sig: &[u8; 64], data: &[u8]) -> bool {
108        use ed25519_dalek::Verifier;
109        if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&self.0) {
110            k.verify(data, &ed25519_dalek::Signature::from_bytes(sig))
111                .is_ok()
112        } else {
113            false
114        }
115    }
116}
117
118/// SbdServer.
119pub struct SbdServer {
120    task_list: Vec<tokio::task::JoinHandle<()>>,
121    bind_addrs: Vec<std::net::SocketAddr>,
122
123    // this should be the only non-weak cslot so the others are dropped
124    // if this top-level server instance is ever dropped.
125    _cslot: CSlot,
126}
127
128impl Drop for SbdServer {
129    fn drop(&mut self) {
130        for task in self.task_list.iter() {
131            task.abort();
132        }
133    }
134}
135
136/// Convert an IP address to an IPv6 address.
137pub fn to_canonical_ip(ip: IpAddr) -> Arc<Ipv6Addr> {
138    Arc::new(match ip {
139        IpAddr::V4(ip) => ip.to_ipv6_mapped(),
140        IpAddr::V6(ip) => ip,
141    })
142}
143
144/// If the check passes, the canonical IP is returned, otherwise None and the connection should be
145/// dropped.
146pub async fn preflight_ip_check(
147    config: &Config,
148    ip_rate: &IpRate,
149    addr: std::net::SocketAddr,
150) -> Option<Arc<Ipv6Addr>> {
151    let raw_ip = to_canonical_ip(addr.ip());
152
153    let use_trusted_ip = config.trusted_ip_header.is_some();
154
155    if !use_trusted_ip {
156        // Do this check BEFORE handshake to avoid extra
157        // server process when capable.
158        // If we *are* behind a reverse proxy, we assume
159        // some amount of DDoS mitigation is happening there
160        // and thus we can accept a little more process overhead
161        if ip_rate.is_blocked(&raw_ip).await {
162            return None;
163        }
164
165        // Also precheck our rate limit, using up one byte
166        if !ip_rate.is_ok(&raw_ip, 1).await {
167            return None;
168        }
169    }
170
171    Some(raw_ip)
172}
173
174/// Handle an upgraded websocket connection.
175pub async fn handle_upgraded(
176    config: Arc<Config>,
177    ip_rate: Arc<IpRate>,
178    weak_cslot: WeakCSlot,
179    ws: Arc<impl SbdWebsocket>,
180    pub_key: PubKey,
181    calc_ip: Arc<Ipv6Addr>,
182    maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
183) {
184    let use_trusted_ip = config.trusted_ip_header.is_some();
185
186    // illegal pub key
187    if &pub_key.0[..28] == cmd::CMD_PREFIX {
188        return;
189    }
190
191    if use_trusted_ip {
192        // if using a trusted ip, check block here.
193        // see note above before the handshakes.
194        if ip_rate.is_blocked(&calc_ip).await {
195            return;
196        }
197
198        // Also precheck our rate limit, using up one byte
199        if !ip_rate.is_ok(&calc_ip, 1).await {
200            return;
201        }
202    }
203
204    if let Some(cslot) = weak_cslot.upgrade() {
205        cslot
206            .insert(&config, calc_ip, pub_key, ws, maybe_auth)
207            .await;
208    }
209}
210
211/// Handle the /authenticate request for access token
212async fn handle_auth(
213    axum::extract::State(app_state): axum::extract::State<AppState>,
214    body: bytes::Bytes,
215) -> axum::response::Response {
216    use AuthenticateTokenError::*;
217
218    // process the actual authentication
219    match process_authenticate_token(
220        &app_state.config,
221        &app_state.token_tracker,
222        app_state.auth_failures,
223        body,
224    )
225    .await
226    {
227        Ok(token) => axum::response::IntoResponse::into_response(axum::Json(
228            serde_json::json!({
229                "authToken": *token,
230            }),
231        )),
232        Err(Unauthorized) => {
233            tracing::debug!("/authenticate: UNAUTHORIZED");
234            axum::response::IntoResponse::into_response((
235                axum::http::StatusCode::UNAUTHORIZED,
236                "Unauthorized",
237            ))
238        }
239        Err(HookServerError(err)) => {
240            tracing::debug!(?err, "/authenticate: BAD_GATEWAY");
241            axum::response::IntoResponse::into_response((
242                axum::http::StatusCode::BAD_GATEWAY,
243                format!("BAD_GATEWAY: {err:?}"),
244            ))
245        }
246        Err(OtherError(err)) => {
247            tracing::warn!(?err, "/authenticate: INTERNAL_SERVER_ERROR");
248            axum::response::IntoResponse::into_response((
249                axum::http::StatusCode::INTERNAL_SERVER_ERROR,
250                format!("INTERNAL_SERVER_ERROR: {err:?}"),
251            ))
252        }
253    }
254}
255
256/// Authenticate token error type.
257pub enum AuthenticateTokenError {
258    /// The token is invalid.
259    Unauthorized,
260    /// We had an error talking to the hook server.
261    HookServerError(Error),
262    /// We had an internal error.
263    OtherError(Error),
264}
265
266/// Handle receiving a PUT "/authenticate" rest api request.
267pub async fn process_authenticate_token(
268    config: &Config,
269    token_tracker: &AuthTokenTracker,
270    auth_failures: opentelemetry::metrics::Counter<u64>,
271    auth_material: bytes::Bytes,
272) -> std::result::Result<Arc<str>, AuthenticateTokenError> {
273    use AuthenticateTokenError::*;
274
275    let token: Arc<str> = if let Some(url) = &config.authentication_hook_server
276    {
277        // if a hook server is configured, forward the call to it
278
279        let url = url.clone();
280        let token = tokio::task::spawn_blocking(move || {
281            ureq::put(&url)
282                .header("Content-Type", "application/octet-stream")
283                .send(&auth_material[..])
284                .map_err(|err| {
285                    auth_failures.add(1, &[]);
286
287                    match err {
288                        ureq::Error::StatusCode(401) => Unauthorized,
289                        oth => HookServerError(Error::other(oth)),
290                    }
291                })?
292                .into_body()
293                .read_to_string()
294                .map_err(Error::other)
295                // this is a HookServerError, not an OtherError, because
296                // it is the hook server that either failed to send a full
297                // response, or sent back non-utf8 bytes, etc...
298                .map_err(HookServerError)
299        })
300        .await
301        .map_err(|_| OtherError(Error::other("tokio task died")))??;
302
303        #[derive(serde::Deserialize)]
304        #[serde(rename_all = "camelCase")]
305        struct Token {
306            auth_token: String,
307        }
308
309        let token: Token = serde_json::from_str(&token)
310            .map_err(|err| OtherError(Error::other(err)))?;
311
312        token.auth_token
313    } else {
314        // If no hook server is configured, fallback to gen random token
315
316        use base64::prelude::*;
317        use rand::Rng;
318
319        let mut bytes = [0; 32];
320        rand::thread_rng().fill(&mut bytes);
321        BASE64_URL_SAFE_NO_PAD.encode(&bytes[..])
322    }
323    .into();
324
325    // register the token with our authentication token tracker
326    token_tracker.register_token(token.clone());
327
328    Ok(token)
329}
330
331/// Implement the ability to use axum websockets as our websocket backend.
332#[derive(Clone)]
333struct WebsocketImpl {
334    write: Arc<
335        tokio::sync::Mutex<
336            futures::stream::SplitSink<
337                axum::extract::ws::WebSocket,
338                axum::extract::ws::Message,
339            >,
340        >,
341    >,
342    read: Arc<
343        tokio::sync::Mutex<
344            futures::stream::SplitStream<axum::extract::ws::WebSocket>,
345        >,
346    >,
347    attr: Vec<opentelemetry::KeyValue>,
348    bytes_send: opentelemetry::metrics::Counter<u64>,
349    bytes_recv: opentelemetry::metrics::Counter<u64>,
350}
351
352impl SbdWebsocket for WebsocketImpl {
353    fn recv(&self) -> futures::future::BoxFuture<'static, Result<Payload>> {
354        let this = self.clone();
355        Box::pin(async move {
356            let mut read = this.read.lock().await;
357            use futures::stream::StreamExt;
358            loop {
359                match read.next().await {
360                    None => return Err(Error::other("closed")),
361                    Some(r) => {
362                        let msg = r.map_err(Error::other)?;
363                        match msg {
364                            axum::extract::ws::Message::Text(s) => {
365                                this.bytes_recv.add(s.len() as u64, &this.attr);
366                                return Ok(Payload::Vec(s.as_bytes().to_vec()));
367                            }
368                            axum::extract::ws::Message::Binary(v) => {
369                                this.bytes_recv.add(v.len() as u64, &this.attr);
370                                return Ok(Payload::Vec(v[..].to_vec()));
371                            }
372                            axum::extract::ws::Message::Ping(_)
373                            | axum::extract::ws::Message::Pong(_) => (),
374                            axum::extract::ws::Message::Close(_) => {
375                                return Err(Error::other("closed"))
376                            }
377                        }
378                    }
379                }
380            }
381        })
382    }
383
384    fn send(
385        &self,
386        payload: Payload,
387    ) -> futures::future::BoxFuture<'static, Result<()>> {
388        use futures::SinkExt;
389        let this = self.clone();
390        Box::pin(async move {
391            let mut write = this.write.lock().await;
392            let v = match payload {
393                Payload::Vec(v) => v,
394                Payload::BytesMut(b) => b.to_vec(),
395            };
396            this.bytes_send.add(v.len() as u64, &this.attr);
397            write
398                .send(axum::extract::ws::Message::Binary(
399                    bytes::Bytes::copy_from_slice(&v),
400                ))
401                .await
402                .map_err(Error::other)?;
403            write.flush().await.map_err(Error::other)?;
404            Ok(())
405        })
406    }
407
408    fn close(&self) -> futures::future::BoxFuture<'static, ()> {
409        use futures::SinkExt;
410        let this = self.clone();
411        Box::pin(async move {
412            let _ = this.write.lock().await.close().await;
413        })
414    }
415}
416
417impl WebsocketImpl {
418    fn new(
419        ws: axum::extract::ws::WebSocket,
420        pk: PubKey,
421        meter: &opentelemetry::metrics::Meter,
422    ) -> Self {
423        use futures::StreamExt;
424
425        let bytes_send = meter
426            .u64_counter("sbd.server.bytes_send")
427            .with_description("Number of bytes sent to client")
428            .with_unit("bytes")
429            .build();
430        let bytes_recv = meter
431            .u64_counter("sbd.server.bytes_recv")
432            .with_description("Number of bytes received from client")
433            .with_unit("bytes")
434            .build();
435
436        let (tx, rx) = ws.split();
437        Self {
438            write: Arc::new(tokio::sync::Mutex::new(tx)),
439            read: Arc::new(tokio::sync::Mutex::new(rx)),
440            attr: vec![opentelemetry::KeyValue::new(
441                "pub_key",
442                base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(*pk.0),
443            )],
444            bytes_send,
445            bytes_recv,
446        }
447    }
448}
449
450/// Handle the http upgrade request for a websocket connection.
451async fn handle_ws(
452    axum::extract::Path(pub_key): axum::extract::Path<String>,
453    headers: axum::http::HeaderMap,
454    ws: axum::extract::WebSocketUpgrade,
455    axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo<
456        std::net::SocketAddr,
457    >,
458    axum::extract::State(app_state): axum::extract::State<AppState>,
459) -> impl axum::response::IntoResponse {
460    use axum::response::IntoResponse;
461    use base64::Engine;
462
463    // first check for auth tokens
464    let token: Option<Arc<str>> = headers
465        .get("Authorization")
466        .and_then(|t| t.to_str().ok().map(<Arc<str>>::from));
467
468    let maybe_auth = Some((token.clone(), app_state.token_tracker.clone()));
469
470    // compare any passed tokens with our token authentication mechanism
471    if !app_state
472        .token_tracker
473        .check_is_token_valid(&app_state.config, token)
474    {
475        // Might be useful to record the IP address here, but it's special category data. To keep
476        // the server privacy-friendly, avoid exporting the IP address to the metrics even though
477        // this user did not authenticate properly.
478        app_state
479            .auth_failures
480            .add(1, &[opentelemetry::KeyValue::new("pub_key", pub_key)]);
481
482        return axum::response::IntoResponse::into_response((
483            axum::http::StatusCode::UNAUTHORIZED,
484            "Unauthorized",
485        ));
486    }
487
488    // get the primary key this user is claiming
489    let pk = match base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(pub_key) {
490        Ok(pk) if pk.len() == 32 => {
491            let mut sized_pk = [0; 32];
492            sized_pk.copy_from_slice(&pk);
493            PubKey(Arc::new(sized_pk))
494        }
495        _ => return axum::http::StatusCode::BAD_REQUEST.into_response(),
496    };
497
498    let mut calc_ip = to_canonical_ip(addr.ip());
499
500    // if we're using a trusted ip, parse that out of the header
501    if let Some(trusted_ip_header) = &app_state.config.trusted_ip_header {
502        if let Some(header) =
503            headers.get(trusted_ip_header).and_then(|h| h.to_str().ok())
504        {
505            if let Ok(ip) = header.parse::<IpAddr>() {
506                calc_ip = to_canonical_ip(ip);
507            }
508        }
509    }
510
511    // do the actual websocket upgrade
512    ws.max_message_size(MAX_MSG_BYTES as usize).on_upgrade(
513        move |socket| async move {
514            handle_upgraded(
515                app_state.config.clone(),
516                app_state.ip_rate.clone(),
517                app_state.cslot.clone(),
518                Arc::new(WebsocketImpl::new(
519                    socket,
520                    pk.clone(),
521                    &app_state.meter,
522                )),
523                pk,
524                calc_ip,
525                maybe_auth,
526            )
527            .await;
528        },
529    )
530}
531
532/// Utility for managing auth tokens.
533#[derive(Clone, Default)]
534pub struct AuthTokenTracker {
535    token_map: Arc<Mutex<HashMap<Arc<str>, std::time::Instant>>>,
536}
537
538impl AuthTokenTracker {
539    /// Register a token as valid.
540    pub fn register_token(&self, token: Arc<str>) {
541        self.token_map
542            .lock()
543            .unwrap()
544            .insert(token, std::time::Instant::now());
545    }
546
547    /// Check that a token is valid.
548    /// If so, mark it as recently used so it doesn't time out.
549    /// The "token" parameter should be direct from the http header
550    /// i.e. with the "Barer" include, like "Bearer base64".
551    /// This should be called with None as the token if no Authenticate
552    /// header was specified.
553    pub fn check_is_token_valid(
554        &self,
555        config: &Config,
556        token: Option<Arc<str>>,
557    ) -> bool {
558        let token: Arc<str> = if let Some(token) = token {
559            // If the client supplied a token, always validate it,
560            // even if no hook server was specified in the config.
561            if !token.starts_with("Bearer ") {
562                return false;
563            }
564            token.trim_start_matches("Bearer ").into()
565        } else if config.authentication_hook_server.is_none() {
566            // If the client did not supply a token, and we have no
567            // hook server configured, allow the request.
568            return true;
569        } else {
570            // We have no token, but one is required. Unauthorized.
571            return false;
572        };
573
574        let mut lock = self.token_map.lock().unwrap();
575
576        let idle_dur = config.idle_dur();
577
578        lock.retain(|_t, e| e.elapsed() < idle_dur);
579
580        if let std::collections::hash_map::Entry::Occupied(mut e) =
581            lock.entry(token)
582        {
583            e.insert(std::time::Instant::now());
584            true
585        } else {
586            false
587        }
588    }
589}
590
591#[derive(Clone)]
592struct AppState {
593    config: Arc<Config>,
594    token_tracker: AuthTokenTracker,
595    ip_rate: Arc<IpRate>,
596    cslot: WeakCSlot,
597    auth_failures: opentelemetry::metrics::Counter<u64>,
598    meter: opentelemetry::metrics::Meter,
599}
600
601impl AppState {
602    pub fn new(
603        config: Arc<Config>,
604        ip_rate: Arc<IpRate>,
605        cslot: WeakCSlot,
606        meter: opentelemetry::metrics::Meter,
607    ) -> Self {
608        Self {
609            config,
610            token_tracker: AuthTokenTracker::default(),
611            ip_rate,
612            cslot,
613            auth_failures: meter
614                .u64_counter("sbd.server.auth_failures")
615                .with_description("Number of failed authentication attempts")
616                .with_unit("count")
617                .build(),
618            meter,
619        }
620    }
621}
622
623impl SbdServer {
624    /// Construct a new running sbd server with the provided config.
625    pub async fn new(config: Arc<Config>) -> Result<Self> {
626        let tls_config = if let (Some(cert), Some(pk)) =
627            (&config.cert_pem_file, &config.priv_key_pem_file)
628        {
629            Some(Arc::new(TlsConfig::new(cert, pk).await?))
630        } else {
631            None
632        };
633
634        let sbd_server_meter = global::meter("sbd-server");
635
636        let mut task_list = Vec::new();
637        let mut bind_addrs = Vec::new();
638
639        let ip_rate = Arc::new(IpRate::new(config.clone()));
640        task_list.push(spawn_prune_task(ip_rate.clone()));
641
642        let cslot = CSlot::new(
643            config.clone(),
644            ip_rate.clone(),
645            sbd_server_meter.clone(),
646        );
647        let weak_cslot = cslot.weak();
648
649        // setup the axum router
650        let app: axum::Router<()> = axum::Router::new()
651            .route("/authenticate", axum::routing::put(handle_auth))
652            .route("/{pub_key}", axum::routing::any(handle_ws))
653            .layer(axum::extract::DefaultBodyLimit::max(1024))
654            .with_state(AppState::new(
655                config.clone(),
656                ip_rate.clone(),
657                weak_cslot.clone(),
658                sbd_server_meter,
659            ));
660
661        let app =
662            app.into_make_service_with_connect_info::<std::net::SocketAddr>();
663
664        let mut found_port_zero: Option<u16> = None;
665
666        // bind to configured bindings
667        for bind in config.bind.iter() {
668            let mut a: std::net::SocketAddr =
669                bind.parse().map_err(Error::other)?;
670            if let Some(found_port_zero) = &found_port_zero {
671                if a.port() == 0 {
672                    a.set_port(*found_port_zero);
673                }
674            }
675
676            let h = axum_server::Handle::new();
677
678            if let Some(tls_config) = &tls_config {
679                let tls_config =
680                    axum_server::tls_rustls::RustlsConfig::from_config(
681                        tls_config.config(),
682                    );
683                let server = axum_server::bind_rustls(a, tls_config)
684                    .handle(h.clone())
685                    .serve(app.clone());
686                task_list.push(tokio::task::spawn(async move {
687                    if let Err(err) = server.await {
688                        tracing::error!(?err);
689                    }
690                }));
691            } else {
692                let server =
693                    axum_server::bind(a).handle(h.clone()).serve(app.clone());
694                task_list.push(tokio::task::spawn(async move {
695                    if let Err(err) = server.await {
696                        tracing::error!(?err);
697                    }
698                }));
699            }
700
701            if let Some(addr) = h.listening().await {
702                if found_port_zero.is_none() && a.port() == 0 {
703                    found_port_zero = Some(addr.port());
704                }
705                bind_addrs.push(addr);
706            }
707        }
708
709        Ok(Self {
710            task_list,
711            bind_addrs,
712            _cslot: cslot,
713        })
714    }
715
716    /// Get the list of addresses bound locally.
717    pub fn bind_addrs(&self) -> &[std::net::SocketAddr] {
718        self.bind_addrs.as_slice()
719    }
720}
721
722#[cfg(test)]
723mod test;