xapi_binance/common/
executor.rs

1use crate::{
2    common::{
3        endpoint::BnEndpoint,
4        payload::BnWsApiPayload,
5        ratelimiter::BnRatelimiter,
6        response::BnWsApiRespType,
7        signer::BnSigner,
8        ws::{
9            api::{BnWsApi, BnWsApiCall},
10            stream::{BnWsStream, BnWsStreamCall},
11        },
12    },
13    data::enums::ratelimit::BnRateLimitType,
14};
15use serde::{Serialize, de::DeserializeOwned};
16use std::{num::NonZeroU32, sync::Arc};
17use tokio::sync::{Mutex, OnceCell, mpsc, oneshot};
18use typed_builder::TypedBuilder;
19use ulid::Ulid;
20use xapi_shared::{
21    ratelimiter::SharedRatelimiterTrait, rest::SharedRestClientTrait, signer::SharedSignerTrait,
22    ws::error::SharedWsError,
23};
24
25#[derive(TypedBuilder)]
26pub struct BnExecutor {
27    endpoint: BnEndpoint,
28    #[builder(default = reqwest::Client::new())]
29    rest_client: reqwest::Client,
30    #[builder(default = None, setter(strip_option))]
31    signer: Option<BnSigner>,
32    #[builder(default = Arc::new(BnRatelimiter::default()))]
33    ratelimiter: Arc<BnRatelimiter>,
34    #[builder(default)]
35    ws_api: OnceCell<ezsockets::Client<BnWsApi>>,
36    #[builder(default)]
37    streams: Mutex<Vec<ezsockets::Client<BnWsStream>>>,
38}
39
40impl SharedRestClientTrait<BnRateLimitType> for BnExecutor {
41    fn get_client(&self) -> &reqwest::Client {
42        &self.rest_client
43    }
44
45    fn get_signer(&self) -> &dyn SharedSignerTrait {
46        if let Some(signer) = &self.signer {
47            signer
48        } else {
49            tracing::error!("signer is not set for BnExecutor");
50            panic!("signer is not set for BnExecutor");
51        }
52    }
53
54    fn get_ratelimiter(&self) -> Arc<dyn SharedRatelimiterTrait<BnRateLimitType> + Sync + Send> {
55        self.ratelimiter.clone()
56    }
57}
58
59impl BnExecutor {
60    pub fn get_endpoint(&self) -> &BnEndpoint {
61        &self.endpoint
62    }
63
64    pub async fn call_ws_api<ReqType: Serialize, ResType: DeserializeOwned>(
65        &self,
66        limits: &[(BnRateLimitType, NonZeroU32)],
67        signed: bool,
68        method: &str,
69        params: Option<ReqType>,
70    ) -> BnWsApiRespType<ResType> {
71        let params = params
72            .map(|p| {
73                serde_json::to_value(p)
74                    .inspect_err(|err| tracing::error!(?err, "failed to serialize ws api params"))
75                    .map_err(|err| SharedWsError::SerdeError(err.to_string()))
76            })
77            .transpose()?;
78
79        let params = match signed {
80            true => match &self.signer {
81                Some(signer) => Some(signer.sign_ws_payload(params)?),
82                None => {
83                    tracing::error!("signer is not set for BnExecutor");
84                    return Err(SharedWsError::AppError("signer is not set".to_string()));
85                }
86            },
87            false => params,
88        };
89
90        for (rate_limit_type, limit) in limits.iter() {
91            self.get_ratelimiter()
92                .limit_on(rate_limit_type, *limit)
93                .await?;
94        }
95
96        let api = self.get_ws_api().await;
97
98        let (tx, rx) = oneshot::channel();
99
100        api.call(BnWsApiCall::SendApi {
101            payload: BnWsApiPayload {
102                id: Ulid::new().to_string(),
103                method: method.to_string(),
104                params,
105            },
106            tx,
107        })
108        .inspect_err(|err| tracing::error!(?err, "failed to send ws api request"))
109        .map_err(|err| SharedWsError::ChannelClosedError(err.to_string()))?;
110
111        let resp = rx
112            .await
113            .inspect_err(|err| {
114                tracing::error!(?err, "failed to receive ws api response");
115            })
116            .map_err(|err| SharedWsError::ChannelClosedError(err.to_string()))??;
117
118        serde_json::from_value(resp.result)
119            .inspect_err(|err| tracing::error!(?err, "failed to parse ws api response result"))
120            .map_err(|err| SharedWsError::SerdeError(err.to_string()))
121    }
122
123    pub async fn subscribe_stream<T: DeserializeOwned + Send + 'static>(
124        &self,
125        stream: String,
126    ) -> Result<mpsc::Receiver<Result<T, SharedWsError>>, SharedWsError> {
127        let ws_stream_base_url = self.endpoint.get_ws_stream_base_url().clone();
128
129        let client = BnWsStream::connect(ezsockets::ClientConfig::new(ws_stream_base_url)).await;
130
131        let (raw_tx, mut raw_rx) = mpsc::channel(128);
132
133        let (oneshot_tx, oneshot_rx) = oneshot::channel();
134
135        let message = BnWsStreamCall::SubscribeStream {
136            streams: vec![(stream, raw_tx)],
137            tx: oneshot_tx,
138        };
139
140        client
141            .call(message)
142            .inspect_err(|err| tracing::error!(?err, "failed to subscribe to stream"))
143            .map_err(|err| SharedWsError::AppError(err.to_string()))?;
144
145        self.streams.lock().await.push(client);
146
147        oneshot_rx
148            .await
149            .map_err(|err| SharedWsError::ChannelClosedError(err.to_string()))??;
150
151        let (tx, rx) = mpsc::channel(128);
152
153        tokio::spawn(async move {
154            while let Some(result) = raw_rx.recv().await {
155                let msg = match result {
156                    Ok(resp) => match serde_json::from_value::<T>(resp.data) {
157                        Ok(data) => Ok(data),
158                        Err(err) => {
159                            tracing::error!(?err, "failed to parse message");
160                            Err(SharedWsError::SerdeError(err.to_string()))
161                        }
162                    },
163                    Err(err) => Err(err),
164                };
165
166                if let Err(err) = tx.send(msg).await {
167                    tracing::error!(?err, "failed to send message");
168                }
169            }
170        });
171
172        Ok(rx)
173    }
174
175    async fn get_ws_api(&self) -> &ezsockets::Client<BnWsApi> {
176        self.ws_api
177            .get_or_init(async || {
178                let ws_api_base_url = self.endpoint.get_ws_api_base_url().clone();
179
180                BnWsApi::connect(ezsockets::ClientConfig::new(ws_api_base_url)).await
181            })
182            .await
183    }
184}