v_exchanges_api_generics/
ws.rs1use std::{
2 collections::HashSet,
3 time::{Duration, SystemTime},
4 vec,
5};
6
7use eyre::{Result, bail};
8use futures_util::{SinkExt as _, StreamExt as _};
9use jiff::Timestamp;
10use reqwest::Url;
11use tokio::net::TcpStream;
12use tokio_tungstenite::{
13 MaybeTlsStream, WebSocketStream,
14 tungstenite::{self, Bytes},
15};
16use tracing::instrument;
17
18use crate::{AuthError, UrlError};
19
20type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
21
22pub trait WsHandler: std::fmt::Debug {
24 fn config(&self) -> Result<WsConfig, UrlError> {
26 Ok(WsConfig::default())
27 }
28
29 #[allow(unused_variables)]
34 fn handle_auth(&mut self) -> Result<Vec<tungstenite::Message>, WsError> {
35 Ok(vec![])
36 }
37
38 #[allow(unused_variables)]
60 fn handle_subscribe(&mut self, topics: HashSet<Topic>) -> Result<Vec<tungstenite::Message>, WsError>;
61
62 #[allow(unused_variables)]
64 fn handle_jrpc(&mut self, jrpc: serde_json::Value) -> Result<ResponseOrContent, WsError>;
65 }
77
78#[derive(Clone, Debug)]
79pub enum ResponseOrContent {
80 Response(Vec<tungstenite::Message>),
82 Content(ContentEvent),
84}
85#[derive(Clone, Debug)]
86pub struct ContentEvent {
87 pub data: serde_json::Value,
88 pub topic: String,
89 pub time: Timestamp,
90 pub event_type: String,
91}
92
93#[derive(Clone, Debug, Eq)]
94pub struct TopicInterpreter<T> {
95 pub event_name: String,
97 pub interpret: fn(&serde_json::Value) -> Result<T, WsError>,
99}
100impl<T> std::hash::Hash for TopicInterpreter<T> {
101 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
102 self.event_name.hash(state);
103 }
104}
105impl<T> PartialEq for TopicInterpreter<T> {
106 fn eq(&self, other: &Self) -> bool {
107 self.event_name == other.event_name
108 }
109}
110
111#[derive(Debug)]
113pub struct WsConnection<H: WsHandler> {
114 url: Url,
115 config: WsConfig,
116 handler: H,
117 stream: Option<WsConnectionStream>,
118 last_reconnect_attempt: SystemTime, }
120#[derive(Debug, derive_more::Deref, derive_more::DerefMut)]
121struct WsConnectionStream {
122 #[deref_mut]
123 #[deref]
124 stream: WsStream,
125 connected_since: SystemTime,
126 last_unanswered_communication: Option<SystemTime>,
127}
128impl WsConnectionStream {
129 fn new(stream: WsStream, connected_since: SystemTime) -> Self {
130 Self {
131 stream,
132 connected_since,
133 last_unanswered_communication: None,
134 }
135 }
136}
137impl<H: WsHandler> WsConnection<H> {
138 #[allow(missing_docs)]
139 pub fn try_new(url_suffix: &str, handler: H) -> Result<Self, UrlError> {
140 let config = handler.config()?;
141 let url = match &config.base_url {
142 Some(base_url) => base_url.join(url_suffix)?,
143 None => Url::parse(url_suffix)?,
144 };
145
146 Ok(Self {
147 url,
148 config,
149 handler,
150 stream: None,
151 last_reconnect_attempt: SystemTime::UNIX_EPOCH,
152 })
153 }
154
155 pub async fn next(&mut self) -> Result<ContentEvent, WsError> {
157 if let Some(inner) = &self.stream
158 && inner.connected_since + self.config.refresh_after < SystemTime::now()
159 {
160 tracing::info!("Refreshing connection, as `refresh_after` specified in WsConfig has elapsed ({:?})", self.config.refresh_after);
161 self.reconnect().await?;
162 }
163 if self.stream.is_none() {
164 self.connect().await?;
165 }
166 let json_rpc_value = loop {
170 let resp = {
172 let timeout = match self.stream.as_ref() {
173 Some(stream) => match stream.last_unanswered_communication {
174 Some(last_unanswered) => {
175 let now = SystemTime::now();
176 match last_unanswered + self.config.response_timeout > now {
177 true => self.config.response_timeout,
178 false => {
179 tracing::error!(
180 "Timeout for last unanswered communication ended before `.next()` was called. This likely indicates an implementation error on the clientside."
181 );
182 self.reconnect().await?;
183 continue;
184 }
185 }
186 }
187 None => self.config.message_timeout,
188 },
189 None => {
190 tracing::error!(
191 "UNEXPECTED: Stream is None at ws.rs:172 despite guard at line 163. \
192 Possible causes: (1) system hibernation/sleep caused stale state, \
193 (2) memory corruption, (3) logic bug in reconnection flow, \
194 (4) async cancellation. \
195 Last reconnect attempt: {:?} ago. Attempting to reconnect...",
196 SystemTime::now().duration_since(self.last_reconnect_attempt).unwrap_or_default()
197 );
198 self.connect().await?;
199 continue;
200 }
201 };
202
203 let timeout_handle = tokio::time::timeout(timeout, {
204 let stream = self.stream.as_mut().unwrap();
205 stream.next()
206 });
207 match timeout_handle.await {
208 Ok(Some(resp)) => {
209 self.stream.as_mut().unwrap().last_unanswered_communication = None;
210 resp
211 }
212 Ok(None) => {
213 tracing::warn!("tungstenite couldn't read from the stream. Restarting.");
214 self.reconnect().await?;
215 continue;
216 }
217 Err(timeout_error) => {
218 tracing::warn!("Message reception timed out after {:?} seconds. // {timeout_error}", timeout);
219 {
220 let stream = self.stream.as_mut().unwrap();
221 match stream.last_unanswered_communication.is_some() {
222 true => self.reconnect().await?,
223 false => {
224 self.send(tungstenite::Message::Ping(Bytes::default())).await?;
226 continue;
227 }
228 }
229 }
230 continue;
231 }
232 }
233 };
234
235 match resp {
237 Ok(succ_resp) => match succ_resp {
238 tungstenite::Message::Text(text) => {
239 let value: serde_json::Value =
240 serde_json::from_str(&text).expect("API sent invalid JSON, which is completely unexpected. Disappointment is immeasurable and the day is ruined.");
241 tracing::trace!("{value:#?}"); break match { self.handler.handle_jrpc(value)? } {
243 ResponseOrContent::Response(messages) => {
244 self.send_all(messages).await?;
245 continue; }
247 ResponseOrContent::Content(content) => content,
248 };
249 }
250 tungstenite::Message::Binary(_) => {
251 panic!("Received binary. But exchanges are not smart enough to send this, what is happening");
252 }
253 tungstenite::Message::Ping(bytes) => {
254 self.send(tungstenite::Message::Pong(bytes)).await?; tracing::debug!("ponged");
256 continue;
257 }
258 tungstenite::Message::Pong(_) => {
260 tracing::info!("Received pong");
261 continue;
262 }
263 tungstenite::Message::Close(maybe_reason) => {
264 match maybe_reason {
265 Some(close_frame) => {
266 tracing::info!("Server closed connection; reason: {close_frame:?}");
268 }
269 None => {
270 tracing::info!("Server closed connection; no reason specified.");
271 }
272 }
273 self.stream = None;
274 self.reconnect().await?;
275 continue;
276 }
277 tungstenite::Message::Frame(_) => {
278 unreachable!("Can't get from reading");
279 }
280 },
281 Err(err) => match err {
282 tungstenite::Error::ConnectionClosed => {
283 tracing::error!("received `tungstenite::Error::ConnectionClosed` on polling. Will reconnect");
284 self.stream = None;
285 continue;
286 }
287 tungstenite::Error::AlreadyClosed => {
288 tracing::error!("received `tungstenite::Error::AlreadyClosed` from polling. Will reconnect");
289 self.stream = None;
290 continue;
291 }
292 tungstenite::Error::Io(e) => {
293 tracing::warn!("received `tungstenite::Error::Io` from polling: {e:?}. Likely indicates connection issues. Skipping.");
294 continue;
295 }
296 tungstenite::Error::Tls(_tls_error) => todo!(),
297 tungstenite::Error::Capacity(capacity_error) => {
298 tracing::warn!("received `tungstenite::Error::Capacity` from polling: {capacity_error:?}. Skipping.");
299 continue;
300 }
301 tungstenite::Error::Protocol(protocol_error) => {
302 tracing::warn!("received `tungstenite::Error::Protocol` from polling: {protocol_error:?}. Will reconnect");
303 self.stream = None;
304 continue;
305 }
306 tungstenite::Error::WriteBufferFull(_) => unreachable!("can only get from writing"),
307 tungstenite::Error::Utf8(e) => panic!("received `tungstenite::Error::Utf8` from polling: {e:?}. Exchange is going crazy, aborting"),
308 tungstenite::Error::AttackAttempt => {
309 tracing::warn!("received `tungstenite::Error::AttackAttempt` from polling. Don't have a reason to trust detection 100%, so just reconnecting.");
310 self.stream = None;
311 continue;
312 }
313 tungstenite::Error::Url(_url_error) => todo!(),
314 tungstenite::Error::Http(_response) => todo!(),
315 tungstenite::Error::HttpFormat(_error) => todo!(),
316 },
317 }
318 };
319 Ok(json_rpc_value)
320 }
321
322 #[instrument(skip_all)]
323 async fn send_all(&mut self, messages: Vec<tungstenite::Message>) -> Result<(), tungstenite::Error> {
324 if let Some(inner) = &mut self.stream {
325 match messages.len() {
326 0 => return Ok(()),
327 1 => {
328 tracing::debug!("sending to server: {:#?}", &messages[0]);
329 inner.send(messages.into_iter().next().unwrap()).await?;
330 inner.last_unanswered_communication = Some(SystemTime::now());
331 }
332 _ => {
333 tracing::debug!("sending to server: {messages:#?}");
334 let mut message_stream = futures_util::stream::iter(messages).map(Ok);
335 inner.send_all(&mut message_stream).await?;
336 inner.last_unanswered_communication = Some(SystemTime::now());
337 }
338 };
339 Ok(())
340 } else {
341 Err(tungstenite::Error::ConnectionClosed)
342 }
343 }
344
345 async fn send(&mut self, message: tungstenite::Message) -> Result<(), tungstenite::Error> {
346 self.send_all(vec![message]).await }
348
349 async fn connect(&mut self) -> Result<(), WsError> {
350 tracing::info!("Connecting to {}...", self.url);
351 {
352 let now = SystemTime::now();
353 let timeout = self.config.reconnect_cooldown;
354 if self.last_reconnect_attempt + timeout > now {
355 tracing::warn!("Reconnect cooldown is triggered. Likely indicative of a bad connection.");
356 let duration = (self.last_reconnect_attempt + timeout).duration_since(now).unwrap();
357 tokio::time::sleep(duration).await;
358 }
359 }
360 self.last_reconnect_attempt = SystemTime::now();
361
362 let (stream, http_resp) = tokio_tungstenite::connect_async(self.url.as_str()).await?;
363 tracing::debug!("Ws handshake with server: {http_resp:#?}");
364
365 let now = SystemTime::now();
366 self.stream = Some(WsConnectionStream::new(stream, now));
367
368 let auth_messages = self.handler.handle_auth()?;
369 Ok(self.send_all(auth_messages).await?)
370 }
371
372 pub async fn reconnect(&mut self) -> Result<(), WsError> {
376 if let Some(stream) = self.stream.as_mut() {
377 tracing::info!("Dropping old connection before reconnecting...");
378 if let Err(e) = stream.send(tungstenite::Message::Close(None)).await {
380 tracing::debug!("Failed to send Close frame (connection likely already dead): {e}");
381 }
382 self.stream = None;
383 }
384 self.connect().await
385 }
386}
387
388#[derive(Clone, Debug, Default, Eq, PartialEq)]
392pub struct WsConfig {
393 pub auth: bool,
395 pub base_url: Option<Url>,
399 reconnect_cooldown: Duration = Duration::from_secs(3),
404 refresh_after: Duration = Duration::from_hours(12),
406 message_timeout: Duration = Duration::from_mins(16), response_timeout: Duration = Duration::from_mins(2),
412 pub topics: HashSet<String>,
414}
415
416impl WsConfig {
417 pub fn set_reconnect_cooldown(&mut self, reconnect_cooldown: Duration) -> Result<()> {
418 if reconnect_cooldown.is_zero() {
419 bail!("connect_cooldown must be greater than 0");
420 }
421 self.reconnect_cooldown = reconnect_cooldown;
422 Ok(())
423 }
424
425 pub fn set_refresh_after(&mut self, refresh_after: Duration) -> Result<()> {
426 if refresh_after.is_zero() {
427 bail!("refresh_after must be greater than 0");
428 }
429 self.refresh_after = refresh_after;
430 Ok(())
431 }
432
433 pub fn set_message_timeout(&mut self, message_timeout: Duration) -> Result<()> {
434 if message_timeout.is_zero() {
435 bail!("message_timeout must be greater than 0");
436 }
437 self.message_timeout = message_timeout;
438 Ok(())
439 }
440
441 pub fn set_response_timout(&mut self, response_timeout: Duration) -> Result<()> {
442 if response_timeout.is_zero() {
443 bail!("response_timeout must be greater than 0");
444 }
445 self.response_timeout = response_timeout;
446 Ok(())
447 }
448}
449
450#[derive(Debug, derive_more::Display, thiserror::Error, derive_more::From)]
451pub enum WsError {
452 Definition(WsDefinitionError),
453 Tungstenite(tungstenite::Error),
454 Auth(AuthError),
455 Parse(serde_json::Error),
456 Subscription(String),
457 NetworkConnection,
458 Url(UrlError),
459 UnexpectedEvent(serde_json::Value),
460 Other(eyre::Report),
461}
462#[derive(Debug, derive_more::Display, thiserror::Error)]
463pub enum WsDefinitionError {
464 MissingUrl,
465}
466
467#[derive(Clone, Debug, derive_more::Display, Eq, Hash, PartialEq, serde::Serialize)]
482pub enum Topic {
483 String(String),
484 Order(serde_json::Value),
485}