xapi_binance/common/ws/
stream.rs

1use crate::common::{
2    payload::{BnWsStreamMethod, BnWsStreamPayload},
3    response::{BnWsStreamData, BnWsStreamResponse},
4};
5use ezsockets::{Bytes, Client, ClientConfig, ClientExt, Error, Utf8Bytes};
6use std::collections::HashMap;
7use tokio::sync::{mpsc, oneshot};
8use ulid::Ulid;
9use xapi_shared::ws::{api::SharedWsApiTrait, error::SharedWsError, stream::SharedWsStreamTrait};
10
11pub struct BnWsStream {
12    client: Client<Self>,
13    on_connect_tx: Option<oneshot::Sender<()>>,
14    oneshot_tx_map: HashMap<String, oneshot::Sender<Result<BnWsStreamResponse, SharedWsError>>>,
15    stream_tx_map: HashMap<String, mpsc::Sender<Result<BnWsStreamData, SharedWsError>>>,
16}
17
18pub enum BnWsStreamCall {
19    SubscribeStream {
20        streams: Vec<(String, mpsc::Sender<Result<BnWsStreamData, SharedWsError>>)>,
21        tx: oneshot::Sender<Result<BnWsStreamResponse, SharedWsError>>,
22    },
23}
24
25#[async_trait::async_trait]
26impl ClientExt for BnWsStream {
27    type Call = BnWsStreamCall;
28
29    async fn on_text(&mut self, text: Utf8Bytes) -> Result<(), Error> {
30        let msg = text.to_string();
31
32        if let Some(result) = self.recv_stream_resp(&msg).await {
33            return result.map_err(|err| err.into());
34        }
35
36        if let Some(result) = self.recv_oneshot_resp(&msg) {
37            return result.map_err(|err| err.into());
38        }
39
40        tracing::error!(?msg, "unhandled bn ws message");
41        Err(SharedWsError::AppError("unhandled bn ws message".to_string()).into())
42    }
43
44    async fn on_binary(&mut self, _bytes: Bytes) -> Result<(), Error> {
45        unimplemented!()
46    }
47
48    async fn on_call(&mut self, call: Self::Call) -> Result<(), Error> {
49        match call {
50            BnWsStreamCall::SubscribeStream { streams, tx } => {
51                self.subscribe_streams(streams, tx)?
52            }
53        }
54
55        Ok(())
56    }
57
58    async fn on_connect(&mut self) -> Result<(), Error> {
59        if let Some(tx) = self.on_connect_tx.take() {
60            tx.send(())
61                .inspect_err(|err| {
62                    tracing::error!(?err, "failed to send on_connect signal");
63                })
64                .map_err(|_| {
65                    SharedWsError::ChannelClosedError("first on connect channel closed".to_string())
66                })?;
67        }
68        Ok(())
69    }
70}
71
72impl SharedWsApiTrait<String, BnWsStreamPayload, BnWsStreamResponse> for BnWsStream {
73    fn get_client(&self) -> &Client<Self> {
74        &self.client
75    }
76
77    fn get_oneshot_tx_map(
78        &mut self,
79    ) -> &mut HashMap<String, oneshot::Sender<Result<BnWsStreamResponse, SharedWsError>>> {
80        &mut self.oneshot_tx_map
81    }
82}
83
84#[async_trait::async_trait]
85impl SharedWsStreamTrait<String, BnWsStreamData> for BnWsStream {
86    fn get_stream_tx_map(
87        &mut self,
88    ) -> &mut HashMap<String, mpsc::Sender<Result<BnWsStreamData, SharedWsError>>> {
89        &mut self.stream_tx_map
90    }
91}
92
93impl BnWsStream {
94    pub async fn connect(config: ClientConfig) -> Client<Self> {
95        let (on_connect_tx, on_connect_rx) = oneshot::channel();
96
97        let (client, future) = ezsockets::connect(
98            |client| Self {
99                client,
100                on_connect_tx: Some(on_connect_tx),
101                oneshot_tx_map: Default::default(),
102                stream_tx_map: Default::default(),
103            },
104            config,
105        )
106        .await;
107
108        tokio::spawn(async move {
109            future.await.inspect_err(|err| {
110                tracing::error!(?err, "bn ws client connection error");
111            })
112        });
113
114        _ = on_connect_rx.await;
115
116        client
117    }
118
119    fn subscribe_streams(
120        &mut self,
121        streams: Vec<(String, mpsc::Sender<Result<BnWsStreamData, SharedWsError>>)>,
122        tx: oneshot::Sender<Result<BnWsStreamResponse, SharedWsError>>,
123    ) -> Result<(), SharedWsError> {
124        if streams.is_empty() {
125            tracing::warn!("no streams to subscribe");
126            return Ok(());
127        }
128
129        for (stream, _) in &streams {
130            if self.stream_tx_map.contains_key(stream) {
131                tracing::error!(stream, "duplicated stream in ws subscribe stream request");
132                return Err(SharedWsError::InvalidIdError(stream.clone()));
133            }
134        }
135
136        let id = Ulid::new().to_string();
137
138        let payload = BnWsStreamPayload {
139            id,
140            method: BnWsStreamMethod::Subscribe,
141            params: Some(serde_json::Value::Array(
142                streams
143                    .iter()
144                    .map(|(stream, _)| serde_json::Value::String(stream.clone()))
145                    .collect::<Vec<_>>(),
146            )),
147        };
148
149        for (stream, tx) in streams {
150            self.stream_tx_map.insert(stream, tx);
151        }
152
153        self.send_oneshot(payload, tx)
154    }
155}