sbd_server/
lib.rs

1//! Sbd server library.
2#![deny(missing_docs)]
3
4/// defined by the sbd spec
5const MAX_MSG_BYTES: i32 = 20_000;
6
7use std::collections::HashMap;
8use std::io::{Error, Result};
9use std::net::{IpAddr, Ipv6Addr};
10use std::sync::{Arc, Mutex};
11
12mod config;
13pub use config::*;
14
15mod maybe_tls;
16pub use maybe_tls::*;
17
18mod ip_deny;
19mod ip_rate;
20pub use ip_rate::*;
21
22mod cslot;
23pub use cslot::*;
24
25mod cmd;
26
27/// Websocket backend abstraction.
28pub mod ws {
29    /// Payload.
30    pub enum Payload {
31        /// Vec.
32        Vec(Vec<u8>),
33
34        /// BytesMut.
35        BytesMut(bytes::BytesMut),
36    }
37
38    impl std::ops::Deref for Payload {
39        type Target = [u8];
40
41        #[inline(always)]
42        fn deref(&self) -> &Self::Target {
43            match self {
44                Payload::Vec(v) => v.as_slice(),
45                Payload::BytesMut(b) => b.as_ref(),
46            }
47        }
48    }
49
50    impl Payload {
51        /// Mutable payload.
52        #[inline(always)]
53        pub fn to_mut(&mut self) -> &mut [u8] {
54            match self {
55                Payload::Vec(ref mut owned) => owned,
56                Payload::BytesMut(b) => b.as_mut(),
57            }
58        }
59    }
60
61    use futures::future::BoxFuture;
62
63    /// Websocket trait.
64    pub trait SbdWebsocket: Send + Sync + 'static {
65        /// Receive from the websocket.
66        fn recv(&self) -> BoxFuture<'static, std::io::Result<Payload>>;
67
68        /// Send to the websocket.
69        fn send(
70            &self,
71            payload: Payload,
72        ) -> BoxFuture<'static, std::io::Result<()>>;
73
74        /// Close the websocket.
75        fn close(&self) -> BoxFuture<'static, ()>;
76    }
77}
78
79pub use ws::{Payload, SbdWebsocket};
80
81/// Public key.
82#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
83pub struct PubKey(pub Arc<[u8; 32]>);
84
85impl PubKey {
86    /// Verify a signature with this pub key.
87    pub fn verify(&self, sig: &[u8; 64], data: &[u8]) -> bool {
88        use ed25519_dalek::Verifier;
89        if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&self.0) {
90            k.verify(data, &ed25519_dalek::Signature::from_bytes(sig))
91                .is_ok()
92        } else {
93            false
94        }
95    }
96}
97
98/// SbdServer.
99pub struct SbdServer {
100    task_list: Vec<tokio::task::JoinHandle<()>>,
101    bind_addrs: Vec<std::net::SocketAddr>,
102    _cslot: cslot::CSlot,
103}
104
105impl Drop for SbdServer {
106    fn drop(&mut self) {
107        for task in self.task_list.iter() {
108            task.abort();
109        }
110    }
111}
112
113/// Convert an IP address to an IPv6 address.
114pub fn to_canonical_ip(ip: IpAddr) -> Arc<Ipv6Addr> {
115    Arc::new(match ip {
116        IpAddr::V4(ip) => ip.to_ipv6_mapped(),
117        IpAddr::V6(ip) => ip,
118    })
119}
120
121/// If the check passes, the canonical IP is returned, otherwise None and the connection should be
122/// dropped.
123pub async fn preflight_ip_check(
124    config: &Config,
125    ip_rate: &IpRate,
126    addr: std::net::SocketAddr,
127) -> Option<Arc<Ipv6Addr>> {
128    let raw_ip = to_canonical_ip(addr.ip());
129
130    let use_trusted_ip = config.trusted_ip_header.is_some();
131
132    if !use_trusted_ip {
133        // Do this check BEFORE handshake to avoid extra
134        // server process when capable.
135        // If we *are* behind a reverse proxy, we assume
136        // some amount of DDoS mitigation is happening there
137        // and thus we can accept a little more process overhead
138        if ip_rate.is_blocked(&raw_ip).await {
139            return None;
140        }
141
142        // Also precheck our rate limit, using up one byte
143        if !ip_rate.is_ok(&raw_ip, 1).await {
144            return None;
145        }
146    }
147
148    Some(raw_ip)
149}
150
151/// Handle an upgraded websocket connection.
152pub async fn handle_upgraded(
153    config: Arc<Config>,
154    ip_rate: Arc<IpRate>,
155    weak_cslot: WeakCSlot,
156    ws: Arc<impl SbdWebsocket>,
157    pub_key: PubKey,
158    calc_ip: Arc<Ipv6Addr>,
159    maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
160) {
161    let use_trusted_ip = config.trusted_ip_header.is_some();
162
163    // illegal pub key
164    if &pub_key.0[..28] == cmd::CMD_PREFIX {
165        return;
166    }
167
168    if use_trusted_ip {
169        // if using a trusted ip, check block here.
170        // see note above before the handshakes.
171        if ip_rate.is_blocked(&calc_ip).await {
172            return;
173        }
174
175        // Also precheck our rate limit, using up one byte
176        if !ip_rate.is_ok(&calc_ip, 1).await {
177            return;
178        }
179    }
180
181    if let Some(cslot) = weak_cslot.upgrade() {
182        cslot
183            .insert(&config, calc_ip, pub_key, ws, maybe_auth)
184            .await;
185    }
186}
187
188async fn handle_auth(
189    axum::extract::State(app_state): axum::extract::State<AppState>,
190    body: bytes::Bytes,
191) -> axum::response::Response {
192    use AuthenticateTokenError::*;
193
194    match process_authenticate_token(
195        &app_state.config,
196        &app_state.token_tracker,
197        body,
198    )
199    .await
200    {
201        Ok(token) => axum::response::IntoResponse::into_response(axum::Json(
202            serde_json::json!({
203                "authToken": *token,
204            }),
205        )),
206        Err(Unauthorized) => {
207            tracing::debug!("/authenticate: UNAUTHORIZED");
208            axum::response::IntoResponse::into_response((
209                axum::http::StatusCode::UNAUTHORIZED,
210                "Unauthorized",
211            ))
212        }
213        Err(HookServerError(err)) => {
214            tracing::debug!(?err, "/authenticate: BAD_GATEWAY");
215            axum::response::IntoResponse::into_response((
216                axum::http::StatusCode::BAD_GATEWAY,
217                format!("BAD_GATEWAY: {err:?}"),
218            ))
219        }
220        Err(OtherError(err)) => {
221            tracing::warn!(?err, "/authenticate: INTERNAL_SERVER_ERROR");
222            axum::response::IntoResponse::into_response((
223                axum::http::StatusCode::INTERNAL_SERVER_ERROR,
224                format!("INTERNAL_SERVER_ERROR: {err:?}"),
225            ))
226        }
227    }
228}
229
230/// Authenticate token error type.
231pub enum AuthenticateTokenError {
232    /// The token is invalid.
233    Unauthorized,
234    /// We had an error talking to the hook server.
235    HookServerError(std::io::Error),
236    /// We had an internal error.
237    OtherError(std::io::Error),
238}
239
240/// Handle receiving a PUT "/authenticate" rest api request.
241pub async fn process_authenticate_token(
242    config: &Config,
243    token_tracker: &AuthTokenTracker,
244    auth_material: bytes::Bytes,
245) -> std::result::Result<Arc<str>, AuthenticateTokenError> {
246    use AuthenticateTokenError::*;
247
248    let token: Arc<str> = if let Some(url) = &config.authentication_hook_server
249    {
250        let url = url.clone();
251        tokio::task::spawn_blocking(move || {
252            ureq::put(&url)
253                .set("Content-Type", "application/octet-stream")
254                .send(&auth_material[..])
255                .map_err(|err| HookServerError(std::io::Error::other(err)))?
256                .into_string()
257                // this is a HookServerError, not an OtherError,
258                // because it is the hook server that either failed
259                // to send a full response, or sent back non-utf8 bytes, etc...
260                .map_err(HookServerError)
261        })
262        .await
263        .map_err(|_| OtherError(std::io::Error::other("tokio task died")))??
264    } else {
265        // If no backend configured, fallback to gen random token:
266        use base64::prelude::*;
267        use rand::Rng;
268
269        let mut bytes = [0; 32];
270        rand::thread_rng().fill(&mut bytes);
271        BASE64_URL_SAFE_NO_PAD.encode(&bytes[..])
272    }
273    .into();
274
275    token_tracker.register_token(token.clone());
276
277    Ok(token)
278}
279
280#[derive(Clone)]
281struct WebsocketImpl {
282    write: Arc<
283        tokio::sync::Mutex<
284            futures::stream::SplitSink<
285                axum::extract::ws::WebSocket,
286                axum::extract::ws::Message,
287            >,
288        >,
289    >,
290    read: Arc<
291        tokio::sync::Mutex<
292            futures::stream::SplitStream<axum::extract::ws::WebSocket>,
293        >,
294    >,
295}
296
297impl SbdWebsocket for WebsocketImpl {
298    fn recv(&self) -> futures::future::BoxFuture<'static, Result<Payload>> {
299        let this = self.clone();
300        Box::pin(async move {
301            let mut read = this.read.lock().await;
302            use futures::stream::StreamExt;
303            loop {
304                match read.next().await {
305                    None => return Err(Error::other("closed")),
306                    Some(r) => {
307                        let msg = r.map_err(Error::other)?;
308                        match msg {
309                            axum::extract::ws::Message::Text(s) => {
310                                return Ok(Payload::Vec(s.as_bytes().to_vec()))
311                            }
312                            axum::extract::ws::Message::Binary(v) => {
313                                return Ok(Payload::Vec(v[..].to_vec()))
314                            }
315                            axum::extract::ws::Message::Ping(_)
316                            | axum::extract::ws::Message::Pong(_) => (),
317                            axum::extract::ws::Message::Close(_) => {
318                                return Err(Error::other("closed"))
319                            }
320                        }
321                    }
322                }
323            }
324        })
325    }
326
327    fn send(
328        &self,
329        payload: Payload,
330    ) -> futures::future::BoxFuture<'static, Result<()>> {
331        use futures::SinkExt;
332        let this = self.clone();
333        Box::pin(async move {
334            let mut write = this.write.lock().await;
335            let v = match payload {
336                Payload::Vec(v) => v,
337                Payload::BytesMut(b) => b.to_vec(),
338            };
339            write
340                .send(axum::extract::ws::Message::Binary(
341                    bytes::Bytes::copy_from_slice(&v),
342                ))
343                .await
344                .map_err(Error::other)?;
345            write.flush().await.map_err(Error::other)?;
346            Ok(())
347        })
348    }
349
350    fn close(&self) -> futures::future::BoxFuture<'static, ()> {
351        use futures::SinkExt;
352        let this = self.clone();
353        Box::pin(async move {
354            let _ = this.write.lock().await.close().await;
355        })
356    }
357}
358
359impl WebsocketImpl {
360    fn new(ws: axum::extract::ws::WebSocket) -> Self {
361        use futures::StreamExt;
362        let (tx, rx) = ws.split();
363        Self {
364            write: Arc::new(tokio::sync::Mutex::new(tx)),
365            read: Arc::new(tokio::sync::Mutex::new(rx)),
366        }
367    }
368}
369
370async fn handle_ws(
371    axum::extract::Path(pub_key): axum::extract::Path<String>,
372    headers: axum::http::HeaderMap,
373    ws: axum::extract::WebSocketUpgrade,
374    axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo<
375        std::net::SocketAddr,
376    >,
377    axum::extract::State(app_state): axum::extract::State<AppState>,
378) -> impl axum::response::IntoResponse {
379    use axum::response::IntoResponse;
380    use base64::Engine;
381
382    let token: Option<Arc<str>> = headers
383        .get("Authorization")
384        .and_then(|t| t.to_str().ok().map(<Arc<str>>::from));
385
386    let maybe_auth = Some((token.clone(), app_state.token_tracker.clone()));
387
388    if !app_state
389        .token_tracker
390        .check_is_token_valid(&app_state.config, token)
391    {
392        return axum::response::IntoResponse::into_response((
393            axum::http::StatusCode::UNAUTHORIZED,
394            "Unauthorized",
395        ));
396    }
397
398    let pk = match base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(pub_key) {
399        Ok(pk) if pk.len() == 32 => {
400            let mut sized_pk = [0; 32];
401            sized_pk.copy_from_slice(&pk);
402            PubKey(Arc::new(sized_pk))
403        }
404        _ => return axum::http::StatusCode::BAD_REQUEST.into_response(),
405    };
406
407    let mut calc_ip = to_canonical_ip(addr.ip());
408
409    if let Some(trusted_ip_header) = &app_state.config.trusted_ip_header {
410        if let Some(header) =
411            headers.get(trusted_ip_header).and_then(|h| h.to_str().ok())
412        {
413            if let Ok(ip) = header.parse::<IpAddr>() {
414                calc_ip = to_canonical_ip(ip);
415            }
416        }
417    }
418
419    ws.max_message_size(MAX_MSG_BYTES as usize).on_upgrade(
420        move |socket| async move {
421            handle_upgraded(
422                app_state.config.clone(),
423                app_state.ip_rate.clone(),
424                app_state.cslot.clone(),
425                Arc::new(WebsocketImpl::new(socket)),
426                pk,
427                calc_ip,
428                maybe_auth,
429            )
430            .await;
431        },
432    )
433}
434
435/// Utility for managing auth tokens.
436#[derive(Clone, Default)]
437pub struct AuthTokenTracker {
438    token_map: Arc<Mutex<HashMap<Arc<str>, std::time::Instant>>>,
439}
440
441impl AuthTokenTracker {
442    /// Register a token as valid.
443    pub fn register_token(&self, token: Arc<str>) {
444        self.token_map
445            .lock()
446            .unwrap()
447            .insert(token, std::time::Instant::now());
448    }
449
450    /// Check that a token is valid.
451    /// If so, mark it as recently used so it doesn't time out.
452    /// The "token" parameter should be direct from the http header
453    /// i.e. with the "Barer" include, like "Bearer base64".
454    /// This should be called with None as the token if no Authenticate
455    /// header was specified.
456    pub fn check_is_token_valid(
457        &self,
458        config: &Config,
459        token: Option<Arc<str>>,
460    ) -> bool {
461        let token: Arc<str> = if let Some(token) = token {
462            // If the client supplied a token, always validate it,
463            // even if no hook server was specified in the config.
464            if !token.starts_with("Bearer ") {
465                return false;
466            }
467            token.trim_start_matches("Bearer ").into()
468        } else if config.authentication_hook_server.is_none() {
469            // If the client did not supply a token, and we have no
470            // hook server configured, allow the request.
471            return true;
472        } else {
473            // We have no token, but one is required. Unauthorized.
474            return false;
475        };
476
477        let mut lock = self.token_map.lock().unwrap();
478
479        let idle_dur = config.idle_dur();
480
481        lock.retain(|_t, e| e.elapsed() < idle_dur);
482
483        if let std::collections::hash_map::Entry::Occupied(mut e) =
484            lock.entry(token)
485        {
486            e.insert(std::time::Instant::now());
487            true
488        } else {
489            false
490        }
491    }
492}
493
494#[derive(Clone)]
495struct AppState {
496    config: Arc<Config>,
497    token_tracker: AuthTokenTracker,
498    ip_rate: Arc<IpRate>,
499    cslot: WeakCSlot,
500}
501
502impl AppState {
503    pub fn new(
504        config: Arc<Config>,
505        ip_rate: Arc<IpRate>,
506        cslot: WeakCSlot,
507    ) -> Self {
508        Self {
509            config,
510            token_tracker: AuthTokenTracker::default(),
511            ip_rate,
512            cslot,
513        }
514    }
515}
516
517impl SbdServer {
518    /// Construct a new running sbd server with the provided config.
519    pub async fn new(config: Arc<Config>) -> Result<Self> {
520        let tls_config = if let (Some(cert), Some(pk)) =
521            (&config.cert_pem_file, &config.priv_key_pem_file)
522        {
523            Some(Arc::new(maybe_tls::TlsConfig::new(cert, pk).await?))
524        } else {
525            None
526        };
527
528        let mut task_list = Vec::new();
529        let mut bind_addrs = Vec::new();
530
531        let ip_rate = Arc::new(IpRate::new(config.clone()));
532        task_list.push(spawn_prune_task(ip_rate.clone()));
533
534        let cslot = CSlot::new(config.clone(), ip_rate.clone());
535        let weak_cslot = cslot.weak();
536
537        let app: axum::Router<()> = axum::Router::new()
538            .route("/authenticate", axum::routing::put(handle_auth))
539            .route("/{pub_key}", axum::routing::any(handle_ws))
540            .layer(axum::extract::DefaultBodyLimit::max(1024))
541            .with_state(AppState::new(
542                config.clone(),
543                ip_rate.clone(),
544                weak_cslot.clone(),
545            ));
546
547        let app =
548            app.into_make_service_with_connect_info::<std::net::SocketAddr>();
549
550        let mut found_port_zero: Option<u16> = None;
551
552        for bind in config.bind.iter() {
553            let mut a: std::net::SocketAddr =
554                bind.parse().map_err(Error::other)?;
555            if let Some(found_port_zero) = &found_port_zero {
556                if a.port() == 0 {
557                    a.set_port(*found_port_zero);
558                }
559            }
560
561            let h = axum_server::Handle::new();
562
563            if let Some(tls_config) = &tls_config {
564                let tls_config =
565                    axum_server::tls_rustls::RustlsConfig::from_config(
566                        tls_config.config(),
567                    );
568                let server = axum_server::bind_rustls(a, tls_config)
569                    .handle(h.clone())
570                    .serve(app.clone());
571                task_list.push(tokio::task::spawn(async move {
572                    if let Err(err) = server.await {
573                        tracing::error!(?err);
574                    }
575                }));
576            } else {
577                let server =
578                    axum_server::bind(a).handle(h.clone()).serve(app.clone());
579                task_list.push(tokio::task::spawn(async move {
580                    if let Err(err) = server.await {
581                        tracing::error!(?err);
582                    }
583                }));
584            }
585
586            if let Some(addr) = h.listening().await {
587                if found_port_zero.is_none() && a.port() == 0 {
588                    found_port_zero = Some(addr.port());
589                }
590                bind_addrs.push(addr);
591            }
592        }
593
594        Ok(Self {
595            task_list,
596            bind_addrs,
597            _cslot: cslot,
598        })
599    }
600
601    /// Get the list of addresses bound locally.
602    pub fn bind_addrs(&self) -> &[std::net::SocketAddr] {
603        self.bind_addrs.as_slice()
604    }
605}
606
607#[cfg(test)]
608mod test;