tokio_binance/
ws_stream.rs

1use async_tungstenite::{
2    stream::Stream as StreamSwitcher,
3    tokio::{connect_async, TokioAdapter},
4    tungstenite::{
5        handshake::client::Response,
6        protocol::{frame::coding::CloseCode, CloseFrame},
7        Message,
8    },
9    WebSocketStream as WsStream,
10};
11use core::pin::Pin;
12use futures::{
13    sink::Sink,
14    stream::TryStreamExt,
15    task::{Context, Poll},
16    SinkExt, Stream,
17};
18use tokio::net::TcpStream;
19use tokio_native_tls::TlsStream;
20
21use crate::error::{Error, Kind, WsCloseError};
22use crate::param::Interval;
23use serde::de::DeserializeOwned;
24use serde::Serialize;
25use serde_json::Value;
26use std::fmt;
27
28/// wss://stream.binance.us:9443
29pub const BINANCE_US_WSS_URL: &'static str = "wss://stream.binance.us:9443";
30
31#[derive(Copy, Clone)]
32pub enum Channel<'c> {
33    AggTrade(&'c str),
34    Depth(&'c str, Speed),
35    Trade(&'c str),
36    Kline(&'c str, Interval),
37    MiniTicker(&'c str),
38    AllMiniTickers,
39    Ticker(&'c str),
40    AllTickers,
41    BookTicker(&'c str),
42    AllBookTickers,
43    PartialDepth(&'c str, Level, Speed),
44    /// The only channel that takes a listen-key instead of a symbol
45    UserData(&'c str),
46}
47
48impl<'c> fmt::Display for Channel<'c> {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match self {
51            Self::AggTrade(symbol) => write!(f, "{}", symbol.to_lowercase() + "@aggTrade"),
52            Self::Trade(symbol) => write!(f, "{}", symbol.to_lowercase() + "@trade"),
53            Self::Kline(symbol, interval) => {
54                let interval = serde_json::to_value(interval).unwrap();
55                write!(
56                    f,
57                    "{}",
58                    symbol.to_lowercase() + "@kline_" + interval.as_str().unwrap()
59                )
60            }
61            Self::MiniTicker(symbol) => write!(f, "{}", symbol.to_lowercase() + "@miniTicker"),
62            Self::AllMiniTickers => write!(f, "!miniTicker@arr"),
63            Self::Ticker(symbol) => write!(f, "{}", symbol.to_lowercase() + "@ticker"),
64            Self::AllTickers => write!(f, "!ticker@arr"),
65            Self::BookTicker(symbol) => write!(f, "{}", symbol.to_lowercase() + "@bookTicker"),
66            Self::AllBookTickers => write!(f, "!bookTicker"),
67            Self::PartialDepth(symbol, level, speed) => {
68                let level = serde_json::to_value(level).unwrap();
69                let speed = serde_json::to_value(speed).unwrap();
70                write!(
71                    f,
72                    "{}",
73                    symbol.to_lowercase()
74                        + "@depth"
75                        + level.as_str().unwrap()
76                        + "@"
77                        + speed.as_str().unwrap()
78                )
79            }
80            Self::Depth(symbol, speed) => {
81                let speed = serde_json::to_value(speed).unwrap();
82                write!(
83                    f,
84                    "{}",
85                    symbol.to_lowercase() + "@depth@" + speed.as_str().unwrap()
86                )
87            }
88            Self::UserData(listen_key) => write!(f, "{}", listen_key),
89        }
90    }
91}
92
93impl<'a, 'c> PartialEq<&'a str> for Channel<'c> {
94    fn eq(&self, other: &&str) -> bool {
95        self.to_string() == *other
96    }
97}
98
99impl<'c> PartialEq<String> for Channel<'c> {
100    fn eq(&self, other: &String) -> bool {
101        self.to_string() == *other
102    }
103}
104
105impl<'c> PartialEq<Value> for Channel<'c> {
106    fn eq(&self, other: &Value) -> bool {
107        self.to_string() == *other
108    }
109}
110
111#[derive(Copy, Clone, Serialize)]
112pub enum Level {
113    #[serde(rename = "5")]
114    Five,
115    #[serde(rename = "10")]
116    Ten,
117    #[serde(rename = "20")]
118    Twenty,
119}
120
121#[derive(Copy, Clone, Serialize)]
122pub enum Speed {
123    #[serde(rename = "100ms")]
124    HundredMillis,
125    #[serde(rename = "1000ms")]
126    ThousandMillis,
127}
128
129#[derive(Serialize)]
130struct SubscribeMessage<'a> {
131    method: &'a str,
132    params: &'a [Value],
133    id: u64,
134}
135
136type InnerStream = (
137    WsStream<StreamSwitcher<TokioAdapter<TcpStream>, TokioAdapter<TlsStream<TcpStream>>>>,
138    Response,
139);
140
141/// Websocket stream for the various binance channels aka streams.
142pub struct WebSocketStream {
143    inner: InnerStream,
144    id: u64,
145}
146
147impl WebSocketStream {
148    /// Start websocket stream by connecting to a channel.
149    /// # Example
150    ///
151    /// ```no_run
152    /// use tokio_binance::{WebSocketStream, BINANCE_US_WSS_URL, Channel};
153    ///
154    /// #[tokio::main]
155    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
156    ///     let channel = Channel::Ticker("BNBUSDT");
157    ///     let mut stream = WebSocketStream::connect(channel, BINANCE_US_WSS_URL).await?;
158    ///     Ok(())
159    /// }
160    /// ```
161    pub async fn connect<U: Into<String>>(
162        channel: Channel<'_>,
163        url: U,
164    ) -> crate::error::Result<Self> {
165        let url = url.into() + "/ws/" + &channel.to_string();
166
167        let inner = connect_async(url).await?;
168        let mut stream = Self { inner, id: 0 };
169
170        let message = SubscribeMessage {
171            method: "SET_PROPERTY",
172            params: &["combined".into(), true.into()],
173            id: stream.id,
174        };
175        let message = serde_json::to_string(&message)?;
176        stream.send(Message::Text(message)).await?;
177        stream.id += 1;
178
179        Ok(stream)
180    }
181    /// Helper method for getting messages as text.
182    /// # Example
183    ///
184    /// ```no_run
185    /// # use tokio_binance::{WebSocketStream, BINANCE_US_WSS_URL, Channel};
186    /// # #[tokio::main]
187    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
188    /// # let channel = Channel::Ticker("BNBUSDT");
189    /// # let mut stream = WebSocketStream::connect(channel, BINANCE_US_WSS_URL).await?;
190    /// while let Some(text) = stream.text().await? {
191    ///     println!("{}", text);
192    /// }
193    /// # Ok(())
194    /// # }
195    /// ```
196    pub async fn text(&mut self) -> crate::error::Result<Option<String>> {
197        match self.try_next().await? {
198            Some(msg) => match msg {
199                Message::Text(text) => Ok(Some(text)),
200                Message::Ping(ref value) => {
201                    self.send(Message::Pong(value.clone())).await?;
202                    let ping = serde_json::json!({
203                        "ping": msg.into_text()?,
204                    });
205                    Ok(Some(serde_json::to_string(&ping)?))
206                }
207                Message::Pong(ref value) => {
208                    self.send(Message::Ping(value.clone())).await?;
209                    let pong = serde_json::json!({
210                        "pong": msg.into_text()?,
211                    });
212                    Ok(Some(serde_json::to_string(&pong)?))
213                }
214                Message::Binary(_) => Ok(Some(msg.into_text()?)),
215                Message::Close(Some(frame)) => {
216                    Err(WsCloseError::new(frame.code, frame.reason).into())
217                }
218                Message::Close(None) => Err(WsCloseError::new(
219                    CloseCode::Abnormal,
220                    "Close message with no frame received",
221                )
222                .into()),
223            },
224            None => Ok(None),
225        }
226    }
227    /// Helper method for getting messages as a serde deserializable.
228    /// # Example
229    ///
230    /// ```no_run
231    /// # use tokio_binance::{WebSocketStream, BINANCE_US_WSS_URL, Channel};
232    /// use serde_json::Value;
233    ///
234    /// # #[tokio::main]
235    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
236    /// # let channel = Channel::Ticker("BNBUSDT");
237    /// # let mut stream = WebSocketStream::connect(channel, BINANCE_US_WSS_URL).await?;
238    /// while let Some(value) = stream.json::<Value>().await? {
239    ///     // filter the messages before accessing a field.
240    ///     if channel == value["stream"] {
241    ///         println!("{}", serde_json::to_string_pretty(&value)?);
242    ///     }
243    /// }
244    /// # Ok(())
245    /// # }
246    /// ```
247    pub async fn json<J: DeserializeOwned>(&mut self) -> crate::error::Result<Option<J>> {
248        match self.text().await? {
249            Some(text) => Ok(Some(serde_json::from_str(&text)?)),
250            None => Ok(None),
251        }
252    }
253    /// Subscribe to one or more channels aka streams.
254    /// # Example
255    ///
256    /// ```no_run
257    /// # use tokio_binance::{WebSocketStream, BINANCE_US_WSS_URL};
258    /// use tokio_binance::{Channel, Interval};
259    ///
260    /// # #[tokio::main]
261    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
262    /// # let channel = Channel::Ticker("BNBUSDT");
263    /// # let mut stream = WebSocketStream::connect(channel, BINANCE_US_WSS_URL).await?;
264    /// stream.subscribe(&[
265    ///     Channel::AggTrade("BNBUSDT"),
266    ///     Channel::Ticker("BTCUSDT"),
267    ///     Channel::Kline("BNBUSDT", Interval::OneMinute)
268    ///     // and so on
269    /// ]).await?;
270    /// # Ok(())
271    /// # }
272    /// ```
273    pub async fn subscribe(&mut self, channels: &[Channel<'_>]) -> crate::error::Result<()> {
274        self.send_msg("SUBSCRIBE", channels).await
275    }
276    /// Unsubscribe from one or more channels aka streams.
277    /// # Example
278    ///
279    /// ```no_run
280    /// # use tokio_binance::{WebSocketStream, BINANCE_US_WSS_URL};
281    /// use tokio_binance::{Channel, Interval};
282    ///
283    /// # #[tokio::main]
284    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
285    /// # let channel = Channel::Ticker("BNBUSDT");
286    /// # let mut stream = WebSocketStream::connect(channel, BINANCE_US_WSS_URL).await?;
287    /// stream.unsubscribe(&[
288    ///     Channel::AggTrade("BNBUSDT"),
289    ///     Channel::Kline("BNBUSDT", Interval::OneMinute)
290    ///     // and so on
291    /// ]).await?;
292    /// # Ok(())
293    /// # }
294    /// ```
295    pub async fn unsubscribe(&mut self, channels: &[Channel<'_>]) -> crate::error::Result<()> {
296        self.send_msg("UNSUBSCRIBE", channels).await
297    }
298    /// Returns a shared reference to the inner stream.
299    pub fn get_ref(&self) -> &InnerStream {
300        &self.inner
301    }
302    /// Returns a mutable reference to the inner stream.
303    pub fn get_mut(&mut self) -> &mut InnerStream {
304        &mut self.inner
305    }
306    /// Close the underlying web socket
307    pub async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> crate::error::Result<()> {
308        self.inner.0.close(msg).await?;
309        Ok(())
310    }
311
312    async fn send_msg(
313        &mut self,
314        method: &str,
315        channels: &[Channel<'_>],
316    ) -> crate::error::Result<()> {
317        let params: Vec<_> = channels
318            .iter()
319            .map(|channel| Value::String(channel.to_string()))
320            .collect();
321
322        let message = SubscribeMessage {
323            method,
324            params: &params,
325            id: self.id,
326        };
327        let message = serde_json::to_string(&message)?;
328        self.send(Message::Text(message)).await?;
329        self.id += 1;
330        Ok(())
331    }
332}
333
334impl Stream for WebSocketStream {
335    type Item = crate::error::Result<Message>;
336
337    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
338        match self.inner.0.try_poll_next_unpin(cx) {
339            Poll::Ready(Some(val)) => Poll::Ready(Some(Ok(val?))),
340            Poll::Ready(None) => Poll::Ready(None),
341            Poll::Pending => Poll::Pending,
342        }
343    }
344}
345
346impl Sink<Message> for WebSocketStream {
347    type Error = Error;
348
349    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
350        match self.inner.0.poll_ready_unpin(cx) {
351            Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)),
352            Poll::Ready(Err(val)) => Poll::Ready(Err(Error::new(Kind::Tungstenite, Some(val)))),
353            Poll::Pending => Poll::Pending,
354        }
355    }
356
357    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
358        match self.inner.0.start_send_unpin(item) {
359            Ok(val) => Ok(val),
360            Err(val) => Err(Error::new(Kind::Tungstenite, Some(val))),
361        }
362    }
363
364    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
365        match self.inner.0.poll_flush_unpin(cx) {
366            Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)),
367            Poll::Ready(Err(val)) => Poll::Ready(Err(Error::new(Kind::Tungstenite, Some(val)))),
368            Poll::Pending => Poll::Pending,
369        }
370    }
371    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
372        match self.inner.0.poll_close_unpin(cx) {
373            Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)),
374            Poll::Ready(Err(val)) => Poll::Ready(Err(Error::new(Kind::Tungstenite, Some(val)))),
375            Poll::Pending => Poll::Pending,
376        }
377    }
378}