sbd_server/
lib.rs

1//! Sbd server library.
2#![deny(missing_docs)]
3
4/// defined by the sbd spec
5const MAX_MSG_BYTES: i32 = 20_000;
6
7use std::io::{Error, Result};
8use std::net::{IpAddr, Ipv6Addr};
9use std::sync::Arc;
10
11mod config;
12pub use config::*;
13
14mod maybe_tls;
15pub use maybe_tls::*;
16
17mod ip_deny;
18mod ip_rate;
19pub use ip_rate::*;
20
21mod cslot;
22pub use cslot::*;
23
24mod cmd;
25
26/// Websocket backend abstraction.
27pub mod ws {
28    /// Payload.
29    pub enum Payload {
30        /// Vec.
31        Vec(Vec<u8>),
32
33        /// BytesMut.
34        BytesMut(bytes::BytesMut),
35    }
36
37    impl std::ops::Deref for Payload {
38        type Target = [u8];
39
40        #[inline(always)]
41        fn deref(&self) -> &Self::Target {
42            match self {
43                Payload::Vec(v) => v.as_slice(),
44                Payload::BytesMut(b) => b.as_ref(),
45            }
46        }
47    }
48
49    impl Payload {
50        /// Mutable payload.
51        #[inline(always)]
52        pub fn to_mut(&mut self) -> &mut [u8] {
53            match self {
54                Payload::Vec(ref mut owned) => owned,
55                Payload::BytesMut(b) => b.as_mut(),
56            }
57        }
58    }
59
60    #[cfg(feature = "tungstenite")]
61    mod ws_tungstenite;
62
63    use futures::future::BoxFuture;
64    #[cfg(feature = "tungstenite")]
65    pub use ws_tungstenite::*;
66
67    #[cfg(all(not(feature = "tungstenite"), feature = "fastwebsockets"))]
68    mod ws_fastwebsockets;
69    #[cfg(all(not(feature = "tungstenite"), feature = "fastwebsockets"))]
70    pub use ws_fastwebsockets::*;
71
72    /// Websocket trait.
73    pub trait SbdWebsocket: Send + Sync + 'static {
74        /// Receive from the websocket.
75        fn recv(&self) -> BoxFuture<'static, std::io::Result<Payload>>;
76
77        /// Send to the websocket.
78        fn send(
79            &self,
80            payload: Payload,
81        ) -> BoxFuture<'static, std::io::Result<()>>;
82
83        /// Close the websocket.
84        fn close(&self) -> BoxFuture<'static, ()>;
85    }
86}
87
88pub use ws::{Payload, SbdWebsocket};
89
90/// Public key.
91#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
92pub struct PubKey(pub Arc<[u8; 32]>);
93
94impl PubKey {
95    /// Verify a signature with this pub key.
96    pub fn verify(&self, sig: &[u8; 64], data: &[u8]) -> bool {
97        use ed25519_dalek::Verifier;
98        if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&self.0) {
99            k.verify(data, &ed25519_dalek::Signature::from_bytes(sig))
100                .is_ok()
101        } else {
102            false
103        }
104    }
105}
106
107/// SbdServer.
108pub struct SbdServer {
109    task_list: Vec<tokio::task::JoinHandle<()>>,
110    bind_addrs: Vec<std::net::SocketAddr>,
111    _cslot: cslot::CSlot,
112}
113
114impl Drop for SbdServer {
115    fn drop(&mut self) {
116        for task in self.task_list.iter() {
117            task.abort();
118        }
119    }
120}
121
122/// Convert an IP address to an IPv6 address.
123pub fn to_canonical_ip(ip: IpAddr) -> Arc<Ipv6Addr> {
124    Arc::new(match ip {
125        IpAddr::V4(ip) => ip.to_ipv6_mapped(),
126        IpAddr::V6(ip) => ip,
127    })
128}
129
130/// If the check passes, the canonical IP is returned, otherwise None and the connection should be
131/// dropped.
132pub async fn preflight_ip_check(
133    config: &Config,
134    ip_rate: &IpRate,
135    addr: std::net::SocketAddr,
136) -> Option<Arc<Ipv6Addr>> {
137    let raw_ip = to_canonical_ip(addr.ip());
138
139    let use_trusted_ip = config.trusted_ip_header.is_some();
140
141    if !use_trusted_ip {
142        // Do this check BEFORE handshake to avoid extra
143        // server process when capable.
144        // If we *are* behind a reverse proxy, we assume
145        // some amount of DDoS mitigation is happening there
146        // and thus we can accept a little more process overhead
147        if ip_rate.is_blocked(&raw_ip).await {
148            return None;
149        }
150
151        // Also precheck our rate limit, using up one byte
152        if !ip_rate.is_ok(&raw_ip, 1).await {
153            return None;
154        }
155    }
156
157    Some(raw_ip)
158}
159
160/// Handle an upgraded websocket connection.
161pub async fn handle_upgraded(
162    config: Arc<Config>,
163    ip_rate: Arc<IpRate>,
164    weak_cslot: WeakCSlot,
165    ws: Arc<impl SbdWebsocket>,
166    pub_key: PubKey,
167    calc_ip: Arc<Ipv6Addr>,
168) {
169    let use_trusted_ip = config.trusted_ip_header.is_some();
170
171    // illegal pub key
172    if &pub_key.0[..28] == cmd::CMD_PREFIX {
173        return;
174    }
175
176    let ws = Arc::new(ws);
177
178    if use_trusted_ip {
179        // if using a trusted ip, check block here.
180        // see note above before the handshakes.
181        if ip_rate.is_blocked(&calc_ip).await {
182            return;
183        }
184
185        // Also precheck our rate limit, using up one byte
186        if !ip_rate.is_ok(&calc_ip, 1).await {
187            return;
188        }
189    }
190
191    if let Some(cslot) = weak_cslot.upgrade() {
192        cslot.insert(&config, calc_ip, pub_key, ws).await;
193    }
194}
195
196async fn check_accept_connection(
197    _connect_permit: tokio::sync::OwnedSemaphorePermit,
198    config: Arc<Config>,
199    tls_config: Option<Arc<TlsConfig>>,
200    ip_rate: Arc<IpRate>,
201    tcp: tokio::net::TcpStream,
202    addr: std::net::SocketAddr,
203    weak_cslot: WeakCSlot,
204) {
205    let _ = tokio::time::timeout(config.idle_dur(), async {
206        let Some(mut calc_ip) =
207            preflight_ip_check(&config, &ip_rate, addr).await
208        else {
209            return;
210        };
211
212        let socket = if let Some(tls_config) = &tls_config {
213            match MaybeTlsStream::tls(tls_config, tcp).await {
214                Err(_) => return,
215                Ok(tls) => tls,
216            }
217        } else {
218            MaybeTlsStream::Tcp(tcp)
219        };
220
221        let (ws, pub_key, ip) =
222            match ws::WebSocket::upgrade(config.clone(), socket).await {
223                Ok(r) => r,
224                Err(_) => return,
225            };
226
227        if let Some(ip) = ip {
228            calc_ip = Arc::new(ip);
229        }
230
231        handle_upgraded(
232            config,
233            ip_rate,
234            weak_cslot,
235            Arc::new(ws),
236            pub_key,
237            calc_ip,
238        )
239        .await;
240    })
241    .await;
242}
243
244async fn bind_all<I: IntoIterator<Item = std::net::SocketAddr>>(
245    i: I,
246) -> Vec<tokio::net::TcpListener> {
247    let mut listeners = Vec::new();
248    for a in i.into_iter() {
249        if let Ok(tcp) = tokio::net::TcpListener::bind(a).await {
250            listeners.push(tcp);
251        }
252    }
253    listeners
254}
255
256impl SbdServer {
257    /// Construct a new running sbd server with the provided config.
258    pub async fn new(config: Arc<Config>) -> Result<Self> {
259        let tls_config = if let (Some(cert), Some(pk)) =
260            (&config.cert_pem_file, &config.priv_key_pem_file)
261        {
262            Some(Arc::new(maybe_tls::TlsConfig::new(cert, pk).await?))
263        } else {
264            None
265        };
266
267        let mut task_list = Vec::new();
268        let mut bind_addrs = Vec::new();
269
270        let ip_rate = Arc::new(IpRate::new(config.clone()));
271        task_list.push(spawn_prune_task(ip_rate.clone()));
272
273        let cslot = CSlot::new(config.clone(), ip_rate.clone());
274
275        // limit the number of connections that can be "connecting" at a time.
276        // MAYBE make this configurable.
277        // Read this as a prioritization of existing connections over incoming
278        let connect_limit = Arc::new(tokio::sync::Semaphore::new(1024));
279
280        let mut bind_port_zero = Vec::new();
281        let mut bind_explicit_port = Vec::new();
282
283        for bind in config.bind.iter() {
284            let a: std::net::SocketAddr = bind.parse().map_err(Error::other)?;
285
286            if a.port() == 0 {
287                bind_port_zero.push(a);
288            } else {
289                bind_explicit_port.push(a);
290            }
291        }
292
293        let (mut listeners, mut l2) = tokio::join!(
294            async {
295                // bail if there are no zero port bindings
296                if bind_port_zero.is_empty() {
297                    return Vec::new();
298                }
299
300                // try twice to re-use port
301                'top: for _ in 0..2 {
302                    let mut listeners = Vec::new();
303
304                    let mut a_iter = bind_port_zero.iter();
305
306                    let a = a_iter.next().unwrap();
307                    if let Ok(tcp) = tokio::net::TcpListener::bind(a).await {
308                        let port = tcp.local_addr().unwrap().port();
309                        listeners.push(tcp);
310
311                        for a in a_iter {
312                            let mut a = *a;
313                            a.set_port(port);
314                            match tokio::net::TcpListener::bind(a).await {
315                                Err(_) => continue 'top,
316                                Ok(tcp) => listeners.push(tcp),
317                            }
318                        }
319
320                        return listeners;
321                    }
322                }
323
324                // just use whatever we can get
325                bind_all(bind_port_zero).await
326            },
327            async { bind_all(bind_explicit_port).await },
328        );
329
330        listeners.append(&mut l2);
331
332        if listeners.is_empty() {
333            return Err(Error::other("failed to bind any listeners"));
334        }
335
336        let weak_cslot = cslot.weak();
337        for tcp in listeners {
338            bind_addrs.push(tcp.local_addr()?);
339
340            let tls_config = tls_config.clone();
341            let connect_limit = connect_limit.clone();
342            let config = config.clone();
343            let weak_cslot = weak_cslot.clone();
344            let ip_rate = ip_rate.clone();
345            task_list.push(tokio::task::spawn(async move {
346                loop {
347                    if let Ok((tcp, addr)) = tcp.accept().await {
348                        // Drop connections as fast as possible
349                        // if we are overloaded on accepting connections.
350                        let connect_permit =
351                            match connect_limit.clone().try_acquire_owned() {
352                                Ok(permit) => permit,
353                                _ => continue,
354                            };
355
356                        // just let this task die on its own time
357                        // MAYBE preallocate these tasks like cslot
358                        tokio::task::spawn(check_accept_connection(
359                            connect_permit,
360                            config.clone(),
361                            tls_config.clone(),
362                            ip_rate.clone(),
363                            tcp,
364                            addr,
365                            weak_cslot.clone(),
366                        ));
367                    }
368                }
369            }));
370        }
371
372        Ok(Self {
373            task_list,
374            bind_addrs,
375            _cslot: cslot,
376        })
377    }
378
379    /// Get the list of addresses bound locally.
380    pub fn bind_addrs(&self) -> &[std::net::SocketAddr] {
381        self.bind_addrs.as_slice()
382    }
383}
384
385#[cfg(test)]
386mod test;