1use async_tungstenite::{
2 stream::Stream as StreamSwitcher,
3 tokio::{connect_async, TokioAdapter},
4 tungstenite::{
5 handshake::client::Response,
6 protocol::{frame::coding::CloseCode, CloseFrame},
7 Message,
8 },
9 WebSocketStream as WsStream,
10};
11use core::pin::Pin;
12use futures::{
13 sink::Sink,
14 stream::TryStreamExt,
15 task::{Context, Poll},
16 SinkExt, Stream,
17};
18use tokio::net::TcpStream;
19use tokio_native_tls::TlsStream;
20
21use crate::error::{Error, Kind, WsCloseError};
22use crate::param::Interval;
23use serde::de::DeserializeOwned;
24use serde::Serialize;
25use serde_json::Value;
26use std::fmt;
27
28pub const BINANCE_US_WSS_URL: &'static str = "wss://stream.binance.us:9443";
30
31#[derive(Copy, Clone)]
32pub enum Channel<'c> {
33 AggTrade(&'c str),
34 Depth(&'c str, Speed),
35 Trade(&'c str),
36 Kline(&'c str, Interval),
37 MiniTicker(&'c str),
38 AllMiniTickers,
39 Ticker(&'c str),
40 AllTickers,
41 BookTicker(&'c str),
42 AllBookTickers,
43 PartialDepth(&'c str, Level, Speed),
44 UserData(&'c str),
46}
47
48impl<'c> fmt::Display for Channel<'c> {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 Self::AggTrade(symbol) => write!(f, "{}", symbol.to_lowercase() + "@aggTrade"),
52 Self::Trade(symbol) => write!(f, "{}", symbol.to_lowercase() + "@trade"),
53 Self::Kline(symbol, interval) => {
54 let interval = serde_json::to_value(interval).unwrap();
55 write!(
56 f,
57 "{}",
58 symbol.to_lowercase() + "@kline_" + interval.as_str().unwrap()
59 )
60 }
61 Self::MiniTicker(symbol) => write!(f, "{}", symbol.to_lowercase() + "@miniTicker"),
62 Self::AllMiniTickers => write!(f, "!miniTicker@arr"),
63 Self::Ticker(symbol) => write!(f, "{}", symbol.to_lowercase() + "@ticker"),
64 Self::AllTickers => write!(f, "!ticker@arr"),
65 Self::BookTicker(symbol) => write!(f, "{}", symbol.to_lowercase() + "@bookTicker"),
66 Self::AllBookTickers => write!(f, "!bookTicker"),
67 Self::PartialDepth(symbol, level, speed) => {
68 let level = serde_json::to_value(level).unwrap();
69 let speed = serde_json::to_value(speed).unwrap();
70 write!(
71 f,
72 "{}",
73 symbol.to_lowercase()
74 + "@depth"
75 + level.as_str().unwrap()
76 + "@"
77 + speed.as_str().unwrap()
78 )
79 }
80 Self::Depth(symbol, speed) => {
81 let speed = serde_json::to_value(speed).unwrap();
82 write!(
83 f,
84 "{}",
85 symbol.to_lowercase() + "@depth@" + speed.as_str().unwrap()
86 )
87 }
88 Self::UserData(listen_key) => write!(f, "{}", listen_key),
89 }
90 }
91}
92
93impl<'a, 'c> PartialEq<&'a str> for Channel<'c> {
94 fn eq(&self, other: &&str) -> bool {
95 self.to_string() == *other
96 }
97}
98
99impl<'c> PartialEq<String> for Channel<'c> {
100 fn eq(&self, other: &String) -> bool {
101 self.to_string() == *other
102 }
103}
104
105impl<'c> PartialEq<Value> for Channel<'c> {
106 fn eq(&self, other: &Value) -> bool {
107 self.to_string() == *other
108 }
109}
110
111#[derive(Copy, Clone, Serialize)]
112pub enum Level {
113 #[serde(rename = "5")]
114 Five,
115 #[serde(rename = "10")]
116 Ten,
117 #[serde(rename = "20")]
118 Twenty,
119}
120
121#[derive(Copy, Clone, Serialize)]
122pub enum Speed {
123 #[serde(rename = "100ms")]
124 HundredMillis,
125 #[serde(rename = "1000ms")]
126 ThousandMillis,
127}
128
129#[derive(Serialize)]
130struct SubscribeMessage<'a> {
131 method: &'a str,
132 params: &'a [Value],
133 id: u64,
134}
135
136type InnerStream = (
137 WsStream<StreamSwitcher<TokioAdapter<TcpStream>, TokioAdapter<TlsStream<TcpStream>>>>,
138 Response,
139);
140
141pub struct WebSocketStream {
143 inner: InnerStream,
144 id: u64,
145}
146
147impl WebSocketStream {
148 pub async fn connect<U: Into<String>>(
162 channel: Channel<'_>,
163 url: U,
164 ) -> crate::error::Result<Self> {
165 let url = url.into() + "/ws/" + &channel.to_string();
166
167 let inner = connect_async(url).await?;
168 let mut stream = Self { inner, id: 0 };
169
170 let message = SubscribeMessage {
171 method: "SET_PROPERTY",
172 params: &["combined".into(), true.into()],
173 id: stream.id,
174 };
175 let message = serde_json::to_string(&message)?;
176 stream.send(Message::Text(message)).await?;
177 stream.id += 1;
178
179 Ok(stream)
180 }
181 pub async fn text(&mut self) -> crate::error::Result<Option<String>> {
197 match self.try_next().await? {
198 Some(msg) => match msg {
199 Message::Text(text) => Ok(Some(text)),
200 Message::Ping(ref value) => {
201 self.send(Message::Pong(value.clone())).await?;
202 let ping = serde_json::json!({
203 "ping": msg.into_text()?,
204 });
205 Ok(Some(serde_json::to_string(&ping)?))
206 }
207 Message::Pong(ref value) => {
208 self.send(Message::Ping(value.clone())).await?;
209 let pong = serde_json::json!({
210 "pong": msg.into_text()?,
211 });
212 Ok(Some(serde_json::to_string(&pong)?))
213 }
214 Message::Binary(_) => Ok(Some(msg.into_text()?)),
215 Message::Close(Some(frame)) => {
216 Err(WsCloseError::new(frame.code, frame.reason).into())
217 }
218 Message::Close(None) => Err(WsCloseError::new(
219 CloseCode::Abnormal,
220 "Close message with no frame received",
221 )
222 .into()),
223 },
224 None => Ok(None),
225 }
226 }
227 pub async fn json<J: DeserializeOwned>(&mut self) -> crate::error::Result<Option<J>> {
248 match self.text().await? {
249 Some(text) => Ok(Some(serde_json::from_str(&text)?)),
250 None => Ok(None),
251 }
252 }
253 pub async fn subscribe(&mut self, channels: &[Channel<'_>]) -> crate::error::Result<()> {
274 self.send_msg("SUBSCRIBE", channels).await
275 }
276 pub async fn unsubscribe(&mut self, channels: &[Channel<'_>]) -> crate::error::Result<()> {
296 self.send_msg("UNSUBSCRIBE", channels).await
297 }
298 pub fn get_ref(&self) -> &InnerStream {
300 &self.inner
301 }
302 pub fn get_mut(&mut self) -> &mut InnerStream {
304 &mut self.inner
305 }
306 pub async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> crate::error::Result<()> {
308 self.inner.0.close(msg).await?;
309 Ok(())
310 }
311
312 async fn send_msg(
313 &mut self,
314 method: &str,
315 channels: &[Channel<'_>],
316 ) -> crate::error::Result<()> {
317 let params: Vec<_> = channels
318 .iter()
319 .map(|channel| Value::String(channel.to_string()))
320 .collect();
321
322 let message = SubscribeMessage {
323 method,
324 params: ¶ms,
325 id: self.id,
326 };
327 let message = serde_json::to_string(&message)?;
328 self.send(Message::Text(message)).await?;
329 self.id += 1;
330 Ok(())
331 }
332}
333
334impl Stream for WebSocketStream {
335 type Item = crate::error::Result<Message>;
336
337 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
338 match self.inner.0.try_poll_next_unpin(cx) {
339 Poll::Ready(Some(val)) => Poll::Ready(Some(Ok(val?))),
340 Poll::Ready(None) => Poll::Ready(None),
341 Poll::Pending => Poll::Pending,
342 }
343 }
344}
345
346impl Sink<Message> for WebSocketStream {
347 type Error = Error;
348
349 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
350 match self.inner.0.poll_ready_unpin(cx) {
351 Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)),
352 Poll::Ready(Err(val)) => Poll::Ready(Err(Error::new(Kind::Tungstenite, Some(val)))),
353 Poll::Pending => Poll::Pending,
354 }
355 }
356
357 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
358 match self.inner.0.start_send_unpin(item) {
359 Ok(val) => Ok(val),
360 Err(val) => Err(Error::new(Kind::Tungstenite, Some(val))),
361 }
362 }
363
364 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
365 match self.inner.0.poll_flush_unpin(cx) {
366 Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)),
367 Poll::Ready(Err(val)) => Poll::Ready(Err(Error::new(Kind::Tungstenite, Some(val)))),
368 Poll::Pending => Poll::Pending,
369 }
370 }
371 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
372 match self.inner.0.poll_close_unpin(cx) {
373 Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)),
374 Poll::Ready(Err(val)) => Poll::Ready(Err(Error::new(Kind::Tungstenite, Some(val)))),
375 Poll::Pending => Poll::Pending,
376 }
377 }
378}