1use crate::*;
4
5pub type AlterTokenCb =
7 Arc<dyn Fn(Arc<str>) -> Arc<str> + 'static + Send + Sync>;
8
9pub struct WsRawConnect {
11 pub full_url: String,
13
14 pub max_message_size: usize,
17
18 pub allow_plain_text: bool,
20
21 pub danger_disable_certificate_check: bool,
25
26 pub headers: Vec<(String, String)>,
28
29 pub auth_material: Option<Vec<u8>>,
32
33 pub alter_token_cb: Option<AlterTokenCb>,
37}
38
39impl WsRawConnect {
40 pub async fn connect(self) -> Result<(WsRawSend, WsRawRecv)> {
42 let Self {
43 full_url,
44 max_message_size,
45 allow_plain_text,
46 danger_disable_certificate_check,
47 headers,
48 auth_material,
49 alter_token_cb,
50 } = self;
51
52 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
54 let mut request =
55 IntoClientRequest::into_client_request(full_url.clone())
56 .map_err(Error::other)?;
57
58 for (k, v) in headers {
60 use tokio_tungstenite::tungstenite::http::header::*;
61 let k =
62 HeaderName::from_bytes(k.as_bytes()).map_err(Error::other)?;
63 let v =
64 HeaderValue::from_bytes(v.as_bytes()).map_err(Error::other)?;
65 request.headers_mut().insert(k, v);
66 }
67
68 if let Some(auth_material) = auth_material {
70 let mut auth_url =
72 url::Url::parse(&full_url).map_err(Error::other)?;
73 auth_url.set_path("/authenticate");
74 match auth_url.scheme() {
75 "ws" => {
76 let _ = auth_url.set_scheme("http");
77 }
78 "wss" => {
79 let _ = auth_url.set_scheme("https");
80 }
81 _ => (),
82 }
83
84 let token = tokio::task::spawn_blocking(move || {
86 ureq::put(auth_url.as_str())
87 .send(&auth_material[..])
88 .map_err(Error::other)?
89 .into_body()
90 .read_to_string()
91 .map_err(Error::other)
92 })
93 .await??;
94
95 #[derive(serde::Deserialize)]
97 #[serde(rename_all = "camelCase")]
98 struct Token {
99 auth_token: Arc<str>,
100 }
101
102 let token: Token =
103 serde_json::from_str(&token).map_err(Error::other)?;
104 let token = token.auth_token;
105
106 let token = if let Some(cb) = alter_token_cb {
107 cb(token)
109 } else {
110 token
111 };
112
113 use tokio_tungstenite::tungstenite::http::header::*;
115 let v =
116 HeaderValue::from_bytes(format!("Bearer {token}").as_bytes())
117 .map_err(Error::other)?;
118 request.headers_mut().insert("Authorization", v);
119 };
120
121 let scheme_ws = request.uri().scheme_str() == Some("ws");
122 let scheme_wss = request.uri().scheme_str() == Some("wss");
123
124 if !scheme_ws && !scheme_wss {
125 return Err(Error::other("scheme must be ws:// or wss://"));
126 }
127
128 if !allow_plain_text && scheme_ws {
129 return Err(Error::other("plain text scheme not allowed"));
130 }
131
132 let host = match request.uri().host() {
133 Some(host) => host.to_string(),
134 None => return Err(Error::other("invalid url")),
135 };
136 let port = request.uri().port_u16().unwrap_or({
137 if scheme_ws {
138 80
139 } else {
140 443
141 }
142 });
143
144 let tcp =
146 tokio::net::TcpStream::connect(format!("{host}:{port}")).await?;
147
148 let maybe_tls = if scheme_ws {
150 tokio_tungstenite::MaybeTlsStream::Plain(tcp)
151 } else {
152 let tls = priv_system_tls(danger_disable_certificate_check);
153 let name = host
154 .try_into()
155 .unwrap_or_else(|_| "sbd".try_into().unwrap());
156 let tls = tokio_rustls::TlsConnector::from(tls)
157 .connect(name, tcp)
158 .await?;
159
160 tokio_tungstenite::MaybeTlsStream::Rustls(tls)
161 };
162
163 let config =
165 tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default(
166 )
167 .max_message_size(Some(max_message_size));
168
169 let (ws, _res) = tokio_tungstenite::client_async_with_config(
171 request,
172 maybe_tls,
173 Some(config),
174 )
175 .await
176 .map_err(Error::other)?;
177
178 let (send, recv) = futures::stream::StreamExt::split(ws);
180
181 Ok((WsRawSend { send }, WsRawRecv { recv }))
182 }
183}
184
185use tokio_tungstenite::tungstenite::protocol::Message;
186type MaybeTls = tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>;
187type Ws = tokio_tungstenite::WebSocketStream<MaybeTls>;
188type WsSend = futures::stream::SplitSink<Ws, Message>;
189type WsRecv = futures::stream::SplitStream<Ws>;
190
191pub struct WsRawSend {
193 send: WsSend,
194}
195
196impl WsRawSend {
197 pub async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
199 use futures::sink::SinkExt;
200 self.send
201 .send(Message::binary(msg))
202 .await
203 .map_err(Error::other)?;
204 self.send.flush().await.map_err(Error::other)?;
205 Ok(())
206 }
207
208 pub async fn close(&mut self) {
210 use futures::sink::SinkExt;
211 let _ = self.send.close().await;
212 }
213}
214
215pub struct WsRawRecv {
217 recv: WsRecv,
218}
219
220impl WsRawRecv {
221 pub async fn recv(&mut self) -> Result<Vec<u8>> {
223 use futures::stream::StreamExt;
224 use tokio_tungstenite::tungstenite::protocol::Message::*;
225 loop {
226 match self.recv.next().await {
227 None => return Err(Error::other("closed")),
228 Some(r) => {
229 let msg = r.map_err(Error::other)?;
230 match msg {
231 Text(s) => return Ok(s.as_bytes().to_vec()),
233 Binary(v) => return Ok(v.to_vec()),
235 Ping(_) | Pong(_) => (),
237 Close(_) => return Err(Error::other("closed")),
238 Frame(_) => unreachable!(),
240 }
241 }
242 }
243 }
244 }
245}
246
247pub struct Handshake {
249 pub limit_byte_nanos: i32,
251
252 pub limit_idle_millis: i32,
254
255 pub bytes_sent: usize,
257}
258
259impl Handshake {
260 pub async fn handshake<C: Crypto>(
262 send: &mut WsRawSend,
263 recv: &mut WsRawRecv,
264 crypto: &C,
265 ) -> Result<Self> {
266 let mut limit_byte_nanos = 8000;
267 let mut limit_idle_millis = 10_000;
268 let mut bytes_sent = 0;
269
270 loop {
271 match Msg(recv.recv().await?).parse()? {
272 MsgType::Msg { .. } => {
273 return Err(Error::other("invalid handshake"));
275 }
276 MsgType::LimitByteNanos(l) => limit_byte_nanos = l,
278 MsgType::LimitIdleMillis(l) => limit_idle_millis = l,
280 MsgType::AuthReq(nonce) => {
282 let sig = crypto.sign(nonce)?;
283 let mut auth_res = Vec::with_capacity(HDR_SIZE + SIG_SIZE);
284 auth_res.extend_from_slice(CMD_PREFIX);
285 auth_res.extend_from_slice(b"ares");
286 auth_res.extend_from_slice(&sig);
287 send.send(auth_res).await?;
288 bytes_sent += HDR_SIZE + SIG_SIZE;
289 }
290 MsgType::Ready => break,
292 MsgType::Unknown => (),
293 }
294 }
295
296 Ok(Self {
297 limit_byte_nanos,
298 limit_idle_millis,
299 bytes_sent,
300 })
301 }
302}
303
304fn priv_system_tls(
305 danger_disable_certificate_check: bool,
306) -> Arc<rustls::ClientConfig> {
307 let mut roots = rustls::RootCertStore::empty();
308
309 #[cfg(any(
310 feature = "force_webpki_roots",
311 not(any(
312 target_os = "windows",
313 target_os = "linux",
314 target_os = "macos",
315 )),
316 ))]
317 {
318 roots.roots = webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect();
319 }
320
321 #[cfg(all(
322 not(feature = "force_webpki_roots"),
323 any(target_os = "windows", target_os = "linux", target_os = "macos",),
324 ))]
325 roots.add_parsable_certificates(
326 rustls_native_certs::load_native_certs().certs,
327 );
328
329 if danger_disable_certificate_check {
330 let v = rustls::client::WebPkiServerVerifier::builder(Arc::new(roots))
331 .build()
332 .unwrap();
333
334 Arc::new(
335 rustls::ClientConfig::builder()
336 .dangerous()
337 .with_custom_certificate_verifier(Arc::new(V(v)))
338 .with_no_client_auth(),
339 )
340 } else {
341 Arc::new(
342 rustls::ClientConfig::builder()
343 .with_root_certificates(roots)
344 .with_no_client_auth(),
345 )
346 }
347}
348
349#[derive(Debug)]
350struct V(Arc<rustls::client::WebPkiServerVerifier>);
351
352impl rustls::client::danger::ServerCertVerifier for V {
353 fn verify_server_cert(
354 &self,
355 _end_entity: &rustls::pki_types::CertificateDer<'_>,
356 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
357 _server_name: &rustls::pki_types::ServerName<'_>,
358 _ocsp_response: &[u8],
359 _now: rustls::pki_types::UnixTime,
360 ) -> std::result::Result<
361 rustls::client::danger::ServerCertVerified,
362 rustls::Error,
363 > {
364 Ok(rustls::client::danger::ServerCertVerified::assertion())
365 }
366 fn verify_tls12_signature(
367 &self,
368 _message: &[u8],
369 _cert: &rustls::pki_types::CertificateDer<'_>,
370 _dss: &rustls::DigitallySignedStruct,
371 ) -> std::result::Result<
372 rustls::client::danger::HandshakeSignatureValid,
373 rustls::Error,
374 > {
375 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
376 }
377 fn verify_tls13_signature(
378 &self,
379 _message: &[u8],
380 _cert: &rustls::pki_types::CertificateDer<'_>,
381 _dss: &rustls::DigitallySignedStruct,
382 ) -> std::result::Result<
383 rustls::client::danger::HandshakeSignatureValid,
384 rustls::Error,
385 > {
386 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
387 }
388 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
389 self.0.supported_verify_schemes()
390 }
391}