1use crate::datamanager::DataManager;
6use crate::errors::{Result, TqError};
7use crate::types::{
8 Account, InsertOrderRequest, Notification, Order, Position, PositionUpdate, Trade,
9};
10use crate::websocket::{TqTradeWebsocket, WebSocketConfig};
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14use async_channel::{Sender, Receiver, unbounded};
15use tokio::sync::RwLock;
16use tracing::{debug, error, info};
17
18#[allow(unused)]
20pub struct TradeSession {
21 broker: String,
22 user_id: String,
23 password: String,
24 dm: Arc<DataManager>,
25 ws: Arc<TqTradeWebsocket>,
26
27 account_tx: Sender<Account>,
29 account_rx: Receiver<Account>,
30 position_tx: Sender<PositionUpdate>,
31 position_rx: Receiver<PositionUpdate>,
32 order_tx: Sender<Order>,
33 order_rx: Receiver<Order>,
34 trade_tx: Sender<Trade>,
35 trade_rx: Receiver<Trade>,
36 notification_tx: Sender<Notification>,
37 notification_rx: Receiver<Notification>,
38
39 on_account: Arc<RwLock<Option<Arc<dyn Fn(Account) + Send + Sync>>>>,
41 on_position: Arc<RwLock<Option<Arc<dyn Fn(String, Position) + Send + Sync>>>>,
42 on_order: Arc<RwLock<Option<Arc<dyn Fn(Order) + Send + Sync>>>>,
43 on_trade: Arc<RwLock<Option<Arc<dyn Fn(Trade) + Send + Sync>>>>,
44 on_notification: Arc<RwLock<Option<Arc<dyn Fn(Notification) + Send + Sync>>>>,
45 on_error: Arc<RwLock<Option<Arc<dyn Fn(String) + Send + Sync>>>>,
46
47 logged_in: Arc<AtomicBool>,
49 running: Arc<AtomicBool>,
50}
51
52impl TradeSession {
53 pub fn new(
55 broker: String,
56 user_id: String,
57 password: String,
58 dm: Arc<DataManager>,
59 ws_url: String,
60 ws_config: WebSocketConfig,
61 ) -> Self {
62 let (account_tx, account_rx) = unbounded();
64 let (position_tx, position_rx) = unbounded();
65 let (order_tx, order_rx) = unbounded();
66 let (trade_tx, trade_rx) = unbounded();
67 let (notification_tx, notification_rx) = unbounded();
68
69 let ws = Arc::new(TqTradeWebsocket::new(ws_url, Arc::clone(&dm), ws_config));
70
71 let noti_tx = notification_tx.clone();
72 ws.on_notify(move |noti| {
73 let noti2 = noti.clone();
74 let noti_tx2 = noti_tx.clone();
75 tokio::spawn(async move {
76 match noti_tx2.send(noti2).await {
77 Ok(_) => {
78 debug!("通知发送成功: {:?}", noti);
79 }
80 Err(e) => {
81 error!("通知发送失败: {:?}, error={}", noti, e);
82 }
83 }
84 });
85 });
86
87 TradeSession {
88 broker,
89 user_id,
90 password,
91 dm,
92 ws,
93 account_tx,
94 account_rx,
95 position_tx,
96 position_rx,
97 order_tx,
98 order_rx,
99 trade_tx,
100 trade_rx,
101 notification_tx,
102 notification_rx,
103 on_account: Arc::new(RwLock::new(None)),
104 on_position: Arc::new(RwLock::new(None)),
105 on_order: Arc::new(RwLock::new(None)),
106 on_trade: Arc::new(RwLock::new(None)),
107 on_notification: Arc::new(RwLock::new(None)),
108 on_error: Arc::new(RwLock::new(None)),
109 logged_in: Arc::new(AtomicBool::new(false)),
110 running: Arc::new(AtomicBool::new(false)),
111 }
112 }
113
114 async fn send_login(&self) -> Result<()> {
116 let login_req = serde_json::json!({
117 "aid": "req_login",
118 "bid": self.broker,
119 "user_name": self.user_id,
120 "password": self.password
121 });
122
123 info!(
124 "发送交易登录请求: broker={}, user_id={}",
125 self.broker, self.user_id
126 );
127 self.ws.send(&login_req).await?;
128 Ok(())
129 }
130
131 async fn start_watching(&self) {
133 let dm_clone = Arc::clone(&self.dm);
134 let user_id = self.user_id.clone();
135 let logged_in = Arc::clone(&self.logged_in);
136 let running = Arc::clone(&self.running);
137 let account_tx = self.account_tx.clone();
138 let position_tx = self.position_tx.clone();
139 let order_tx = self.order_tx.clone();
140 let trade_tx = self.trade_tx.clone();
141 let on_account = Arc::clone(&self.on_account);
142 let on_position = Arc::clone(&self.on_position);
143 let on_order = Arc::clone(&self.on_order);
144 let on_trade = Arc::clone(&self.on_trade);
145
146 info!("TradeSession 开始监听数据更新");
147
148 let dm_for_callback = Arc::clone(&dm_clone);
150
151 let processing_lock = Arc::new(tokio::sync::Mutex::new(()));
154
155 dm_clone.on_data(move || {
156 let dm = Arc::clone(&dm_for_callback);
157 let user_id = user_id.clone();
158 let logged_in = Arc::clone(&logged_in);
159 let running = Arc::clone(&running);
160 let account_tx = account_tx.clone();
161 let position_tx = position_tx.clone();
162 let order_tx = order_tx.clone();
163 let trade_tx = trade_tx.clone();
164 let on_account = Arc::clone(&on_account);
165 let on_position = Arc::clone(&on_position);
166 let on_order = Arc::clone(&on_order);
167 let on_trade = Arc::clone(&on_trade);
168 let lock = Arc::clone(&processing_lock);
169
170 tokio::spawn(async move {
171 let _guard = lock.lock().await;
173
174 if !running.load(Ordering::SeqCst) {
175 return;
176 }
177
178 if let Some(session_data) = dm.get_by_path(&["trade", &user_id, "session"]) {
180 if let serde_json::Value::Object(session_map) = session_data {
181 if let Some(trading_day) = session_map.get("trading_day") {
182 if !trading_day.is_null() {
183 if logged_in
186 .compare_exchange(
187 false,
188 true,
189 Ordering::SeqCst,
190 Ordering::SeqCst,
191 )
192 .is_ok()
193 {
194 info!("交易会话已登录: user_id={}", user_id);
195 }
196 }
197 }
198 }
199 }
200
201 if dm.is_changing(&["trade", &user_id, "accounts", "CNY"]) {
203 match dm.get_account_data(&user_id, "CNY") {
204 Ok(account) => {
205 debug!("账户更新: balance={}", account.balance);
206 let _ = account_tx.send(account.clone()).await;
207
208 if let Some(callback) = on_account.read().await.as_ref() {
209 let cb = Arc::clone(callback);
210 tokio::spawn(async move {
211 cb(account);
212 });
213 }
214 }
215 Err(e) => {
216 error!("获取账户数据失败(这不应该发生): {}", e);
218 }
219 }
220 }
221
222 if dm.is_changing(&["trade", &user_id, "positions"]) {
224 Self::process_position_update(&dm, &user_id, &position_tx, &on_position).await;
225 }
226
227 if dm.is_changing(&["trade", &user_id, "orders"]) {
229 Self::process_order_update(&dm, &user_id, &order_tx, &on_order).await;
230 }
231
232 if dm.is_changing(&["trade", &user_id, "trades"]) {
234 Self::process_trade_update(&dm, &user_id, &trade_tx, &on_trade).await;
235 }
236 });
237 });
238 }
239
240 async fn process_position_update(
242 dm: &Arc<DataManager>,
243 user_id: &str,
244 position_tx: &Sender<PositionUpdate>,
245 on_position: &Arc<RwLock<Option<Arc<dyn Fn(String, Position) + Send + Sync>>>>,
246 ) {
247 if let Some(positions_data) = dm.get_by_path(&["trade", user_id, "positions"]) {
248 if let serde_json::Value::Object(positions_map) = positions_data {
249 for (symbol, _) in positions_map.iter() {
250 if symbol.starts_with('_') {
252 continue;
253 }
254
255 if dm.is_changing(&["trade", user_id, "positions", symbol]) {
257 match dm.get_position_data(user_id, symbol) {
258 Ok(position) => {
259 debug!("持仓更新: symbol={}", symbol);
260
261 let update = PositionUpdate {
262 symbol: symbol.clone(),
263 position: position.clone(),
264 };
265
266 let _ = position_tx.send(update).await;
268
269 if let Some(callback) = on_position.read().await.as_ref() {
271 let cb = Arc::clone(callback);
272 let symbol = symbol.clone();
273 tokio::spawn(async move {
274 cb(symbol, position);
275 });
276 }
277 }
278 Err(e) => {
279 error!("获取持仓数据失败: symbol={}, error={}", symbol, e);
280 }
281 }
282 }
283 }
284 }
285 }
286 }
287
288 async fn process_order_update(
290 dm: &Arc<DataManager>,
291 user_id: &str,
292 order_tx: &Sender<Order>,
293 on_order: &Arc<RwLock<Option<Arc<dyn Fn(Order) + Send + Sync>>>>,
294 ) {
295 if let Some(orders_data) = dm.get_by_path(&["trade", user_id, "orders"]) {
296 if let serde_json::Value::Object(orders_map) = orders_data {
297 for (order_id, order_data) in orders_map.iter() {
298 if order_id.starts_with('_') {
300 continue;
301 }
302
303 if dm.is_changing(&["trade", user_id, "orders", order_id]) {
305 if let Ok(order) = serde_json::from_value::<Order>(order_data.clone()) {
306 debug!("订单更新: order_id={}", order_id);
307
308 let _ = order_tx.send(order.clone()).await;
310
311 if let Some(callback) = on_order.read().await.as_ref() {
313 let cb = Arc::clone(callback);
314 tokio::spawn(async move {
315 cb(order);
316 });
317 }
318 }
319 }
320 }
321 }
322 }
323 }
324
325 async fn process_trade_update(
327 dm: &Arc<DataManager>,
328 user_id: &str,
329 trade_tx: &Sender<Trade>,
330 on_trade: &Arc<RwLock<Option<Arc<dyn Fn(Trade) + Send + Sync>>>>,
331 ) {
332 if let Some(trades_data) = dm.get_by_path(&["trade", user_id, "trades"]) {
333 if let serde_json::Value::Object(trades_map) = trades_data {
334 for (trade_id, trade_data) in trades_map.iter() {
335 if trade_id.starts_with('_') {
337 continue;
338 }
339
340 if dm.is_changing(&["trade", user_id, "trades", trade_id]) {
342 if let Ok(trade) = serde_json::from_value::<Trade>(trade_data.clone()) {
343 debug!("成交更新: trade_id={}", trade_id);
344
345 let _ = trade_tx.send(trade.clone()).await;
347
348 if let Some(callback) = on_trade.read().await.as_ref() {
350 let cb = Arc::clone(callback);
351 tokio::spawn(async move {
352 cb(trade);
353 });
354 }
355 }
356 }
357 }
358 }
359 }
360 }
361
362 pub async fn on_account<F>(&self, handler: F)
364 where
365 F: Fn(Account) + Send + Sync + 'static,
366 {
367 let mut guard = self.on_account.write().await;
368 *guard = Some(Arc::new(handler));
369 }
370
371 pub async fn on_position<F>(&self, handler: F)
373 where
374 F: Fn(String, Position) + Send + Sync + 'static,
375 {
376 let mut guard = self.on_position.write().await;
377 *guard = Some(Arc::new(handler));
378 }
379
380 pub async fn on_order<F>(&self, handler: F)
382 where
383 F: Fn(Order) + Send + Sync + 'static,
384 {
385 let mut guard = self.on_order.write().await;
386 *guard = Some(Arc::new(handler));
387 }
388
389 pub async fn on_trade<F>(&self, handler: F)
391 where
392 F: Fn(Trade) + Send + Sync + 'static,
393 {
394 let mut guard = self.on_trade.write().await;
395 *guard = Some(Arc::new(handler));
396 }
397
398 pub async fn on_notification<F>(&self, handler: F)
400 where
401 F: Fn(Notification) + Send + Sync + 'static,
402 {
403 let mut guard = self.on_notification.write().await;
404 *guard = Some(Arc::new(handler));
405 }
406
407 pub async fn on_error<F>(&self, handler: F)
409 where
410 F: Fn(String) + Send + Sync + 'static,
411 {
412 let mut guard = self.on_error.write().await;
413 *guard = Some(Arc::new(handler));
414 }
415
416 pub fn account_channel(&self) -> Receiver<Account> {
418 self.account_rx.clone()
419 }
420
421 pub fn position_channel(&self) -> Receiver<PositionUpdate> {
423 self.position_rx.clone()
424 }
425
426 pub fn order_channel(&self) -> Receiver<Order> {
428 self.order_rx.clone()
429 }
430
431 pub fn trade_channel(&self) -> Receiver<Trade> {
433 self.trade_rx.clone()
434 }
435
436 pub fn notification_channel(&self) -> Receiver<Notification> {
438 self.notification_rx.clone()
439 }
440}
441
442impl TradeSession {
443 pub async fn insert_order(&self, req: &InsertOrderRequest) -> Result<Order> {
445 if !self.is_ready() {
446 return Err(TqError::InternalError("交易会话未就绪".to_string()));
447 }
448
449 let exchange_id = req.get_exchange_id();
450 let instrument_id = req.get_instrument_id();
451
452 use uuid::Uuid;
454 let order_id = format!(
455 "TQRS_{}",
456 Uuid::new_v4().simple().to_string()[..8].to_uppercase()
457 );
458
459 let time_condition = if req.price_type == "ANY" {
461 "IOC"
462 } else {
463 "GFD"
464 };
465
466 let order_req = serde_json::json!({
468 "aid": "insert_order",
469 "user_id": self.user_id,
470 "order_id": order_id,
471 "exchange_id": exchange_id,
472 "instrument_id": instrument_id,
473 "direction": req.direction,
474 "offset": req.offset,
475 "volume": req.volume,
476 "price_type": req.price_type,
477 "limit_price": req.limit_price,
478 "volume_condition": "ANY",
479 "time_condition": time_condition,
480 });
481
482 info!("发送下单请求: order_id={}, symbol={}", order_id, req.symbol);
483 self.ws.send(&order_req).await?;
484
485 let order_init = serde_json::json!({
487 "user_id": self.user_id,
488 "order_id": order_id,
489 "exchange_id": exchange_id,
490 "instrument_id": instrument_id,
491 "direction": req.direction,
492 "offset": req.offset,
493 "volume_orign": req.volume,
494 "volume_left": req.volume,
495 "price_type": req.price_type,
496 "limit_price": req.limit_price,
497 "status": "ALIVE",
498 });
499
500 self.dm.merge_data(
501 serde_json::json!({
502 "trade": {
503 self.user_id.clone(): {
504 "orders": {
505 order_id.clone(): order_init
506 }
507 }
508 }
509 }),
510 false,
511 false,
512 );
513
514 Ok(Order {
516 order_id: order_id.clone(),
517 exchange_id: exchange_id.to_string(),
518 instrument_id: instrument_id.to_string(),
519 direction: req.direction.clone(),
520 offset: req.offset.clone(),
521 volume_orign: req.volume,
522 volume_left: req.volume,
523 limit_price: req.limit_price,
524 price_type: req.price_type.clone(),
525 volume_condition: "ANY".to_string(),
526 time_condition: time_condition.to_string(),
527 insert_date_time: 0,
528 status: "ALIVE".to_string(),
529 frozen_margin: 0f64,
530 epoch: None,
531 seqno: 0,
532 user_id: self.user_id.clone(),
533 exchange_order_id: String::default(),
534 last_msg: String::default(),
535 })
536 }
537
538 pub async fn cancel_order(&self, order_id: &str) -> Result<()> {
540 if !self.is_ready() {
541 return Err(TqError::InternalError("交易会话未就绪".to_string()));
542 }
543
544 let cancel_req = serde_json::json!({
546 "aid": "cancel_order",
547 "user_id": self.user_id,
548 "order_id": order_id
549 });
550
551 info!("发送撤单请求: order_id={}", order_id);
552 self.ws.send(&cancel_req).await?;
553 Ok(())
554 }
555
556 pub async fn get_account(&self) -> Result<Account> {
558 self.dm.get_account_data(&self.user_id, "CNY")
559 }
560
561 pub async fn get_position(&self, symbol: &str) -> Result<Position> {
563 self.dm.get_position_data(&self.user_id, symbol)
564 }
565
566 pub async fn get_positions(&self) -> Result<HashMap<String, Position>> {
568 let data = self.dm.get_by_path(&["trade", &self.user_id, "positions"]);
569 if data.is_none() {
570 return Ok(HashMap::new());
571 }
572
573 let positions_map = match data.unwrap() {
574 serde_json::Value::Object(map) => map,
575 _ => return Ok(HashMap::new()),
576 };
577
578 let mut positions = HashMap::new();
579 for (symbol, pos_data) in positions_map.iter() {
580 if symbol.starts_with('_') {
582 continue;
583 }
584
585 if let Ok(position) = serde_json::from_value::<Position>(pos_data.clone()) {
586 positions.insert(symbol.clone(), position);
587 }
588 }
589
590 Ok(positions)
591 }
592
593 pub async fn get_orders(&self) -> Result<HashMap<String, Order>> {
595 let data = self.dm.get_by_path(&["trade", &self.user_id, "orders"]);
596 if data.is_none() {
597 return Ok(HashMap::new());
598 }
599
600 let orders_map = match data.unwrap() {
601 serde_json::Value::Object(map) => map,
602 _ => return Ok(HashMap::new()),
603 };
604
605 let mut orders = HashMap::new();
606 for (order_id, order_data) in orders_map.iter() {
607 if order_id.starts_with('_') {
609 continue;
610 }
611
612 if let Ok(order) = serde_json::from_value::<Order>(order_data.clone()) {
613 orders.insert(order_id.clone(), order);
614 }
615 }
616
617 Ok(orders)
618 }
619
620 pub async fn get_trades(&self) -> Result<HashMap<String, Trade>> {
622 let data = self.dm.get_by_path(&["trade", &self.user_id, "trades"]);
623 if data.is_none() {
624 return Ok(HashMap::new());
625 }
626
627 let trades_map = match data.unwrap() {
628 serde_json::Value::Object(map) => map,
629 _ => return Ok(HashMap::new()),
630 };
631
632 let mut trades = HashMap::new();
633 for (trade_id, trade_data) in trades_map.iter() {
634 if trade_id.starts_with('_') {
636 continue;
637 }
638
639 if let Ok(trade) = serde_json::from_value::<Trade>(trade_data.clone()) {
640 trades.insert(trade_id.clone(), trade);
641 }
642 }
643
644 Ok(trades)
645 }
646
647 pub async fn connect(&self) -> Result<()> {
649 if self
651 .running
652 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
653 .is_err()
654 {
655 return Ok(()); }
657
658 info!(
659 "连接交易服务器: broker={}, user_id={}",
660 self.broker, self.user_id
661 );
662
663 self.ws.init(false).await?;
665
666 self.start_watching().await;
668
669 self.send_login().await?;
671
672 Ok(())
673 }
674
675 pub fn is_ready(&self) -> bool {
677 self.logged_in.load(Ordering::SeqCst) && self.ws.is_ready()
680 }
681
682 pub async fn close(&self) -> Result<()> {
684 if self
686 .running
687 .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
688 .is_err()
689 {
690 return Ok(()); }
692
693 info!("关闭交易会话");
694 self.ws.close().await?;
695 Ok(())
696 }
697}