sbd_client/
lib.rs

1//! Sbd client library.
2#![deny(missing_docs)]
3
4use std::io::{Error, Result};
5use std::sync::Arc;
6
7/// defined by the sbd spec
8const MAX_MSG_SIZE: usize = 20_000;
9
10/// defined by ed25519 spec
11const PK_SIZE: usize = 32;
12
13/// defined by ed25519 spec
14const SIG_SIZE: usize = 64;
15
16/// sbd spec defines headers to be the same size as ed25519 pub keys
17const HDR_SIZE: usize = PK_SIZE;
18
19/// defined by sbd spec
20const NONCE_SIZE: usize = 32;
21
22/// defined by sbd spec
23const CMD_PREFIX: &[u8; 28] = &[0; 28];
24
25const F_LIMIT_BYTE_NANOS: &[u8] = b"lbrt";
26const F_LIMIT_IDLE_MILLIS: &[u8] = b"lidl";
27const F_AUTH_REQ: &[u8] = b"areq";
28const F_READY: &[u8] = b"srdy";
29
30#[cfg(feature = "raw_client")]
31pub mod raw_client;
32#[cfg(not(feature = "raw_client"))]
33mod raw_client;
34
35mod send_buf;
36
37/// Crypto to use. Note, the pair should be fresh for each new connection.
38pub trait Crypto {
39    /// The pubkey.
40    fn pub_key(&self) -> &[u8; PK_SIZE];
41
42    /// Sign the nonce.
43    fn sign(&self, nonce: &[u8]) -> Result<[u8; SIG_SIZE]>;
44}
45
46#[cfg(feature = "crypto")]
47mod default_crypto {
48    use super::*;
49
50    /// Default signer. Use a fresh one for every new connection.
51    pub struct DefaultCrypto([u8; PK_SIZE], ed25519_dalek::SigningKey);
52
53    impl Default for DefaultCrypto {
54        fn default() -> Self {
55            loop {
56                let k = ed25519_dalek::SigningKey::generate(
57                    &mut rand::thread_rng(),
58                );
59                let pk = k.verifying_key().to_bytes();
60                if &pk[..28] == CMD_PREFIX {
61                    continue;
62                } else {
63                    return Self(pk, k);
64                }
65            }
66        }
67    }
68
69    impl Crypto for DefaultCrypto {
70        fn pub_key(&self) -> &[u8; PK_SIZE] {
71            &self.0
72        }
73
74        fn sign(&self, nonce: &[u8]) -> Result<[u8; SIG_SIZE]> {
75            use ed25519_dalek::Signer;
76            Ok(self.1.sign(nonce).to_bytes())
77        }
78    }
79}
80#[cfg(feature = "crypto")]
81pub use default_crypto::*;
82
83/// Public key.
84#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
85pub struct PubKey(pub Arc<[u8; PK_SIZE]>);
86
87impl std::ops::Deref for PubKey {
88    type Target = [u8; 32];
89
90    fn deref(&self) -> &Self::Target {
91        &self.0
92    }
93}
94
95impl std::fmt::Debug for PubKey {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        use base64::Engine;
98        let out = base64::engine::general_purpose::URL_SAFE_NO_PAD
99            .encode(&self.0[..]);
100        f.write_str(&out)
101    }
102}
103
104enum MsgType<'t> {
105    Msg {
106        #[allow(dead_code)]
107        pub_key: &'t [u8],
108        #[allow(dead_code)]
109        message: &'t [u8],
110    },
111    LimitByteNanos(i32),
112    LimitIdleMillis(i32),
113    AuthReq(&'t [u8]),
114    Ready,
115    Unknown,
116}
117
118/// A message received from a remote.
119/// This is just a single buffer. The first 32 bytes are the public key
120/// of the sender, or 28 `0`s followed by a 4 byte command. Any remaining bytes are the message. The buffer
121/// contained in this type is guaranteed to be at least 32 bytes long.
122pub struct Msg(pub Vec<u8>);
123
124impl Msg {
125    /// Get a reference to the slice containing the pubkey data.
126    pub fn pub_key_ref(&self) -> &[u8] {
127        &self.0[..PK_SIZE]
128    }
129
130    /// Extract a pubkey from the message.
131    pub fn pub_key(&self) -> PubKey {
132        PubKey(Arc::new(self.0[..PK_SIZE].try_into().unwrap()))
133    }
134
135    /// Get a reference to the slice containing the message data.
136    pub fn message(&self) -> &[u8] {
137        &self.0[PK_SIZE..]
138    }
139
140    // -- private -- //
141
142    fn parse(&self) -> Result<MsgType<'_>> {
143        if self.0.len() < PK_SIZE {
144            return Err(Error::other("invalid message length"));
145        }
146        if &self.0[..28] == CMD_PREFIX {
147            match &self.0[28..HDR_SIZE] {
148                F_LIMIT_BYTE_NANOS => {
149                    if self.0.len() != HDR_SIZE + 4 {
150                        return Err(Error::other("invalid lbrt length"));
151                    }
152                    Ok(MsgType::LimitByteNanos(i32::from_be_bytes(
153                        self.0[PK_SIZE..].try_into().unwrap(),
154                    )))
155                }
156                F_LIMIT_IDLE_MILLIS => {
157                    if self.0.len() != HDR_SIZE + 4 {
158                        return Err(Error::other("invalid lidl length"));
159                    }
160                    Ok(MsgType::LimitIdleMillis(i32::from_be_bytes(
161                        self.0[HDR_SIZE..].try_into().unwrap(),
162                    )))
163                }
164                F_AUTH_REQ => {
165                    if self.0.len() != HDR_SIZE + NONCE_SIZE {
166                        return Err(Error::other("invalid areq length"));
167                    }
168                    Ok(MsgType::AuthReq(&self.0[HDR_SIZE..]))
169                }
170                F_READY => Ok(MsgType::Ready),
171                _ => Ok(MsgType::Unknown),
172            }
173        } else {
174            Ok(MsgType::Msg {
175                pub_key: &self.0[..PK_SIZE],
176                message: &self.0[PK_SIZE..],
177            })
178        }
179    }
180}
181
182/// Handle to receive data from the sbd connection.
183pub struct MsgRecv(tokio::sync::mpsc::Receiver<Msg>);
184
185impl MsgRecv {
186    /// Receive data from the sbd connection.
187    pub async fn recv(&mut self) -> Option<Msg> {
188        self.0.recv().await
189    }
190}
191
192/// Configuration for connecting an SbdClient.
193#[derive(Clone)]
194pub struct SbdClientConfig {
195    /// Outgoing message buffer size.
196    pub out_buffer_size: usize,
197
198    /// Setting this to `true` allows `ws://` scheme.
199    pub allow_plain_text: bool,
200
201    /// Setting this to `true` disables certificate verification on `wss://`
202    /// scheme. WARNING: this is a dangerous configuration and should not
203    /// be used outside of testing (i.e. self-signed tls certificates).
204    pub danger_disable_certificate_check: bool,
205
206    /// Set any custom http headers to send with the websocket connect.
207    pub headers: Vec<(String, String)>,
208
209    /// If you must pass authentication material to the sbd server,
210    /// specify it here.
211    pub auth_material: Option<Vec<u8>>,
212}
213
214impl Default for SbdClientConfig {
215    fn default() -> Self {
216        Self {
217            out_buffer_size: MAX_MSG_SIZE * 8,
218            allow_plain_text: false,
219            danger_disable_certificate_check: false,
220            headers: Vec::new(),
221            auth_material: None,
222        }
223    }
224}
225
226/// SbdClient represents a single connection to a single sbd server
227/// through which we can communicate with any number of peers on that server.
228pub struct SbdClient {
229    url: String,
230    pub_key: PubKey,
231    send_buf: Arc<tokio::sync::Mutex<send_buf::SendBuf>>,
232    read_task: tokio::task::JoinHandle<()>,
233    write_task: tokio::task::JoinHandle<()>,
234}
235
236impl Drop for SbdClient {
237    fn drop(&mut self) {
238        self.read_task.abort();
239        self.write_task.abort();
240    }
241}
242
243impl SbdClient {
244    /// Connect to the remote sbd server.
245    pub async fn connect<C: Crypto>(
246        url: &str,
247        crypto: &C,
248    ) -> Result<(Self, MsgRecv)> {
249        Self::connect_config(url, crypto, SbdClientConfig::default()).await
250    }
251
252    /// Connect to the remote sbd server.
253    pub async fn connect_config<C: Crypto>(
254        url: &str,
255        crypto: &C,
256        config: SbdClientConfig,
257    ) -> Result<(Self, MsgRecv)> {
258        use base64::Engine;
259        let full_url = format!(
260            "{url}/{}",
261            base64::engine::general_purpose::URL_SAFE_NO_PAD
262                .encode(crypto.pub_key())
263        );
264
265        // establish a "raw" low-level websocket connection to the server
266        let (mut send, mut recv) = raw_client::WsRawConnect {
267            full_url: full_url.clone(),
268            max_message_size: MAX_MSG_SIZE,
269            allow_plain_text: config.allow_plain_text,
270            danger_disable_certificate_check: config
271                .danger_disable_certificate_check,
272            headers: config.headers,
273            auth_material: config.auth_material,
274            alter_token_cb: None,
275        }
276        .connect()
277        .await?;
278
279        // performing the initial handshake authenticates us as a client
280        // and returns some server configuration values
281        let raw_client::Handshake {
282            limit_byte_nanos,
283            limit_idle_millis,
284            bytes_sent,
285        } = raw_client::Handshake::handshake(&mut send, &mut recv, crypto)
286            .await?;
287
288        // SendBuf helps us track rate-limiting so we don't ban ourselves
289        let send_buf = send_buf::SendBuf::new(
290            full_url.clone(),
291            send,
292            config.out_buffer_size,
293            (limit_byte_nanos as f64 * 1.1) as u64,
294            std::time::Duration::from_millis((limit_idle_millis / 2) as u64),
295            bytes_sent,
296        );
297        let send_buf = Arc::new(tokio::sync::Mutex::new(send_buf));
298
299        // spawn the read task that reads from the websocket connection
300        let send_buf2 = send_buf.clone();
301        let (recv_send, recv_recv) = tokio::sync::mpsc::channel(4);
302        let read_task = tokio::task::spawn(async move {
303            while let Ok(data) = recv.recv().await {
304                let data = Msg(data);
305
306                match match data.parse() {
307                    Ok(data) => data,
308                    Err(_) => break,
309                } {
310                    MsgType::Msg { .. } => {
311                        // we got a message from someone, forward to user
312                        if recv_send.send(data).await.is_err() {
313                            break;
314                        }
315                    }
316                    MsgType::LimitByteNanos(rate) => {
317                        // the server is reconfiguring the ratelimiting
318                        send_buf2
319                            .lock()
320                            .await
321                            .new_rate_limit((rate as f64 * 1.1) as u64);
322                    }
323                    // idle messages should not be sent at this stage
324                    MsgType::LimitIdleMillis(_) => break,
325                    // authorization requests should not be sent at this stage
326                    MsgType::AuthReq(_) => break,
327                    // we can safely ignore late readys
328                    MsgType::Ready => (),
329                    // ignore all protocol messages we don't understand
330                    MsgType::Unknown => (),
331                }
332            }
333
334            send_buf2.lock().await.close().await;
335        });
336
337        // spawn the write task that sends data respecting rate limits
338        let send_buf2 = send_buf.clone();
339        let write_task = tokio::task::spawn(async move {
340            loop {
341                // wait, if required by ratelimiting
342                if let Some(dur) = send_buf2.lock().await.next_step_dur() {
343                    tokio::time::sleep(dur).await;
344                }
345
346                match send_buf2.lock().await.write_next_queued().await {
347                    Err(_) => break,
348                    // send_buf was able to send something, loop again
349                    Ok(true) => (),
350                    // send_buf failed to do anything, we need a short
351                    // delay before we try looping again to avoid a busy wait
352                    Ok(false) => {
353                        tokio::time::sleep(std::time::Duration::from_millis(
354                            10,
355                        ))
356                        .await;
357                    }
358                }
359            }
360
361            send_buf2.lock().await.close().await;
362        });
363
364        let pub_key = PubKey(Arc::new(*crypto.pub_key()));
365
366        let this = Self {
367            url: full_url,
368            pub_key,
369            send_buf,
370            read_task,
371            write_task,
372        };
373
374        Ok((this, MsgRecv(recv_recv)))
375    }
376
377    /// The full url of this client.
378    pub fn url(&self) -> &str {
379        &self.url
380    }
381
382    /// The pub key of this client.
383    pub fn pub_key(&self) -> &PubKey {
384        &self.pub_key
385    }
386
387    /// Close the connection.
388    pub async fn close(&self) {
389        self.send_buf.lock().await.close().await;
390    }
391
392    /// Send a message to a peer.
393    pub async fn send(&self, peer: &PubKey, data: &[u8]) -> Result<()> {
394        self.send_buf.lock().await.send(peer, data).await
395    }
396}
397
398#[cfg(test)]
399mod test;