v_exchanges_api_generics/
ws.rs1use std::{
2 collections::HashSet,
3 time::{Duration, SystemTime},
4 vec,
5};
6
7use eyre::{Result, bail};
8use futures_util::{SinkExt as _, StreamExt as _};
9use jiff::Timestamp;
10use reqwest::Url;
11use tokio::net::TcpStream;
12use tokio_tungstenite::{
13 MaybeTlsStream, WebSocketStream,
14 tungstenite::{self, Bytes},
15};
16use tracing::instrument;
17
18use crate::{AuthError, UrlError};
19
20type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
21
22pub trait WsHandler: std::fmt::Debug {
24 fn config(&self) -> Result<WsConfig, UrlError> {
26 Ok(WsConfig::default())
27 }
28
29 #[allow(unused_variables)]
34 fn handle_auth(&mut self) -> Result<Vec<tungstenite::Message>, WsError> {
35 Ok(vec![])
36 }
37
38 #[allow(unused_variables)]
60 fn handle_subscribe(&mut self, topics: HashSet<Topic>) -> Result<Vec<tungstenite::Message>, WsError>;
61
62 #[allow(unused_variables)]
64 fn handle_jrpc(&mut self, jrpc: serde_json::Value) -> Result<ResponseOrContent, WsError>;
65 }
77
78#[derive(Clone, Debug)]
79pub enum ResponseOrContent {
80 Response(Vec<tungstenite::Message>),
82 Content(ContentEvent),
84}
85#[derive(Clone, Debug)]
86pub struct ContentEvent {
87 pub data: serde_json::Value,
88 pub topic: String,
89 pub time: Timestamp,
90 pub event_type: String,
91}
92
93#[derive(Clone, Debug, Eq)]
94pub struct TopicInterpreter<T> {
95 pub event_name: String,
97 pub interpret: fn(&serde_json::Value) -> Result<T, WsError>,
99}
100impl<T> std::hash::Hash for TopicInterpreter<T> {
101 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
102 self.event_name.hash(state);
103 }
104}
105impl<T> PartialEq for TopicInterpreter<T> {
106 fn eq(&self, other: &Self) -> bool {
107 self.event_name == other.event_name
108 }
109}
110
111#[derive(Debug)]
113pub struct WsConnection<H: WsHandler> {
114 url: Url,
115 config: WsConfig,
116 handler: H,
117 stream: Option<WsConnectionStream>,
118 last_reconnect_attempt: SystemTime, }
120#[derive(Debug, derive_more::Deref, derive_more::DerefMut)]
121struct WsConnectionStream {
122 #[deref_mut]
123 #[deref]
124 stream: WsStream,
125 connected_since: SystemTime,
126 last_unanswered_communication: Option<SystemTime>,
127}
128impl WsConnectionStream {
129 fn new(stream: WsStream, connected_since: SystemTime) -> Self {
130 Self {
131 stream,
132 connected_since,
133 last_unanswered_communication: None,
134 }
135 }
136}
137impl<H: WsHandler> WsConnection<H> {
138 #[allow(missing_docs)]
139 pub fn try_new(url_suffix: &str, handler: H) -> Result<Self, UrlError> {
140 let config = handler.config()?;
141 let url = match &config.base_url {
142 Some(base_url) => base_url.join(url_suffix)?,
143 None => Url::parse(url_suffix)?,
144 };
145
146 Ok(Self {
147 url,
148 config,
149 handler,
150 stream: None,
151 last_reconnect_attempt: SystemTime::UNIX_EPOCH,
152 })
153 }
154
155 pub async fn next(&mut self) -> Result<ContentEvent, WsError> {
157 if let Some(inner) = &self.stream
158 && inner.connected_since + self.config.refresh_after < SystemTime::now()
159 {
160 tracing::info!("Refreshing connection, as `refresh_after` specified in WsConfig has elapsed ({:?})", self.config.refresh_after);
161 self.reconnect().await?;
162 }
163 if self.stream.is_none() {
164 self.connect().await?;
165 }
166 let json_rpc_value = loop {
170 let resp = {
172 let timeout = match self.stream.as_ref().unwrap().last_unanswered_communication {
173 Some(last_unanswered) => {
174 let now = SystemTime::now();
175 match last_unanswered + self.config.response_timeout > now {
176 true => self.config.response_timeout,
177 false => {
178 tracing::error!(
179 "Timeout for last unanswered communication ended before `.next()` was called. This likely indicates an implementation error on the clientside."
180 );
181 self.reconnect().await?;
182 continue;
183 }
184 }
185 }
186 None => self.config.message_timeout,
187 };
188
189 let timeout_handle = tokio::time::timeout(timeout, {
190 let stream = self.stream.as_mut().unwrap();
191 stream.next()
192 });
193 match timeout_handle.await {
194 Ok(Some(resp)) => {
195 self.stream.as_mut().unwrap().last_unanswered_communication = None;
196 resp
197 }
198 Ok(None) => {
199 tracing::warn!("tungstenite couldn't read from the stream. Restarting.");
200 self.reconnect().await?;
201 continue;
202 }
203 Err(timeout_error) => {
204 tracing::warn!("Message reception timed out after {:?} seconds. // {timeout_error}", timeout);
205 {
206 let stream = self.stream.as_mut().unwrap();
207 match stream.last_unanswered_communication.is_some() {
208 true => self.reconnect().await?,
209 false => {
210 self.send(tungstenite::Message::Ping(Bytes::default())).await?;
212 continue;
213 }
214 }
215 }
216 continue;
217 }
218 }
219 };
220
221 match resp {
223 Ok(succ_resp) => match succ_resp {
224 tungstenite::Message::Text(text) => {
225 let value: serde_json::Value =
226 serde_json::from_str(&text).expect("API sent invalid JSON, which is completely unexpected. Disappointment is immeasurable and the day is ruined.");
227 tracing::trace!("{value:#?}"); break match { self.handler.handle_jrpc(value)? } {
229 ResponseOrContent::Response(messages) => {
230 self.send_all(messages).await?;
231 continue; }
233 ResponseOrContent::Content(content) => content,
234 };
235 }
236 tungstenite::Message::Binary(_) => {
237 panic!("Received binary. But exchanges are not smart enough to send this, what is happening");
238 }
239 tungstenite::Message::Ping(bytes) => {
240 self.send(tungstenite::Message::Pong(bytes)).await?; tracing::debug!("ponged");
242 continue;
243 }
244 tungstenite::Message::Pong(_) => {
246 tracing::info!("Received pong");
247 continue;
248 }
249 tungstenite::Message::Close(maybe_reason) => {
250 match maybe_reason {
251 Some(close_frame) => {
252 tracing::info!("Server closed connection; reason: {close_frame:?}");
254 }
255 None => {
256 tracing::info!("Server closed connection; no reason specified.");
257 }
258 }
259 self.stream = None;
260 self.reconnect().await?;
261 continue;
262 }
263 tungstenite::Message::Frame(_) => {
264 unreachable!("Can't get from reading");
265 }
266 },
267 Err(err) => match err {
268 tungstenite::Error::ConnectionClosed => {
269 tracing::error!("received `tungstenite::Error::ConnectionClosed` on polling. Will reconnect");
270 self.stream = None;
271 continue;
272 }
273 tungstenite::Error::AlreadyClosed => {
274 tracing::error!("received `tungstenite::Error::AlreadyClosed` from polling. Will reconnect");
275 self.stream = None;
276 continue;
277 }
278 tungstenite::Error::Io(e) => {
279 tracing::warn!("received `tungstenite::Error::Io` from polling: {e:?}. Likely indicates connection issues. Skipping.");
280 continue;
281 }
282 tungstenite::Error::Tls(_tls_error) => todo!(),
283 tungstenite::Error::Capacity(capacity_error) => {
284 tracing::warn!("received `tungstenite::Error::Capacity` from polling: {capacity_error:?}. Skipping.");
285 continue;
286 }
287 tungstenite::Error::Protocol(protocol_error) => {
288 tracing::warn!("received `tungstenite::Error::Protocol` from polling: {protocol_error:?}. Will reconnect");
289 self.stream = None;
290 continue;
291 }
292 tungstenite::Error::WriteBufferFull(_) => unreachable!("can only get from writing"),
293 tungstenite::Error::Utf8(e) => panic!("received `tungstenite::Error::Utf8` from polling: {e:?}. Exchange is going crazy, aborting"),
294 tungstenite::Error::AttackAttempt => {
295 tracing::warn!("received `tungstenite::Error::AttackAttempt` from polling. Don't have a reason to trust detection 100%, so just reconnecting.");
296 self.stream = None;
297 continue;
298 }
299 tungstenite::Error::Url(_url_error) => todo!(),
300 tungstenite::Error::Http(_response) => todo!(),
301 tungstenite::Error::HttpFormat(_error) => todo!(),
302 },
303 }
304 };
305 Ok(json_rpc_value)
306 }
307
308 #[instrument(skip_all)]
309 async fn send_all(&mut self, messages: Vec<tungstenite::Message>) -> Result<(), tungstenite::Error> {
310 if let Some(inner) = &mut self.stream {
311 match messages.len() {
312 0 => return Ok(()),
313 1 => {
314 tracing::debug!("sending to server: {:#?}", &messages[0]);
315 inner.send(messages.into_iter().next().unwrap()).await?;
316 inner.last_unanswered_communication = Some(SystemTime::now());
317 }
318 _ => {
319 tracing::debug!("sending to server: {messages:#?}");
320 let mut message_stream = futures_util::stream::iter(messages).map(Ok);
321 inner.send_all(&mut message_stream).await?;
322 inner.last_unanswered_communication = Some(SystemTime::now());
323 }
324 };
325 Ok(())
326 } else {
327 Err(tungstenite::Error::ConnectionClosed)
328 }
329 }
330
331 async fn send(&mut self, message: tungstenite::Message) -> Result<(), tungstenite::Error> {
332 self.send_all(vec![message]).await }
334
335 async fn connect(&mut self) -> Result<(), WsError> {
336 tracing::info!("Connecting to {}...", self.url);
337 {
338 let now = SystemTime::now();
339 let timeout = self.config.reconnect_cooldown;
340 if self.last_reconnect_attempt + timeout > now {
341 tracing::warn!("Reconnect cooldown is triggered. Likely indicative of a bad connection.");
342 let duration = (self.last_reconnect_attempt + timeout).duration_since(now).unwrap();
343 tokio::time::sleep(duration).await;
344 }
345 }
346 self.last_reconnect_attempt = SystemTime::now();
347
348 let (stream, http_resp) = tokio_tungstenite::connect_async(self.url.as_str()).await?;
349 tracing::debug!("Ws handshake with server: {http_resp:#?}");
350
351 let now = SystemTime::now();
352 self.stream = Some(WsConnectionStream::new(stream, now));
353
354 let auth_messages = self.handler.handle_auth()?;
355 Ok(self.send_all(auth_messages).await?)
356 }
357
358 pub async fn reconnect(&mut self) -> Result<(), WsError> {
362 if self.stream.is_some() {
363 tracing::info!("Dropping old connection before reconnecting...");
364 {
365 let stream = self.stream.as_mut().unwrap();
366 stream.send(tungstenite::Message::Close(None)).await?;
367 self.stream = None;
368 }
369 }
370 self.connect().await
371 }
372}
373
374#[derive(Clone, Debug, Default, Eq, PartialEq)]
378pub struct WsConfig {
379 pub auth: bool,
381 pub base_url: Option<Url>,
385 reconnect_cooldown: Duration = Duration::from_secs(3),
390 refresh_after: Duration = Duration::from_hours(12),
392 message_timeout: Duration = Duration::from_mins(16), response_timeout: Duration = Duration::from_mins(2),
398 pub topics: HashSet<String>,
400}
401
402impl WsConfig {
403 pub fn set_reconnect_cooldown(&mut self, reconnect_cooldown: Duration) -> Result<()> {
404 if reconnect_cooldown.is_zero() {
405 bail!("connect_cooldown must be greater than 0");
406 }
407 self.reconnect_cooldown = reconnect_cooldown;
408 Ok(())
409 }
410
411 pub fn set_refresh_after(&mut self, refresh_after: Duration) -> Result<()> {
412 if refresh_after.is_zero() {
413 bail!("refresh_after must be greater than 0");
414 }
415 self.refresh_after = refresh_after;
416 Ok(())
417 }
418
419 pub fn set_message_timeout(&mut self, message_timeout: Duration) -> Result<()> {
420 if message_timeout.is_zero() {
421 bail!("message_timeout must be greater than 0");
422 }
423 self.message_timeout = message_timeout;
424 Ok(())
425 }
426
427 pub fn set_response_timout(&mut self, response_timeout: Duration) -> Result<()> {
428 if response_timeout.is_zero() {
429 bail!("response_timeout must be greater than 0");
430 }
431 self.response_timeout = response_timeout;
432 Ok(())
433 }
434}
435
436#[derive(Debug, derive_more::Display, thiserror::Error, derive_more::From)]
437pub enum WsError {
438 Definition(WsDefinitionError),
439 Tungstenite(tungstenite::Error),
440 Auth(AuthError),
441 Parse(serde_json::Error),
442 Subscription(String),
443 NetworkConnection,
444 Url(UrlError),
445 UnexpectedEvent(serde_json::Value),
446 Other(eyre::Report),
447}
448#[derive(Debug, derive_more::Display, thiserror::Error)]
449pub enum WsDefinitionError {
450 MissingUrl,
451}
452
453#[derive(Clone, Debug, derive_more::Display, Eq, Hash, PartialEq, serde::Serialize)]
468pub enum Topic {
469 String(String),
470 Order(serde_json::Value),
471}