sbd_client/
raw_client.rs

1//! `feature = "raw_client"` Raw websocket interaction types.
2
3use crate::*;
4
5/// Alter token callback function signature.
6pub type AlterTokenCb =
7    Arc<dyn Fn(Arc<str>) -> Arc<str> + 'static + Send + Sync>;
8
9/// Connection info for creating a raw websocket connection.
10pub struct WsRawConnect {
11    /// The full url including the pubkey path parameter.
12    pub full_url: String,
13
14    /// The maximum message size. If a message is larger than this
15    /// the connection will be closed.
16    pub max_message_size: usize,
17
18    /// Setting this to `true` allows `ws://` scheme.
19    pub allow_plain_text: bool,
20
21    /// Setting this to `true` disables certificate verification on `wss://`
22    /// scheme. WARNING: this is a dangerous configuration and should not
23    /// be used outside of testing (i.e. self-signed tls certificates).
24    pub danger_disable_certificate_check: bool,
25
26    /// Set any custom http headers to send with the websocket connect.
27    pub headers: Vec<(String, String)>,
28
29    /// If you must pass authentication material to the sbd server,
30    /// specify it here.
31    pub auth_material: Option<Vec<u8>>,
32
33    /// This is mostly a test api, but since we need to use it outside
34    /// this crate, it is available for anyone using the "raw_client" feature.
35    /// Allows altering the token post-receive so we can send bad ones.
36    pub alter_token_cb: Option<AlterTokenCb>,
37}
38
39impl WsRawConnect {
40    /// Establish the websocket connection.
41    pub async fn connect(self) -> Result<(WsRawSend, WsRawRecv)> {
42        let Self {
43            full_url,
44            max_message_size,
45            allow_plain_text,
46            danger_disable_certificate_check,
47            headers,
48            auth_material,
49            alter_token_cb,
50        } = self;
51
52        // convert the url into a request
53        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
54        let mut request =
55            IntoClientRequest::into_client_request(full_url.clone())
56                .map_err(Error::other)?;
57
58        // set any headers we are configured with
59        for (k, v) in headers {
60            use tokio_tungstenite::tungstenite::http::header::*;
61            let k =
62                HeaderName::from_bytes(k.as_bytes()).map_err(Error::other)?;
63            let v =
64                HeaderValue::from_bytes(v.as_bytes()).map_err(Error::other)?;
65            request.headers_mut().insert(k, v);
66        }
67
68        // if we have auth_material, we need to authenticate
69        if let Some(auth_material) = auth_material {
70            // figure out the authenticate endpoint url
71            let mut auth_url =
72                url::Url::parse(&full_url).map_err(Error::other)?;
73            auth_url.set_path("/authenticate");
74            match auth_url.scheme() {
75                "ws" => {
76                    let _ = auth_url.set_scheme("http");
77                }
78                "wss" => {
79                    let _ = auth_url.set_scheme("https");
80                }
81                _ => (),
82            }
83
84            // request a token from the /authenticate endpoint
85            let token = tokio::task::spawn_blocking(move || {
86                ureq::put(auth_url.as_str())
87                    .send(&auth_material[..])
88                    .map_err(Error::other)?
89                    .into_body()
90                    .read_to_string()
91                    .map_err(Error::other)
92            })
93            .await??;
94
95            // parse out the token
96            #[derive(serde::Deserialize)]
97            #[serde(rename_all = "camelCase")]
98            struct Token {
99                auth_token: Arc<str>,
100            }
101
102            let token: Token =
103                serde_json::from_str(&token).map_err(Error::other)?;
104            let token = token.auth_token;
105
106            let token = if let Some(cb) = alter_token_cb {
107                // hook to allow token alterations
108                cb(token)
109            } else {
110                token
111            };
112
113            // finally add our token to the request headers
114            use tokio_tungstenite::tungstenite::http::header::*;
115            let v =
116                HeaderValue::from_bytes(format!("Bearer {token}").as_bytes())
117                    .map_err(Error::other)?;
118            request.headers_mut().insert("Authorization", v);
119        };
120
121        let scheme_ws = request.uri().scheme_str() == Some("ws");
122        let scheme_wss = request.uri().scheme_str() == Some("wss");
123
124        if !scheme_ws && !scheme_wss {
125            return Err(Error::other("scheme must be ws:// or wss://"));
126        }
127
128        if !allow_plain_text && scheme_ws {
129            return Err(Error::other("plain text scheme not allowed"));
130        }
131
132        let host = match request.uri().host() {
133            Some(host) => host.to_string(),
134            None => return Err(Error::other("invalid url")),
135        };
136        let port = request.uri().port_u16().unwrap_or({
137            if scheme_ws {
138                80
139            } else {
140                443
141            }
142        });
143
144        // open the tcp connection
145        let tcp =
146            tokio::net::TcpStream::connect(format!("{host}:{port}")).await?;
147
148        // optionally layer on TLS
149        let maybe_tls = if scheme_ws {
150            tokio_tungstenite::MaybeTlsStream::Plain(tcp)
151        } else {
152            let tls = priv_system_tls(danger_disable_certificate_check);
153            let name = host
154                .try_into()
155                .unwrap_or_else(|_| "sbd".try_into().unwrap());
156            let tls = tokio_rustls::TlsConnector::from(tls)
157                .connect(name, tcp)
158                .await?;
159
160            tokio_tungstenite::MaybeTlsStream::Rustls(tls)
161        };
162
163        // set some default websocket config
164        let config =
165            tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default(
166            )
167            .max_message_size(Some(max_message_size));
168
169        // establish the connection
170        let (ws, _res) = tokio_tungstenite::client_async_with_config(
171            request,
172            maybe_tls,
173            Some(config),
174        )
175        .await
176        .map_err(Error::other)?;
177
178        // split for parallel send and recv
179        let (send, recv) = futures::stream::StreamExt::split(ws);
180
181        Ok((WsRawSend { send }, WsRawRecv { recv }))
182    }
183}
184
185use tokio_tungstenite::tungstenite::protocol::Message;
186type MaybeTls = tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>;
187type Ws = tokio_tungstenite::WebSocketStream<MaybeTls>;
188type WsSend = futures::stream::SplitSink<Ws, Message>;
189type WsRecv = futures::stream::SplitStream<Ws>;
190
191/// The send half of the websocket connection.
192pub struct WsRawSend {
193    send: WsSend,
194}
195
196impl WsRawSend {
197    /// Send data over the websocket.
198    pub async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
199        use futures::sink::SinkExt;
200        self.send
201            .send(Message::binary(msg))
202            .await
203            .map_err(Error::other)?;
204        self.send.flush().await.map_err(Error::other)?;
205        Ok(())
206    }
207
208    /// Close the connection.
209    pub async fn close(&mut self) {
210        use futures::sink::SinkExt;
211        let _ = self.send.close().await;
212    }
213}
214
215/// The receive half of the websocket connection.
216pub struct WsRawRecv {
217    recv: WsRecv,
218}
219
220impl WsRawRecv {
221    /// Receive from the websocket.
222    pub async fn recv(&mut self) -> Result<Vec<u8>> {
223        use futures::stream::StreamExt;
224        use tokio_tungstenite::tungstenite::protocol::Message::*;
225        loop {
226            match self.recv.next().await {
227                None => return Err(Error::other("closed")),
228                Some(r) => {
229                    let msg = r.map_err(Error::other)?;
230                    match msg {
231                        // convert text into binary
232                        Text(s) => return Ok(s.as_bytes().to_vec()),
233                        // use binary directly
234                        Binary(v) => return Ok(v.to_vec()),
235                        // ignoring server ping/pong for now
236                        Ping(_) | Pong(_) => (),
237                        Close(_) => return Err(Error::other("closed")),
238                        // we are not configured to receive raw frames
239                        Frame(_) => unreachable!(),
240                    }
241                }
242            }
243        }
244    }
245}
246
247/// Process the standard sbd handshake from the client side.
248pub struct Handshake {
249    /// limit_byte_nanos.
250    pub limit_byte_nanos: i32,
251
252    /// limit_idle_millis.
253    pub limit_idle_millis: i32,
254
255    /// bytes sent.
256    pub bytes_sent: usize,
257}
258
259impl Handshake {
260    /// Process the standard sbd handshake from the client side.
261    pub async fn handshake<C: Crypto>(
262        send: &mut WsRawSend,
263        recv: &mut WsRawRecv,
264        crypto: &C,
265    ) -> Result<Self> {
266        let mut limit_byte_nanos = 8000;
267        let mut limit_idle_millis = 10_000;
268        let mut bytes_sent = 0;
269
270        loop {
271            match Msg(recv.recv().await?).parse()? {
272                MsgType::Msg { .. } => {
273                    // we are not authenticated yet, we should not get msgs
274                    return Err(Error::other("invalid handshake"));
275                }
276                // receive server rate limit
277                MsgType::LimitByteNanos(l) => limit_byte_nanos = l,
278                // receive server idle timeout
279                MsgType::LimitIdleMillis(l) => limit_idle_millis = l,
280                // process the authorization request
281                MsgType::AuthReq(nonce) => {
282                    let sig = crypto.sign(nonce)?;
283                    let mut auth_res = Vec::with_capacity(HDR_SIZE + SIG_SIZE);
284                    auth_res.extend_from_slice(CMD_PREFIX);
285                    auth_res.extend_from_slice(b"ares");
286                    auth_res.extend_from_slice(&sig);
287                    send.send(auth_res).await?;
288                    bytes_sent += HDR_SIZE + SIG_SIZE;
289                }
290                // hey! handshake is successful
291                MsgType::Ready => break,
292                MsgType::Unknown => (),
293            }
294        }
295
296        Ok(Self {
297            limit_byte_nanos,
298            limit_idle_millis,
299            bytes_sent,
300        })
301    }
302}
303
304fn priv_system_tls(
305    danger_disable_certificate_check: bool,
306) -> Arc<rustls::ClientConfig> {
307    let mut roots = rustls::RootCertStore::empty();
308
309    #[cfg(any(
310        feature = "force_webpki_roots",
311        not(any(
312            target_os = "windows",
313            target_os = "linux",
314            target_os = "macos",
315        )),
316    ))]
317    {
318        roots.roots = webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect();
319    }
320
321    #[cfg(all(
322        not(feature = "force_webpki_roots"),
323        any(target_os = "windows", target_os = "linux", target_os = "macos",),
324    ))]
325    roots.add_parsable_certificates(
326        rustls_native_certs::load_native_certs().certs,
327    );
328
329    if danger_disable_certificate_check {
330        let v = rustls::client::WebPkiServerVerifier::builder(Arc::new(roots))
331            .build()
332            .unwrap();
333
334        Arc::new(
335            rustls::ClientConfig::builder()
336                .dangerous()
337                .with_custom_certificate_verifier(Arc::new(V(v)))
338                .with_no_client_auth(),
339        )
340    } else {
341        Arc::new(
342            rustls::ClientConfig::builder()
343                .with_root_certificates(roots)
344                .with_no_client_auth(),
345        )
346    }
347}
348
349#[derive(Debug)]
350struct V(Arc<rustls::client::WebPkiServerVerifier>);
351
352impl rustls::client::danger::ServerCertVerifier for V {
353    fn verify_server_cert(
354        &self,
355        _end_entity: &rustls::pki_types::CertificateDer<'_>,
356        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
357        _server_name: &rustls::pki_types::ServerName<'_>,
358        _ocsp_response: &[u8],
359        _now: rustls::pki_types::UnixTime,
360    ) -> std::result::Result<
361        rustls::client::danger::ServerCertVerified,
362        rustls::Error,
363    > {
364        Ok(rustls::client::danger::ServerCertVerified::assertion())
365    }
366    fn verify_tls12_signature(
367        &self,
368        _message: &[u8],
369        _cert: &rustls::pki_types::CertificateDer<'_>,
370        _dss: &rustls::DigitallySignedStruct,
371    ) -> std::result::Result<
372        rustls::client::danger::HandshakeSignatureValid,
373        rustls::Error,
374    > {
375        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
376    }
377    fn verify_tls13_signature(
378        &self,
379        _message: &[u8],
380        _cert: &rustls::pki_types::CertificateDer<'_>,
381        _dss: &rustls::DigitallySignedStruct,
382    ) -> std::result::Result<
383        rustls::client::danger::HandshakeSignatureValid,
384        rustls::Error,
385    > {
386        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
387    }
388    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
389        self.0.supported_verify_schemes()
390    }
391}