1use crate::auth::{Authenticator, TqAuth};
6use crate::datamanager::{DataManager, DataManagerConfig};
7use crate::errors::{Result, TqError};
8use crate::quote::QuoteSubscription;
9use crate::series::SeriesAPI;
10use crate::trade_session::TradeSession;
11use crate::websocket::{TqQuoteWebsocket, WebSocketConfig};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16#[derive(Debug, Clone)]
18pub struct ClientConfig {
19 pub log_level: String,
21 pub view_width: usize,
23 pub development: bool,
25}
26
27impl Default for ClientConfig {
28 fn default() -> Self {
29 ClientConfig {
30 log_level: "info".to_string(),
31 view_width: 10000,
32 development: false,
33 }
34 }
35}
36
37pub type ClientOption = Box<dyn Fn(&mut ClientConfig)>;
39
40pub struct ClientBuilder {
42 username: String,
43 password: String,
44 config: ClientConfig,
45 auth: Option<Arc<RwLock<dyn Authenticator>>>,
46}
47
48impl ClientBuilder {
49 pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
70 ClientBuilder {
71 username: username.into(),
72 password: password.into(),
73 config: ClientConfig::default(),
74 auth: None,
75 }
76 }
77
78 pub fn log_level(mut self, level: impl Into<String>) -> Self {
80 self.config.log_level = level.into();
81 self
82 }
83
84 pub fn view_width(mut self, width: usize) -> Self {
86 self.config.view_width = width;
87 self
88 }
89
90 pub fn development(mut self, dev: bool) -> Self {
92 self.config.development = dev;
93 self
94 }
95
96 pub fn config(mut self, config: ClientConfig) -> Self {
98 self.config = config;
99 self
100 }
101
102 pub fn auth<A: Authenticator + 'static>(mut self, auth: A) -> Self {
123 self.auth = Some(Arc::new(RwLock::new(auth)));
124 self
125 }
126
127 pub async fn build(self) -> Result<Client> {
133 crate::logger::init_logger(&self.config.log_level, true);
135
136 let auth: Arc<RwLock<dyn Authenticator>> = if let Some(custom_auth) = self.auth {
138 custom_auth
139 } else {
140 let mut auth = TqAuth::new(self.username.clone(), self.password.clone());
141 auth.login().await?;
142 Arc::new(RwLock::new(auth))
143 };
144
145 let dm_config = DataManagerConfig {
147 default_view_width: self.config.view_width,
148 enable_auto_cleanup: true,
149 };
150 let initial_data = HashMap::new();
151 let dm = Arc::new(DataManager::new(initial_data, dm_config));
152
153 Ok(Client {
154 _username: self.username,
155 _config: self.config,
156 auth,
157 dm,
158 quotes_ws: None,
159 series_api: None,
160 trade_sessions: Arc::new(RwLock::new(HashMap::new())),
161 })
162 }
163}
164
165pub struct Client {
167 _username: String,
168 _config: ClientConfig,
169 auth: Arc<RwLock<dyn Authenticator>>,
170 dm: Arc<DataManager>,
171 quotes_ws: Option<Arc<TqQuoteWebsocket>>,
172 series_api: Option<Arc<SeriesAPI>>,
173 trade_sessions: Arc<RwLock<HashMap<String, Arc<TradeSession>>>>,
174}
175
176impl Client {
177 pub async fn new(username: &str, password: &str, config: ClientConfig) -> Result<Self> {
193 ClientBuilder::new(username, password)
194 .config(config)
195 .build()
196 .await
197 }
198
199 pub fn builder(username: impl Into<String>, password: impl Into<String>) -> ClientBuilder {
216 ClientBuilder::new(username, password)
217 }
218
219 pub async fn init_market(&mut self) -> Result<()> {
221 let auth = self.auth.read().await;
222 let md_url = auth.get_md_url(false, false).await?;
223
224 let mut ws_config = WebSocketConfig::default();
225 ws_config.headers = auth.base_header();
226
227 let quotes_ws = Arc::new(TqQuoteWebsocket::new(
228 md_url,
229 Arc::clone(&self.dm),
230 ws_config,
231 ));
232
233 quotes_ws.init(false).await?;
234
235 self.quotes_ws = Some(Arc::clone("es_ws));
236
237 let series_api = Arc::new(SeriesAPI::new(
239 Arc::clone(&self.dm),
240 quotes_ws,
241 Arc::clone(&self.auth),
242 ));
243 self.series_api = Some(series_api);
244
245 Ok(())
246 }
247
248 pub async fn set_auth<A: Authenticator + 'static>(&mut self, auth: A) {
276 self.auth = Arc::new(RwLock::new(auth));
277 }
278
279 pub async fn get_auth(&self) -> tokio::sync::RwLockReadGuard<'_, dyn Authenticator> {
296 self.auth.read().await
297 }
298
299 pub fn series(&self) -> Result<Arc<SeriesAPI>> {
301 self.series_api
302 .clone()
303 .ok_or_else(|| TqError::InternalError("Series API 未初始化".to_string()))
304 }
305
306 pub async fn subscribe_quote(&self, symbols: &[&str]) -> Result<Arc<QuoteSubscription>> {
308 if self.quotes_ws.is_none() {
309 return Err(TqError::InternalError(
310 "行情 WebSocket 未初始化".to_string(),
311 ));
312 }
313 {
314 let auth = self.auth.read().await;
315 auth.has_md_grants(symbols)?
316 }
317 let symbol_list: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
318 let qs = Arc::new(QuoteSubscription::new(
319 Arc::clone(&self.dm),
320 self.quotes_ws.as_ref().unwrap().clone(),
321 symbol_list,
322 ));
323
324 qs.start().await?;
326
327 Ok(qs)
328 }
329
330 pub async fn create_trade_session(
364 &self,
365 broker: &str,
366 user_id: &str,
367 password: &str,
368 ) -> Result<Arc<TradeSession>> {
369 let auth = self.auth.read().await;
371 let broker_info = auth.get_td_url(broker, user_id).await?;
372
373 let mut ws_config = WebSocketConfig::default();
374 ws_config.headers = auth.base_header();
375 drop(auth);
376
377 let session = Arc::new(TradeSession::new(
379 broker.to_string(),
380 user_id.to_string(),
381 password.to_string(),
382 Arc::clone(&self.dm),
383 broker_info.url,
384 ws_config,
385 ));
386
387 let key = format!("{}:{}", broker, user_id);
389 let mut sessions = self.trade_sessions.write().await;
390 sessions.insert(key, Arc::clone(&session));
391
392 Ok(session)
393 }
394
395 pub async fn register_trade_session(&self, key: &str, session: Arc<TradeSession>) {
397 let mut sessions = self.trade_sessions.write().await;
398 sessions.insert(key.to_string(), session);
399 }
400
401 pub async fn get_trade_session(&self, key: &str) -> Option<Arc<TradeSession>> {
403 let sessions = self.trade_sessions.read().await;
404 sessions.get(key).cloned()
405 }
406
407 pub async fn close(&self) -> Result<()> {
409 if let Some(ws) = &self.quotes_ws {
410 ws.close().await?;
411 }
412
413 let sessions = self.trade_sessions.read().await;
414 for (_key, trader) in sessions.iter() {
415 trader.close().await?;
416 }
417
418 Ok(())
419 }
420}