v_exchanges_api_generics/
ws.rs

1use std::{
2	collections::HashSet,
3	time::{Duration, SystemTime},
4	vec,
5};
6
7use eyre::{Result, bail};
8use futures_util::{SinkExt as _, StreamExt as _};
9use jiff::Timestamp;
10use reqwest::Url;
11use tokio::net::TcpStream;
12use tokio_tungstenite::{
13	MaybeTlsStream, WebSocketStream,
14	tungstenite::{self, Bytes},
15};
16use tracing::instrument;
17
18use crate::{AuthError, UrlError};
19
20type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
21
22/// handle exchange-level events on the [WsConnection].
23pub trait WsHandler: std::fmt::Debug {
24	/// Returns a [WsConfig] that will be applied for all WebSocket connections handled by this handler.
25	fn config(&self) -> Result<WsConfig, UrlError> {
26		Ok(WsConfig::default())
27	}
28
29	/// Called when the [WsConnection] is created and on reconnection. Returned messages will be sent back to the server as-is.
30	///
31	/// Handling of `listen-key`s or any other authentication methods exchange demands should be done here. Although oftentimes handling the auth will spread into the [handle_message](Self::handle_message) too.
32	/// Can be ran multiple times (on every reconnect). Thus this inherently cannot be used to initiate connectionions based on a change of state (ie order creation).
33	#[allow(unused_variables)]
34	fn handle_auth(&mut self) -> Result<Vec<tungstenite::Message>, WsError> {
35		Ok(vec![])
36	}
37
38	//Q: problem: can be either {String, serde_json::Value} //? other things?
39	/*
40	  "position"
41	  ||
42	  json!{
43	"id": "56374a46-3061-486b-a311-99ee972eb648",
44	"method": "order.place",
45	"params": {
46	  "symbol": "BTCUSDT",
47	  "side": "SELL",
48	  "type": "LIMIT",
49	  "timeInForce": "GTC",
50	  "price": "23416.10000000",
51	  "quantity": "0.00847000",
52	  "apiKey": "vmPUZE6mv9SD5VNHk4HlWFsOr6aKE2zvsw0MuIgwCIPy6utIco14y7Ju91duEh8A",
53	  "signature": "15af09e41c36f3cc61378c2fbe2c33719a03dd5eba8d0f9206fbda44de717c88",
54	  "timestamp": 1660801715431
55	  }
56	  }
57	  - and then the latter could be requiring signing
58	  */
59	#[allow(unused_variables)]
60	fn handle_subscribe(&mut self, topics: HashSet<Topic>) -> Result<Vec<tungstenite::Message>, WsError>;
61
62	/// Called when the [WsConnection] received a JSON-RPC value, returns messages to be sent to the server or the content with parsed event name. If not the desired content and no respose is to be sent (like after a confirmation for a subscription), return a Response with an empty Vec.
63	#[allow(unused_variables)]
64	fn handle_jrpc(&mut self, jrpc: serde_json::Value) -> Result<ResponseOrContent, WsError>;
65	//A: use this iff spot&&perp binance accept listen-key refresh through stream
66	///// Additional POST communication with the exchange, not conditional on received messages, can be handled here.
67	///// Really this is just for damn Binance with their stupid `listn-key` standard.
68	//fn handle_post(&mut self) -> Result<Option<Vec<tungstenite::Message>>, WsError> {
69	//	Ok(None)
70	//}
71
72	//#[allow(unused_variables)]
73	//fn handle_jrpc(&mut self, jrpc: &serde_json::Value) -> Result<Option<Vec<tungstenite::Message>>, WsError> {
74	//	Ok(None)
75	//}
76}
77
78#[derive(Clone, Debug)]
79pub enum ResponseOrContent {
80	/// Response to a message sent to the server.
81	Response(Vec<tungstenite::Message>),
82	/// Content received from the server.
83	Content(ContentEvent),
84}
85#[derive(Clone, Debug)]
86pub struct ContentEvent {
87	pub data: serde_json::Value,
88	pub topic: String,
89	pub time: Timestamp,
90	pub event_type: String,
91}
92
93#[derive(Clone, Debug, Eq)]
94pub struct TopicInterpreter<T> {
95	/// Only one interpreter for this name is allowed to exist // enforced through `Hash` impl defined over `event_name` only
96	pub event_name: String,
97	/// When name matches, interpretation should succeed.
98	pub interpret: fn(&serde_json::Value) -> Result<T, WsError>,
99}
100impl<T> std::hash::Hash for TopicInterpreter<T> {
101	fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
102		self.event_name.hash(state);
103	}
104}
105impl<T> PartialEq for TopicInterpreter<T> {
106	fn eq(&self, other: &Self) -> bool {
107		self.event_name == other.event_name
108	}
109}
110
111/// Main way to interact with the WebSocket APIs.
112#[derive(Debug)]
113pub struct WsConnection<H: WsHandler> {
114	url: Url,
115	config: WsConfig,
116	handler: H,
117	stream: Option<WsConnectionStream>,
118	last_reconnect_attempt: SystemTime, // not Tz-aware, as it will not escape the application boundary
119}
120#[derive(Debug, derive_more::Deref, derive_more::DerefMut)]
121struct WsConnectionStream {
122	#[deref_mut]
123	#[deref]
124	stream: WsStream,
125	connected_since: SystemTime,
126	last_unanswered_communication: Option<SystemTime>,
127}
128impl WsConnectionStream {
129	fn new(stream: WsStream, connected_since: SystemTime) -> Self {
130		Self {
131			stream,
132			connected_since,
133			last_unanswered_communication: None,
134		}
135	}
136}
137impl<H: WsHandler> WsConnection<H> {
138	#[allow(missing_docs)]
139	pub fn try_new(url_suffix: &str, handler: H) -> Result<Self, UrlError> {
140		let config = handler.config()?;
141		let url = match &config.base_url {
142			Some(base_url) => base_url.join(url_suffix)?,
143			None => Url::parse(url_suffix)?,
144		};
145
146		Ok(Self {
147			url,
148			config,
149			handler,
150			stream: None,
151			last_reconnect_attempt: SystemTime::UNIX_EPOCH,
152		})
153	}
154
155	/// The main interface. All ws operations are hidden, only thing getting through are the content messages or the lack thereof.
156	pub async fn next(&mut self) -> Result<ContentEvent, WsError> {
157		if let Some(inner) = &self.stream
158			&& inner.connected_since + self.config.refresh_after < SystemTime::now()
159		{
160			tracing::info!("Refreshing connection, as `refresh_after` specified in WsConfig has elapsed ({:?})", self.config.refresh_after);
161			self.reconnect().await?;
162		}
163		if self.stream.is_none() {
164			self.connect().await?;
165		}
166		//- at this point self.inner is Some
167
168		// loop until we get actual content
169		let json_rpc_value = loop {
170			// force a response out of the server.
171			let resp = {
172				let timeout = match self.stream.as_ref() {
173					Some(stream) => match stream.last_unanswered_communication {
174						Some(last_unanswered) => {
175							let now = SystemTime::now();
176							match last_unanswered + self.config.response_timeout > now {
177								true => self.config.response_timeout,
178								false => {
179									tracing::error!(
180										"Timeout for last unanswered communication ended before `.next()` was called. This likely indicates an implementation error on the clientside."
181									);
182									self.reconnect().await?;
183									continue;
184								}
185							}
186						}
187						None => self.config.message_timeout,
188					},
189					None => {
190						tracing::error!(
191							"UNEXPECTED: Stream is None at ws.rs:172 despite guard at line 163. \
192							Possible causes: (1) system hibernation/sleep caused stale state, \
193							(2) memory corruption, (3) logic bug in reconnection flow, \
194							(4) async cancellation. \
195							Last reconnect attempt: {:?} ago. Attempting to reconnect...",
196							SystemTime::now().duration_since(self.last_reconnect_attempt).unwrap_or_default()
197						);
198						self.connect().await?;
199						continue;
200					}
201				};
202
203				let timeout_handle = tokio::time::timeout(timeout, {
204					let stream = self.stream.as_mut().unwrap();
205					stream.next()
206				});
207				match timeout_handle.await {
208					Ok(Some(resp)) => {
209						self.stream.as_mut().unwrap().last_unanswered_communication = None;
210						resp
211					}
212					Ok(None) => {
213						tracing::warn!("tungstenite couldn't read from the stream. Restarting.");
214						self.reconnect().await?;
215						continue;
216					}
217					Err(timeout_error) => {
218						tracing::warn!("Message reception timed out after {:?} seconds. // {timeout_error}", timeout);
219						{
220							let stream = self.stream.as_mut().unwrap();
221							match stream.last_unanswered_communication.is_some() {
222								true => self.reconnect().await?,
223								false => {
224									// Reached standard message_timeout (one for messages sent when we're not forcing communication). So let's force it.
225									self.send(tungstenite::Message::Ping(Bytes::default())).await?;
226									continue;
227								}
228							}
229						}
230						continue;
231					}
232				}
233			};
234
235			// some response received, handle it
236			match resp {
237				Ok(succ_resp) => match succ_resp {
238					tungstenite::Message::Text(text) => {
239						let value: serde_json::Value =
240							serde_json::from_str(&text).expect("API sent invalid JSON, which is completely unexpected. Disappointment is immeasurable and the day is ruined.");
241						tracing::trace!("{value:#?}"); // only log it after the `handle_message` has ran, as we're assuming that if it takes any actions, it will handle logging itself. (and that will likely be at a different level of important too)
242						break match { self.handler.handle_jrpc(value)? } {
243							ResponseOrContent::Response(messages) => {
244								self.send_all(messages).await?;
245								continue; // only need to send responses when it's not yet the desired content.
246							}
247							ResponseOrContent::Content(content) => content,
248						};
249					}
250					tungstenite::Message::Binary(_) => {
251						panic!("Received binary. But exchanges are not smart enough to send this, what is happening");
252					}
253					tungstenite::Message::Ping(bytes) => {
254						self.send(tungstenite::Message::Pong(bytes)).await?; // Binance specifically requires the exact ping's payload to be returned here: https://developers.binance.com/docs/binance-spot-api-docs/web-socket-streams
255						tracing::debug!("ponged");
256						continue;
257					}
258					// in most cases these are not seen, as it's sufficient to just answer to their [pings](tungstenite::Message::Ping). Our own pings are sent only when we haven't heard from the exchange for a while, in which case it's likely that it will not [pong](tungstenite::Message::Pong) back either.
259					tungstenite::Message::Pong(_) => {
260						tracing::info!("Received pong");
261						continue;
262					}
263					tungstenite::Message::Close(maybe_reason) => {
264						match maybe_reason {
265							Some(close_frame) => {
266								//Q: maybe need to expose def of this for ind exchanges (so we can interpret the codes)
267								tracing::info!("Server closed connection; reason: {close_frame:?}");
268							}
269							None => {
270								tracing::info!("Server closed connection; no reason specified.");
271							}
272						}
273						self.stream = None;
274						self.reconnect().await?;
275						continue;
276					}
277					tungstenite::Message::Frame(_) => {
278						unreachable!("Can't get from reading");
279					}
280				},
281				Err(err) => match err {
282					tungstenite::Error::ConnectionClosed => {
283						tracing::error!("received `tungstenite::Error::ConnectionClosed` on polling. Will reconnect");
284						self.stream = None;
285						continue;
286					}
287					tungstenite::Error::AlreadyClosed => {
288						tracing::error!("received `tungstenite::Error::AlreadyClosed` from polling. Will reconnect");
289						self.stream = None;
290						continue;
291					}
292					tungstenite::Error::Io(e) => {
293						tracing::warn!("received `tungstenite::Error::Io` from polling: {e:?}. Likely indicates connection issues. Skipping.");
294						continue;
295					}
296					tungstenite::Error::Tls(_tls_error) => todo!(),
297					tungstenite::Error::Capacity(capacity_error) => {
298						tracing::warn!("received `tungstenite::Error::Capacity` from polling: {capacity_error:?}. Skipping.");
299						continue;
300					}
301					tungstenite::Error::Protocol(protocol_error) => {
302						tracing::warn!("received `tungstenite::Error::Protocol` from polling: {protocol_error:?}. Will reconnect");
303						self.stream = None;
304						continue;
305					}
306					tungstenite::Error::WriteBufferFull(_) => unreachable!("can only get from writing"),
307					tungstenite::Error::Utf8(e) => panic!("received `tungstenite::Error::Utf8` from polling: {e:?}. Exchange is going crazy, aborting"),
308					tungstenite::Error::AttackAttempt => {
309						tracing::warn!("received `tungstenite::Error::AttackAttempt` from polling. Don't have a reason to trust detection 100%, so just reconnecting.");
310						self.stream = None;
311						continue;
312					}
313					tungstenite::Error::Url(_url_error) => todo!(),
314					tungstenite::Error::Http(_response) => todo!(),
315					tungstenite::Error::HttpFormat(_error) => todo!(),
316				},
317			}
318		};
319		Ok(json_rpc_value)
320	}
321
322	#[instrument(skip_all)]
323	async fn send_all(&mut self, messages: Vec<tungstenite::Message>) -> Result<(), tungstenite::Error> {
324		if let Some(inner) = &mut self.stream {
325			match messages.len() {
326				0 => return Ok(()),
327				1 => {
328					tracing::debug!("sending to server: {:#?}", &messages[0]);
329					inner.send(messages.into_iter().next().unwrap()).await?;
330					inner.last_unanswered_communication = Some(SystemTime::now());
331				}
332				_ => {
333					tracing::debug!("sending to server: {messages:#?}");
334					let mut message_stream = futures_util::stream::iter(messages).map(Ok);
335					inner.send_all(&mut message_stream).await?;
336					inner.last_unanswered_communication = Some(SystemTime::now());
337				}
338			};
339			Ok(())
340		} else {
341			Err(tungstenite::Error::ConnectionClosed)
342		}
343	}
344
345	async fn send(&mut self, message: tungstenite::Message) -> Result<(), tungstenite::Error> {
346		self.send_all(vec![message]).await // Vec cost is negligible
347	}
348
349	async fn connect(&mut self) -> Result<(), WsError> {
350		tracing::info!("Connecting to {}...", self.url);
351		{
352			let now = SystemTime::now();
353			let timeout = self.config.reconnect_cooldown;
354			if self.last_reconnect_attempt + timeout > now {
355				tracing::warn!("Reconnect cooldown is triggered. Likely indicative of a bad connection.");
356				let duration = (self.last_reconnect_attempt + timeout).duration_since(now).unwrap();
357				tokio::time::sleep(duration).await;
358			}
359		}
360		self.last_reconnect_attempt = SystemTime::now();
361
362		let (stream, http_resp) = tokio_tungstenite::connect_async(self.url.as_str()).await?;
363		tracing::debug!("Ws handshake with server: {http_resp:#?}");
364
365		let now = SystemTime::now();
366		self.stream = Some(WsConnectionStream::new(stream, now));
367
368		let auth_messages = self.handler.handle_auth()?;
369		Ok(self.send_all(auth_messages).await?)
370	}
371
372	/// Sends the existing connection (if any) a `Close` message, and then simply drops it, opening a new one.
373	///
374	/// `pub` for testing only, does not {have to || is expected to} be exposed in any wrappers.
375	pub async fn reconnect(&mut self) -> Result<(), WsError> {
376		if self.stream.is_some() {
377			tracing::info!("Dropping old connection before reconnecting...");
378			{
379				let stream = self.stream.as_mut().unwrap();
380				stream.send(tungstenite::Message::Close(None)).await?;
381				self.stream = None;
382			}
383		}
384		self.connect().await
385	}
386}
387
388/// Configuration for [WsHandler].
389///
390/// Should be returned by [WsHandler::ws_config()].
391#[derive(Clone, Debug, Default, Eq, PartialEq)]
392pub struct WsConfig {
393	/// Whether the connection should be authenticated. Normally implemented through a "listen key"
394	pub auth: bool,
395	/// Prefix which will be used for connections that started using this `WebSocketConfig`.
396	///
397	/// Ex: `"wss://example.com"`
398	pub base_url: Option<Url>,
399	/// Duration that should elapse between each attempt to start a new connection.
400	///
401	/// This matters because the [WebSocketConnection] reconnects on error. If the error
402	/// continues to happen, it could spam the server if `connect_cooldown` is too short.
403	reconnect_cooldown: Duration = Duration::from_secs(3),
404	/// The [WebSocketConnection] will automatically reconnect when `refresh_after` has elapsed since the last connection started.
405	refresh_after: Duration = Duration::from_hours(12),
406	/// A reconnection will be triggered if no messages are received within this amount of time.
407	message_timeout: Duration = Duration::from_mins(16), // assume all exchanges ping more frequently than this
408	/// Timeout for the response to a message sent to the server.
409	///
410	/// Difference from the [message_timeout](Self::message_timeout) is that here we directly request communication. Eg: sending a Ping or attempting to auth.
411	response_timeout: Duration = Duration::from_mins(2),
412	/// The topics that will be subscribed to on creation of the connection. Note that we don't allow for passing anything that changes state here like [Trade](Topic::Trade) payloads, thus submissions are limited to [String]s
413	pub topics: HashSet<String>,
414}
415
416impl WsConfig {
417	pub fn set_reconnect_cooldown(&mut self, reconnect_cooldown: Duration) -> Result<()> {
418		if reconnect_cooldown.is_zero() {
419			bail!("connect_cooldown must be greater than 0");
420		}
421		self.reconnect_cooldown = reconnect_cooldown;
422		Ok(())
423	}
424
425	pub fn set_refresh_after(&mut self, refresh_after: Duration) -> Result<()> {
426		if refresh_after.is_zero() {
427			bail!("refresh_after must be greater than 0");
428		}
429		self.refresh_after = refresh_after;
430		Ok(())
431	}
432
433	pub fn set_message_timeout(&mut self, message_timeout: Duration) -> Result<()> {
434		if message_timeout.is_zero() {
435			bail!("message_timeout must be greater than 0");
436		}
437		self.message_timeout = message_timeout;
438		Ok(())
439	}
440
441	pub fn set_response_timout(&mut self, response_timeout: Duration) -> Result<()> {
442		if response_timeout.is_zero() {
443			bail!("response_timeout must be greater than 0");
444		}
445		self.response_timeout = response_timeout;
446		Ok(())
447	}
448}
449
450#[derive(Debug, derive_more::Display, thiserror::Error, derive_more::From)]
451pub enum WsError {
452	Definition(WsDefinitionError),
453	Tungstenite(tungstenite::Error),
454	Auth(AuthError),
455	Parse(serde_json::Error),
456	Subscription(String),
457	NetworkConnection,
458	Url(UrlError),
459	UnexpectedEvent(serde_json::Value),
460	Other(eyre::Report),
461}
462#[derive(Debug, derive_more::Display, thiserror::Error)]
463pub enum WsDefinitionError {
464	MissingUrl,
465}
466
467//DEPRECATE: or reinstate, - can't even remember what's this now
468//#[derive(Debug, derive_more::Display, thiserror::Error)]
469//pub enum SubscriptionError {
470//	Topic(String),
471//	Params(serde_json::Value),
472//	Incompatible(IncompatibleSubscriptionError),
473//}
474//#[derive(Debug, thiserror::Error)]
475//#[error("Incompatible subscription error: could not subscribe to {topic:#?} on {base_url}")]
476//pub struct IncompatibleSubscriptionError {
477//	topic: Topic,
478//	base_url: Url,
479//}
480
481#[derive(Clone, Debug, derive_more::Display, Eq, Hash, PartialEq, serde::Serialize)]
482pub enum Topic {
483	String(String),
484	Order(serde_json::Value),
485}