xapi_binance/common/ws/
stream.rs1use crate::common::{
2 payload::{BnWsStreamMethod, BnWsStreamPayload},
3 response::{BnWsStreamData, BnWsStreamResponse},
4};
5use ezsockets::{Bytes, Client, ClientConfig, ClientExt, Error, Utf8Bytes};
6use std::collections::HashMap;
7use tokio::sync::{mpsc, oneshot};
8use ulid::Ulid;
9use xapi_shared::ws::{api::SharedWsApiTrait, error::SharedWsError, stream::SharedWsStreamTrait};
10
11pub struct BnWsStream {
12 client: Client<Self>,
13 on_connect_tx: Option<oneshot::Sender<()>>,
14 oneshot_tx_map: HashMap<String, oneshot::Sender<Result<BnWsStreamResponse, SharedWsError>>>,
15 stream_tx_map: HashMap<String, mpsc::Sender<Result<BnWsStreamData, SharedWsError>>>,
16}
17
18pub enum BnWsStreamCall {
19 SubscribeStream {
20 streams: Vec<(String, mpsc::Sender<Result<BnWsStreamData, SharedWsError>>)>,
21 tx: oneshot::Sender<Result<BnWsStreamResponse, SharedWsError>>,
22 },
23}
24
25#[async_trait::async_trait]
26impl ClientExt for BnWsStream {
27 type Call = BnWsStreamCall;
28
29 async fn on_text(&mut self, text: Utf8Bytes) -> Result<(), Error> {
30 let msg = text.to_string();
31
32 if let Some(result) = self.recv_stream_resp(&msg).await {
33 return result.map_err(|err| err.into());
34 }
35
36 if let Some(result) = self.recv_oneshot_resp(&msg) {
37 return result.map_err(|err| err.into());
38 }
39
40 tracing::error!(?msg, "unhandled bn ws message");
41 Err(SharedWsError::AppError("unhandled bn ws message".to_string()).into())
42 }
43
44 async fn on_binary(&mut self, _bytes: Bytes) -> Result<(), Error> {
45 unimplemented!()
46 }
47
48 async fn on_call(&mut self, call: Self::Call) -> Result<(), Error> {
49 match call {
50 BnWsStreamCall::SubscribeStream { streams, tx } => {
51 self.subscribe_streams(streams, tx)?
52 }
53 }
54
55 Ok(())
56 }
57
58 async fn on_connect(&mut self) -> Result<(), Error> {
59 if let Some(tx) = self.on_connect_tx.take() {
60 tx.send(())
61 .inspect_err(|err| {
62 tracing::error!(?err, "failed to send on_connect signal");
63 })
64 .map_err(|_| {
65 SharedWsError::ChannelClosedError("first on connect channel closed".to_string())
66 })?;
67 }
68 Ok(())
69 }
70}
71
72impl SharedWsApiTrait<String, BnWsStreamPayload, BnWsStreamResponse> for BnWsStream {
73 fn get_client(&self) -> &Client<Self> {
74 &self.client
75 }
76
77 fn get_oneshot_tx_map(
78 &mut self,
79 ) -> &mut HashMap<String, oneshot::Sender<Result<BnWsStreamResponse, SharedWsError>>> {
80 &mut self.oneshot_tx_map
81 }
82}
83
84#[async_trait::async_trait]
85impl SharedWsStreamTrait<String, BnWsStreamData> for BnWsStream {
86 fn get_stream_tx_map(
87 &mut self,
88 ) -> &mut HashMap<String, mpsc::Sender<Result<BnWsStreamData, SharedWsError>>> {
89 &mut self.stream_tx_map
90 }
91}
92
93impl BnWsStream {
94 pub async fn connect(config: ClientConfig) -> Client<Self> {
95 let (on_connect_tx, on_connect_rx) = oneshot::channel();
96
97 let (client, future) = ezsockets::connect(
98 |client| Self {
99 client,
100 on_connect_tx: Some(on_connect_tx),
101 oneshot_tx_map: Default::default(),
102 stream_tx_map: Default::default(),
103 },
104 config,
105 )
106 .await;
107
108 tokio::spawn(async move {
109 future.await.inspect_err(|err| {
110 tracing::error!(?err, "bn ws client connection error");
111 })
112 });
113
114 _ = on_connect_rx.await;
115
116 client
117 }
118
119 fn subscribe_streams(
120 &mut self,
121 streams: Vec<(String, mpsc::Sender<Result<BnWsStreamData, SharedWsError>>)>,
122 tx: oneshot::Sender<Result<BnWsStreamResponse, SharedWsError>>,
123 ) -> Result<(), SharedWsError> {
124 if streams.is_empty() {
125 tracing::warn!("no streams to subscribe");
126 return Ok(());
127 }
128
129 for (stream, _) in &streams {
130 if self.stream_tx_map.contains_key(stream) {
131 tracing::error!(stream, "duplicated stream in ws subscribe stream request");
132 return Err(SharedWsError::InvalidIdError(stream.clone()));
133 }
134 }
135
136 let id = Ulid::new().to_string();
137
138 let payload = BnWsStreamPayload {
139 id,
140 method: BnWsStreamMethod::Subscribe,
141 params: Some(serde_json::Value::Array(
142 streams
143 .iter()
144 .map(|(stream, _)| serde_json::Value::String(stream.clone()))
145 .collect::<Vec<_>>(),
146 )),
147 };
148
149 for (stream, tx) in streams {
150 self.stream_tx_map.insert(stream, tx);
151 }
152
153 self.send_oneshot(payload, tx)
154 }
155}