v_exchanges_api_generics/
ws.rs

1use std::{
2	collections::HashSet,
3	time::{Duration, SystemTime},
4	vec,
5};
6
7use eyre::{Result, bail};
8use futures_util::{SinkExt as _, StreamExt as _};
9use jiff::Timestamp;
10use reqwest::Url;
11use tokio::net::TcpStream;
12use tokio_tungstenite::{
13	MaybeTlsStream, WebSocketStream,
14	tungstenite::{self, Bytes},
15};
16use tracing::instrument;
17
18use crate::{AuthError, UrlError};
19
20type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
21
22/// handle exchange-level events on the [WsConnection].
23pub trait WsHandler: std::fmt::Debug {
24	/// Returns a [WsConfig] that will be applied for all WebSocket connections handled by this handler.
25	fn config(&self) -> Result<WsConfig, UrlError> {
26		Ok(WsConfig::default())
27	}
28
29	/// Called when the [WsConnection] is created and on reconnection. Returned messages will be sent back to the server as-is.
30	///
31	/// Handling of `listen-key`s or any other authentication methods exchange demands should be done here. Although oftentimes handling the auth will spread into the [handle_message](Self::handle_message) too.
32	/// Can be ran multiple times (on every reconnect). Thus this inherently cannot be used to initiate connectionions based on a change of state (ie order creation).
33	#[allow(unused_variables)]
34	fn handle_auth(&mut self) -> Result<Vec<tungstenite::Message>, WsError> {
35		Ok(vec![])
36	}
37
38	//Q: problem: can be either {String, serde_json::Value} //? other things?
39	/*
40	  "position"
41	  ||
42	  json!{
43	"id": "56374a46-3061-486b-a311-99ee972eb648",
44	"method": "order.place",
45	"params": {
46	  "symbol": "BTCUSDT",
47	  "side": "SELL",
48	  "type": "LIMIT",
49	  "timeInForce": "GTC",
50	  "price": "23416.10000000",
51	  "quantity": "0.00847000",
52	  "apiKey": "vmPUZE6mv9SD5VNHk4HlWFsOr6aKE2zvsw0MuIgwCIPy6utIco14y7Ju91duEh8A",
53	  "signature": "15af09e41c36f3cc61378c2fbe2c33719a03dd5eba8d0f9206fbda44de717c88",
54	  "timestamp": 1660801715431
55	  }
56	  }
57	  - and then the latter could be requiring signing
58	  */
59	#[allow(unused_variables)]
60	fn handle_subscribe(&mut self, topics: HashSet<Topic>) -> Result<Vec<tungstenite::Message>, WsError>;
61
62	/// Called when the [WsConnection] received a JSON-RPC value, returns messages to be sent to the server or the content with parsed event name. If not the desired content and no respose is to be sent (like after a confirmation for a subscription), return a Response with an empty Vec.
63	#[allow(unused_variables)]
64	fn handle_jrpc(&mut self, jrpc: serde_json::Value) -> Result<ResponseOrContent, WsError>;
65	//A: use this iff spot&&perp binance accept listen-key refresh through stream
66	///// Additional POST communication with the exchange, not conditional on received messages, can be handled here.
67	///// Really this is just for damn Binance with their stupid `listn-key` standard.
68	//fn handle_post(&mut self) -> Result<Option<Vec<tungstenite::Message>>, WsError> {
69	//	Ok(None)
70	//}
71
72	//#[allow(unused_variables)]
73	//fn handle_jrpc(&mut self, jrpc: &serde_json::Value) -> Result<Option<Vec<tungstenite::Message>>, WsError> {
74	//	Ok(None)
75	//}
76}
77
78#[derive(Clone, Debug)]
79pub enum ResponseOrContent {
80	/// Response to a message sent to the server.
81	Response(Vec<tungstenite::Message>),
82	/// Content received from the server.
83	Content(ContentEvent),
84}
85#[derive(Clone, Debug)]
86pub struct ContentEvent {
87	pub data: serde_json::Value,
88	pub topic: String,
89	pub time: Timestamp,
90	pub event_type: String,
91}
92
93#[derive(Clone, Debug, Eq)]
94pub struct TopicInterpreter<T> {
95	/// Only one interpreter for this name is allowed to exist // enforced through `Hash` impl defined over `event_name` only
96	pub event_name: String,
97	/// When name matches, interpretation should succeed.
98	pub interpret: fn(&serde_json::Value) -> Result<T, WsError>,
99}
100impl<T> std::hash::Hash for TopicInterpreter<T> {
101	fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
102		self.event_name.hash(state);
103	}
104}
105impl<T> PartialEq for TopicInterpreter<T> {
106	fn eq(&self, other: &Self) -> bool {
107		self.event_name == other.event_name
108	}
109}
110
111/// Main way to interact with the WebSocket APIs.
112#[derive(Debug)]
113pub struct WsConnection<H: WsHandler> {
114	url: Url,
115	config: WsConfig,
116	handler: H,
117	stream: Option<WsConnectionStream>,
118	last_reconnect_attempt: SystemTime, // not Tz-aware, as it will not escape the application boundary
119}
120#[derive(Debug, derive_more::Deref, derive_more::DerefMut)]
121struct WsConnectionStream {
122	#[deref_mut]
123	#[deref]
124	stream: WsStream,
125	connected_since: SystemTime,
126	last_unanswered_communication: Option<SystemTime>,
127}
128impl WsConnectionStream {
129	fn new(stream: WsStream, connected_since: SystemTime) -> Self {
130		Self {
131			stream,
132			connected_since,
133			last_unanswered_communication: None,
134		}
135	}
136}
137impl<H: WsHandler> WsConnection<H> {
138	#[allow(missing_docs)]
139	pub fn try_new(url_suffix: &str, handler: H) -> Result<Self, UrlError> {
140		let config = handler.config()?;
141		let url = match &config.base_url {
142			Some(base_url) => base_url.join(url_suffix)?,
143			None => Url::parse(url_suffix)?,
144		};
145
146		Ok(Self {
147			url,
148			config,
149			handler,
150			stream: None,
151			last_reconnect_attempt: SystemTime::UNIX_EPOCH,
152		})
153	}
154
155	/// The main interface. All ws operations are hidden, only thing getting through are the content messages or the lack thereof.
156	pub async fn next(&mut self) -> Result<ContentEvent, WsError> {
157		if let Some(inner) = &self.stream
158			&& inner.connected_since + self.config.refresh_after < SystemTime::now()
159		{
160			tracing::info!("Refreshing connection, as `refresh_after` specified in WsConfig has elapsed ({:?})", self.config.refresh_after);
161			self.reconnect().await?;
162		}
163		if self.stream.is_none() {
164			self.connect().await?;
165		}
166		//- at this point self.inner is Some
167
168		// loop until we get actual content
169		let json_rpc_value = loop {
170			// force a response out of the server.
171			let resp = {
172				let timeout = match self.stream.as_ref().unwrap().last_unanswered_communication {
173					Some(last_unanswered) => {
174						let now = SystemTime::now();
175						match last_unanswered + self.config.response_timeout > now {
176							true => self.config.response_timeout,
177							false => {
178								tracing::error!(
179									"Timeout for last unanswered communication ended before `.next()` was called. This likely indicates an implementation error on the clientside."
180								);
181								self.reconnect().await?;
182								continue;
183							}
184						}
185					}
186					None => self.config.message_timeout,
187				};
188
189				let timeout_handle = tokio::time::timeout(timeout, {
190					let stream = self.stream.as_mut().unwrap();
191					stream.next()
192				});
193				match timeout_handle.await {
194					Ok(Some(resp)) => {
195						self.stream.as_mut().unwrap().last_unanswered_communication = None;
196						resp
197					}
198					Ok(None) => {
199						tracing::warn!("tungstenite couldn't read from the stream. Restarting.");
200						self.reconnect().await?;
201						continue;
202					}
203					Err(timeout_error) => {
204						tracing::warn!("Message reception timed out after {:?} seconds. // {timeout_error}", timeout);
205						{
206							let stream = self.stream.as_mut().unwrap();
207							match stream.last_unanswered_communication.is_some() {
208								true => self.reconnect().await?,
209								false => {
210									// Reached standard message_timeout (one for messages sent when we're not forcing communication). So let's force it.
211									self.send(tungstenite::Message::Ping(Bytes::default())).await?;
212									continue;
213								}
214							}
215						}
216						continue;
217					}
218				}
219			};
220
221			// some response received, handle it
222			match resp {
223				Ok(succ_resp) => match succ_resp {
224					tungstenite::Message::Text(text) => {
225						let value: serde_json::Value =
226							serde_json::from_str(&text).expect("API sent invalid JSON, which is completely unexpected. Disappointment is immeasurable and the day is ruined.");
227						tracing::trace!("{value:#?}"); // only log it after the `handle_message` has ran, as we're assuming that if it takes any actions, it will handle logging itself. (and that will likely be at a different level of important too)
228						break match { self.handler.handle_jrpc(value)? } {
229							ResponseOrContent::Response(messages) => {
230								self.send_all(messages).await?;
231								continue; // only need to send responses when it's not yet the desired content.
232							}
233							ResponseOrContent::Content(content) => content,
234						};
235					}
236					tungstenite::Message::Binary(_) => {
237						panic!("Received binary. But exchanges are not smart enough to send this, what is happening");
238					}
239					tungstenite::Message::Ping(bytes) => {
240						self.send(tungstenite::Message::Pong(bytes)).await?; // Binance specifically requires the exact ping's payload to be returned here: https://developers.binance.com/docs/binance-spot-api-docs/web-socket-streams
241						tracing::debug!("ponged");
242						continue;
243					}
244					// in most cases these are not seen, as it's sufficient to just answer to their [pings](tungstenite::Message::Ping). Our own pings are sent only when we haven't heard from the exchange for a while, in which case it's likely that it will not [pong](tungstenite::Message::Pong) back either.
245					tungstenite::Message::Pong(_) => {
246						tracing::info!("Received pong");
247						continue;
248					}
249					tungstenite::Message::Close(maybe_reason) => {
250						match maybe_reason {
251							Some(close_frame) => {
252								//Q: maybe need to expose def of this for ind exchanges (so we can interpret the codes)
253								tracing::info!("Server closed connection; reason: {close_frame:?}");
254							}
255							None => {
256								tracing::info!("Server closed connection; no reason specified.");
257							}
258						}
259						self.stream = None;
260						self.reconnect().await?;
261						continue;
262					}
263					tungstenite::Message::Frame(_) => {
264						unreachable!("Can't get from reading");
265					}
266				},
267				Err(err) => match err {
268					tungstenite::Error::ConnectionClosed => {
269						tracing::error!("received `tungstenite::Error::ConnectionClosed` on polling. Will reconnect");
270						self.stream = None;
271						continue;
272					}
273					tungstenite::Error::AlreadyClosed => {
274						tracing::error!("received `tungstenite::Error::AlreadyClosed` from polling. Will reconnect");
275						self.stream = None;
276						continue;
277					}
278					tungstenite::Error::Io(e) => {
279						tracing::warn!("received `tungstenite::Error::Io` from polling: {e:?}. Likely indicates connection issues. Skipping.");
280						continue;
281					}
282					tungstenite::Error::Tls(_tls_error) => todo!(),
283					tungstenite::Error::Capacity(capacity_error) => {
284						tracing::warn!("received `tungstenite::Error::Capacity` from polling: {capacity_error:?}. Skipping.");
285						continue;
286					}
287					tungstenite::Error::Protocol(protocol_error) => {
288						tracing::warn!("received `tungstenite::Error::Protocol` from polling: {protocol_error:?}. Will reconnect");
289						self.stream = None;
290						continue;
291					}
292					tungstenite::Error::WriteBufferFull(_) => unreachable!("can only get from writing"),
293					tungstenite::Error::Utf8(e) => panic!("received `tungstenite::Error::Utf8` from polling: {e:?}. Exchange is going crazy, aborting"),
294					tungstenite::Error::AttackAttempt => {
295						tracing::warn!("received `tungstenite::Error::AttackAttempt` from polling. Don't have a reason to trust detection 100%, so just reconnecting.");
296						self.stream = None;
297						continue;
298					}
299					tungstenite::Error::Url(_url_error) => todo!(),
300					tungstenite::Error::Http(_response) => todo!(),
301					tungstenite::Error::HttpFormat(_error) => todo!(),
302				},
303			}
304		};
305		Ok(json_rpc_value)
306	}
307
308	#[instrument(skip_all)]
309	async fn send_all(&mut self, messages: Vec<tungstenite::Message>) -> Result<(), tungstenite::Error> {
310		if let Some(inner) = &mut self.stream {
311			match messages.len() {
312				0 => return Ok(()),
313				1 => {
314					tracing::debug!("sending to server: {:#?}", &messages[0]);
315					inner.send(messages.into_iter().next().unwrap()).await?;
316					inner.last_unanswered_communication = Some(SystemTime::now());
317				}
318				_ => {
319					tracing::debug!("sending to server: {messages:#?}");
320					let mut message_stream = futures_util::stream::iter(messages).map(Ok);
321					inner.send_all(&mut message_stream).await?;
322					inner.last_unanswered_communication = Some(SystemTime::now());
323				}
324			};
325			Ok(())
326		} else {
327			Err(tungstenite::Error::ConnectionClosed)
328		}
329	}
330
331	async fn send(&mut self, message: tungstenite::Message) -> Result<(), tungstenite::Error> {
332		self.send_all(vec![message]).await // Vec cost is negligible
333	}
334
335	async fn connect(&mut self) -> Result<(), WsError> {
336		tracing::info!("Connecting to {}...", self.url);
337		{
338			let now = SystemTime::now();
339			let timeout = self.config.reconnect_cooldown;
340			if self.last_reconnect_attempt + timeout > now {
341				tracing::warn!("Reconnect cooldown is triggered. Likely indicative of a bad connection.");
342				let duration = (self.last_reconnect_attempt + timeout).duration_since(now).unwrap();
343				tokio::time::sleep(duration).await;
344			}
345		}
346		self.last_reconnect_attempt = SystemTime::now();
347
348		let (stream, http_resp) = tokio_tungstenite::connect_async(self.url.as_str()).await?;
349		tracing::debug!("Ws handshake with server: {http_resp:#?}");
350
351		let now = SystemTime::now();
352		self.stream = Some(WsConnectionStream::new(stream, now));
353
354		let auth_messages = self.handler.handle_auth()?;
355		Ok(self.send_all(auth_messages).await?)
356	}
357
358	/// Sends the existing connection (if any) a `Close` message, and then simply drops it, opening a new one.
359	///
360	/// `pub` for testing only, does not {have to || is expected to} be exposed in any wrappers.
361	pub async fn reconnect(&mut self) -> Result<(), WsError> {
362		if self.stream.is_some() {
363			tracing::info!("Dropping old connection before reconnecting...");
364			{
365				let stream = self.stream.as_mut().unwrap();
366				stream.send(tungstenite::Message::Close(None)).await?;
367				self.stream = None;
368			}
369		}
370		self.connect().await
371	}
372}
373
374/// Configuration for [WsHandler].
375///
376/// Should be returned by [WsHandler::ws_config()].
377#[derive(Clone, Debug, Default, Eq, PartialEq)]
378pub struct WsConfig {
379	/// Whether the connection should be authenticated. Normally implemented through a "listen key"
380	pub auth: bool,
381	/// Prefix which will be used for connections that started using this `WebSocketConfig`.
382	///
383	/// Ex: `"wss://example.com"`
384	pub base_url: Option<Url>,
385	/// Duration that should elapse between each attempt to start a new connection.
386	///
387	/// This matters because the [WebSocketConnection] reconnects on error. If the error
388	/// continues to happen, it could spam the server if `connect_cooldown` is too short.
389	reconnect_cooldown: Duration = Duration::from_secs(3),
390	/// The [WebSocketConnection] will automatically reconnect when `refresh_after` has elapsed since the last connection started.
391	refresh_after: Duration = Duration::from_hours(12),
392	/// A reconnection will be triggered if no messages are received within this amount of time.
393	message_timeout: Duration = Duration::from_mins(16), // assume all exchanges ping more frequently than this
394	/// Timeout for the response to a message sent to the server.
395	///
396	/// Difference from the [message_timeout](Self::message_timeout) is that here we directly request communication. Eg: sending a Ping or attempting to auth.
397	response_timeout: Duration = Duration::from_mins(2),
398	/// The topics that will be subscribed to on creation of the connection. Note that we don't allow for passing anything that changes state here like [Trade](Topic::Trade) payloads, thus submissions are limited to [String]s
399	pub topics: HashSet<String>,
400}
401
402impl WsConfig {
403	pub fn set_reconnect_cooldown(&mut self, reconnect_cooldown: Duration) -> Result<()> {
404		if reconnect_cooldown.is_zero() {
405			bail!("connect_cooldown must be greater than 0");
406		}
407		self.reconnect_cooldown = reconnect_cooldown;
408		Ok(())
409	}
410
411	pub fn set_refresh_after(&mut self, refresh_after: Duration) -> Result<()> {
412		if refresh_after.is_zero() {
413			bail!("refresh_after must be greater than 0");
414		}
415		self.refresh_after = refresh_after;
416		Ok(())
417	}
418
419	pub fn set_message_timeout(&mut self, message_timeout: Duration) -> Result<()> {
420		if message_timeout.is_zero() {
421			bail!("message_timeout must be greater than 0");
422		}
423		self.message_timeout = message_timeout;
424		Ok(())
425	}
426
427	pub fn set_response_timout(&mut self, response_timeout: Duration) -> Result<()> {
428		if response_timeout.is_zero() {
429			bail!("response_timeout must be greater than 0");
430		}
431		self.response_timeout = response_timeout;
432		Ok(())
433	}
434}
435
436#[derive(Debug, derive_more::Display, thiserror::Error, derive_more::From)]
437pub enum WsError {
438	Definition(WsDefinitionError),
439	Tungstenite(tungstenite::Error),
440	Auth(AuthError),
441	Parse(serde_json::Error),
442	Subscription(String),
443	NetworkConnection,
444	Url(UrlError),
445	UnexpectedEvent(serde_json::Value),
446	Other(eyre::Report),
447}
448#[derive(Debug, derive_more::Display, thiserror::Error)]
449pub enum WsDefinitionError {
450	MissingUrl,
451}
452
453//DEPRECATE: or reinstate, - can't even remember what's this now
454//#[derive(Debug, derive_more::Display, thiserror::Error)]
455//pub enum SubscriptionError {
456//	Topic(String),
457//	Params(serde_json::Value),
458//	Incompatible(IncompatibleSubscriptionError),
459//}
460//#[derive(Debug, thiserror::Error)]
461//#[error("Incompatible subscription error: could not subscribe to {topic:#?} on {base_url}")]
462//pub struct IncompatibleSubscriptionError {
463//	topic: Topic,
464//	base_url: Url,
465//}
466
467#[derive(Clone, Debug, derive_more::Display, Eq, Hash, PartialEq, serde::Serialize)]
468pub enum Topic {
469	String(String),
470	Order(serde_json::Value),
471}