tqsdk_rs/
trade_session.rs

1//! 交易会话实现
2//!
3//! 实现实盘交易功能
4
5use crate::datamanager::DataManager;
6use crate::errors::{Result, TqError};
7use crate::types::{
8    Account, InsertOrderRequest, Notification, Order, Position, PositionUpdate, Trade,
9};
10use crate::websocket::{TqTradeWebsocket, WebSocketConfig};
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14use async_channel::{Sender, Receiver, unbounded};
15use tokio::sync::RwLock;
16use tracing::{debug, error, info};
17
18/// 交易会话
19#[allow(unused)]
20pub struct TradeSession {
21    broker: String,
22    user_id: String,
23    password: String,
24    dm: Arc<DataManager>,
25    ws: Arc<TqTradeWebsocket>,
26
27    // Channels(使用 async-channel)
28    account_tx: Sender<Account>,
29    account_rx: Receiver<Account>,
30    position_tx: Sender<PositionUpdate>,
31    position_rx: Receiver<PositionUpdate>,
32    order_tx: Sender<Order>,
33    order_rx: Receiver<Order>,
34    trade_tx: Sender<Trade>,
35    trade_rx: Receiver<Trade>,
36    notification_tx: Sender<Notification>,
37    notification_rx: Receiver<Notification>,
38
39    // 回调
40    on_account: Arc<RwLock<Option<Arc<dyn Fn(Account) + Send + Sync>>>>,
41    on_position: Arc<RwLock<Option<Arc<dyn Fn(String, Position) + Send + Sync>>>>,
42    on_order: Arc<RwLock<Option<Arc<dyn Fn(Order) + Send + Sync>>>>,
43    on_trade: Arc<RwLock<Option<Arc<dyn Fn(Trade) + Send + Sync>>>>,
44    on_notification: Arc<RwLock<Option<Arc<dyn Fn(Notification) + Send + Sync>>>>,
45    on_error: Arc<RwLock<Option<Arc<dyn Fn(String) + Send + Sync>>>>,
46
47    // 状态(使用 Arc<AtomicBool> 避免锁开销,支持跨线程共享)
48    logged_in: Arc<AtomicBool>,
49    running: Arc<AtomicBool>,
50}
51
52impl TradeSession {
53    /// 创建交易会话
54    pub fn new(
55        broker: String,
56        user_id: String,
57        password: String,
58        dm: Arc<DataManager>,
59        ws_url: String,
60        ws_config: WebSocketConfig,
61    ) -> Self {
62        // 创建 async-channel channels(使用 unbounded)
63        let (account_tx, account_rx) = unbounded();
64        let (position_tx, position_rx) = unbounded();
65        let (order_tx, order_rx) = unbounded();
66        let (trade_tx, trade_rx) = unbounded();
67        let (notification_tx, notification_rx) = unbounded();
68
69        let ws = Arc::new(TqTradeWebsocket::new(ws_url, Arc::clone(&dm), ws_config));
70
71        let noti_tx = notification_tx.clone();
72        ws.on_notify(move |noti| {
73            let noti2 = noti.clone();
74            let noti_tx2 = noti_tx.clone();
75            tokio::spawn(async move {
76                match noti_tx2.send(noti2).await {
77                    Ok(_) => {
78                        debug!("通知发送成功: {:?}", noti);
79                    }
80                    Err(e) => {
81                        error!("通知发送失败: {:?}, error={}", noti, e);
82                    }
83                }
84            });
85        });
86
87        TradeSession {
88            broker,
89            user_id,
90            password,
91            dm,
92            ws,
93            account_tx,
94            account_rx,
95            position_tx,
96            position_rx,
97            order_tx,
98            order_rx,
99            trade_tx,
100            trade_rx,
101            notification_tx,
102            notification_rx,
103            on_account: Arc::new(RwLock::new(None)),
104            on_position: Arc::new(RwLock::new(None)),
105            on_order: Arc::new(RwLock::new(None)),
106            on_trade: Arc::new(RwLock::new(None)),
107            on_notification: Arc::new(RwLock::new(None)),
108            on_error: Arc::new(RwLock::new(None)),
109            logged_in: Arc::new(AtomicBool::new(false)),
110            running: Arc::new(AtomicBool::new(false)),
111        }
112    }
113
114    /// 发送登录请求
115    async fn send_login(&self) -> Result<()> {
116        let login_req = serde_json::json!({
117            "aid": "req_login",
118            "bid": self.broker,
119            "user_name": self.user_id,
120            "password": self.password
121        });
122
123        info!(
124            "发送交易登录请求: broker={}, user_id={}",
125            self.broker, self.user_id
126        );
127        self.ws.send(&login_req).await?;
128        Ok(())
129    }
130
131    /// 启动数据监听
132    async fn start_watching(&self) {
133        let dm_clone = Arc::clone(&self.dm);
134        let user_id = self.user_id.clone();
135        let logged_in = Arc::clone(&self.logged_in);
136        let running = Arc::clone(&self.running);
137        let account_tx = self.account_tx.clone();
138        let position_tx = self.position_tx.clone();
139        let order_tx = self.order_tx.clone();
140        let trade_tx = self.trade_tx.clone();
141        let on_account = Arc::clone(&self.on_account);
142        let on_position = Arc::clone(&self.on_position);
143        let on_order = Arc::clone(&self.on_order);
144        let on_trade = Arc::clone(&self.on_trade);
145
146        info!("TradeSession 开始监听数据更新");
147
148        // 注册数据更新回调
149        let dm_for_callback = Arc::clone(&dm_clone);
150
151        // 使用 Mutex 防止多个任务并发访问数据,避免竞态条件
152        // 这确保了 is_changing() 和数据处理是串行的
153        let processing_lock = Arc::new(tokio::sync::Mutex::new(()));
154
155        dm_clone.on_data(move || {
156            let dm = Arc::clone(&dm_for_callback);
157            let user_id = user_id.clone();
158            let logged_in = Arc::clone(&logged_in);
159            let running = Arc::clone(&running);
160            let account_tx = account_tx.clone();
161            let position_tx = position_tx.clone();
162            let order_tx = order_tx.clone();
163            let trade_tx = trade_tx.clone();
164            let on_account = Arc::clone(&on_account);
165            let on_position = Arc::clone(&on_position);
166            let on_order = Arc::clone(&on_order);
167            let on_trade = Arc::clone(&on_trade);
168            let lock = Arc::clone(&processing_lock);
169
170            tokio::spawn(async move {
171                // 获取锁,确保同一时刻只有一个任务在处理数据
172                let _guard = lock.lock().await;
173
174                if !running.load(Ordering::SeqCst) {
175                    return;
176                }
177
178                // 检查登录状态
179                if let Some(session_data) = dm.get_by_path(&["trade", &user_id, "session"]) {
180                    if let serde_json::Value::Object(session_map) = session_data {
181                        if let Some(trading_day) = session_map.get("trading_day") {
182                            if !trading_day.is_null() {
183                                // 使用 compare_exchange 实现原子的 check-and-set
184                                // 只有当 logged_in 从 false 变为 true 时才打印日志
185                                if logged_in
186                                    .compare_exchange(
187                                        false,
188                                        true,
189                                        Ordering::SeqCst,
190                                        Ordering::SeqCst,
191                                    )
192                                    .is_ok()
193                                {
194                                    info!("交易会话已登录: user_id={}", user_id);
195                                }
196                            }
197                        }
198                    }
199                }
200
201                // 检查账户更新
202                if dm.is_changing(&["trade", &user_id, "accounts", "CNY"]) {
203                    match dm.get_account_data(&user_id, "CNY") {
204                        Ok(account) => {
205                            debug!("账户更新: balance={}", account.balance);
206                            let _ = account_tx.send(account.clone()).await;
207
208                            if let Some(callback) = on_account.read().await.as_ref() {
209                                let cb = Arc::clone(callback);
210                                tokio::spawn(async move {
211                                    cb(account);
212                                });
213                            }
214                        }
215                        Err(e) => {
216                            // 理论上不应该到这里,因为 is_changing 已经验证了数据存在
217                            error!("获取账户数据失败(这不应该发生): {}", e);
218                        }
219                    }
220                }
221
222                // 检查持仓更新
223                if dm.is_changing(&["trade", &user_id, "positions"]) {
224                    Self::process_position_update(&dm, &user_id, &position_tx, &on_position).await;
225                }
226
227                // 检查委托单更新
228                if dm.is_changing(&["trade", &user_id, "orders"]) {
229                    Self::process_order_update(&dm, &user_id, &order_tx, &on_order).await;
230                }
231
232                // 检查成交更新
233                if dm.is_changing(&["trade", &user_id, "trades"]) {
234                    Self::process_trade_update(&dm, &user_id, &trade_tx, &on_trade).await;
235                }
236            });
237        });
238    }
239
240    /// 处理持仓更新
241    async fn process_position_update(
242        dm: &Arc<DataManager>,
243        user_id: &str,
244        position_tx: &Sender<PositionUpdate>,
245        on_position: &Arc<RwLock<Option<Arc<dyn Fn(String, Position) + Send + Sync>>>>,
246    ) {
247        if let Some(positions_data) = dm.get_by_path(&["trade", user_id, "positions"]) {
248            if let serde_json::Value::Object(positions_map) = positions_data {
249                for (symbol, _) in positions_map.iter() {
250                    // 跳过内部元数据字段(以 _ 开头)
251                    if symbol.starts_with('_') {
252                        continue;
253                    }
254
255                    // 检查单个持仓是否有更新
256                    if dm.is_changing(&["trade", user_id, "positions", symbol]) {
257                        match dm.get_position_data(user_id, symbol) {
258                            Ok(position) => {
259                                debug!("持仓更新: symbol={}", symbol);
260
261                                let update = PositionUpdate {
262                                    symbol: symbol.clone(),
263                                    position: position.clone(),
264                                };
265
266                                // 发送到 async-channel
267                                let _ = position_tx.send(update).await;
268
269                                // 调用回调
270                                if let Some(callback) = on_position.read().await.as_ref() {
271                                    let cb = Arc::clone(callback);
272                                    let symbol = symbol.clone();
273                                    tokio::spawn(async move {
274                                        cb(symbol, position);
275                                    });
276                                }
277                            }
278                            Err(e) => {
279                                error!("获取持仓数据失败: symbol={}, error={}", symbol, e);
280                            }
281                        }
282                    }
283                }
284            }
285        }
286    }
287
288    /// 处理委托单更新
289    async fn process_order_update(
290        dm: &Arc<DataManager>,
291        user_id: &str,
292        order_tx: &Sender<Order>,
293        on_order: &Arc<RwLock<Option<Arc<dyn Fn(Order) + Send + Sync>>>>,
294    ) {
295        if let Some(orders_data) = dm.get_by_path(&["trade", user_id, "orders"]) {
296            if let serde_json::Value::Object(orders_map) = orders_data {
297                for (order_id, order_data) in orders_map.iter() {
298                    // 跳过内部元数据字段(以 _ 开头)
299                    if order_id.starts_with('_') {
300                        continue;
301                    }
302
303                    // 检查单个订单是否有更新
304                    if dm.is_changing(&["trade", user_id, "orders", order_id]) {
305                        if let Ok(order) = serde_json::from_value::<Order>(order_data.clone()) {
306                            debug!("订单更新: order_id={}", order_id);
307
308                            // 发送到 async-channel
309                            let _ = order_tx.send(order.clone()).await;
310
311                            // 调用回调
312                            if let Some(callback) = on_order.read().await.as_ref() {
313                                let cb = Arc::clone(callback);
314                                tokio::spawn(async move {
315                                    cb(order);
316                                });
317                            }
318                        }
319                    }
320                }
321            }
322        }
323    }
324
325    /// 处理成交更新
326    async fn process_trade_update(
327        dm: &Arc<DataManager>,
328        user_id: &str,
329        trade_tx: &Sender<Trade>,
330        on_trade: &Arc<RwLock<Option<Arc<dyn Fn(Trade) + Send + Sync>>>>,
331    ) {
332        if let Some(trades_data) = dm.get_by_path(&["trade", user_id, "trades"]) {
333            if let serde_json::Value::Object(trades_map) = trades_data {
334                for (trade_id, trade_data) in trades_map.iter() {
335                    // 跳过内部元数据字段(以 _ 开头)
336                    if trade_id.starts_with('_') {
337                        continue;
338                    }
339
340                    // 检查单个成交是否有更新
341                    if dm.is_changing(&["trade", user_id, "trades", trade_id]) {
342                        if let Ok(trade) = serde_json::from_value::<Trade>(trade_data.clone()) {
343                            debug!("成交更新: trade_id={}", trade_id);
344
345                            // 发送到 async-channel
346                            let _ = trade_tx.send(trade.clone()).await;
347
348                            // 调用回调
349                            if let Some(callback) = on_trade.read().await.as_ref() {
350                                let cb = Arc::clone(callback);
351                                tokio::spawn(async move {
352                                    cb(trade);
353                                });
354                            }
355                        }
356                    }
357                }
358            }
359        }
360    }
361
362    /// 注册账户更新回调
363    pub async fn on_account<F>(&self, handler: F)
364    where
365        F: Fn(Account) + Send + Sync + 'static,
366    {
367        let mut guard = self.on_account.write().await;
368        *guard = Some(Arc::new(handler));
369    }
370
371    /// 注册持仓更新回调
372    pub async fn on_position<F>(&self, handler: F)
373    where
374        F: Fn(String, Position) + Send + Sync + 'static,
375    {
376        let mut guard = self.on_position.write().await;
377        *guard = Some(Arc::new(handler));
378    }
379
380    /// 注册委托单更新回调
381    pub async fn on_order<F>(&self, handler: F)
382    where
383        F: Fn(Order) + Send + Sync + 'static,
384    {
385        let mut guard = self.on_order.write().await;
386        *guard = Some(Arc::new(handler));
387    }
388
389    /// 注册成交记录回调
390    pub async fn on_trade<F>(&self, handler: F)
391    where
392        F: Fn(Trade) + Send + Sync + 'static,
393    {
394        let mut guard = self.on_trade.write().await;
395        *guard = Some(Arc::new(handler));
396    }
397
398    /// 注册通知回调
399    pub async fn on_notification<F>(&self, handler: F)
400    where
401        F: Fn(Notification) + Send + Sync + 'static,
402    {
403        let mut guard = self.on_notification.write().await;
404        *guard = Some(Arc::new(handler));
405    }
406
407    /// 注册错误回调
408    pub async fn on_error<F>(&self, handler: F)
409    where
410        F: Fn(String) + Send + Sync + 'static,
411    {
412        let mut guard = self.on_error.write().await;
413        *guard = Some(Arc::new(handler));
414    }
415
416    /// 获取账户更新 Channel(克隆接收端)
417    pub fn account_channel(&self) -> Receiver<Account> {
418        self.account_rx.clone()
419    }
420
421    /// 获取持仓更新 Channel(克隆接收端)
422    pub fn position_channel(&self) -> Receiver<PositionUpdate> {
423        self.position_rx.clone()
424    }
425
426    /// 获取委托单更新 Channel(克隆接收端)
427    pub fn order_channel(&self) -> Receiver<Order> {
428        self.order_rx.clone()
429    }
430
431    /// 获取成交记录 Channel(克隆接收端)
432    pub fn trade_channel(&self) -> Receiver<Trade> {
433        self.trade_rx.clone()
434    }
435
436    /// 获取通知 Channel(克隆接收端)
437    pub fn notification_channel(&self) -> Receiver<Notification> {
438        self.notification_rx.clone()
439    }
440}
441
442impl TradeSession {
443    /// 下单
444    pub async fn insert_order(&self, req: &InsertOrderRequest) -> Result<Order> {
445        if !self.is_ready() {
446            return Err(TqError::InternalError("交易会话未就绪".to_string()));
447        }
448
449        let exchange_id = req.get_exchange_id();
450        let instrument_id = req.get_instrument_id();
451
452        // 生成订单 ID
453        use uuid::Uuid;
454        let order_id = format!(
455            "TQRS_{}",
456            Uuid::new_v4().simple().to_string()[..8].to_uppercase()
457        );
458
459        // 确定时间条件
460        let time_condition = if req.price_type == "ANY" {
461            "IOC"
462        } else {
463            "GFD"
464        };
465
466        // 发送下单请求
467        let order_req = serde_json::json!({
468            "aid": "insert_order",
469            "user_id": self.user_id,
470            "order_id": order_id,
471            "exchange_id": exchange_id,
472            "instrument_id": instrument_id,
473            "direction": req.direction,
474            "offset": req.offset,
475            "volume": req.volume,
476            "price_type": req.price_type,
477            "limit_price": req.limit_price,
478            "volume_condition": "ANY",
479            "time_condition": time_condition,
480        });
481
482        info!("发送下单请求: order_id={}, symbol={}", order_id, req.symbol);
483        self.ws.send(&order_req).await?;
484
485        // 初始化订单状态到 DataManager
486        let order_init = serde_json::json!({
487            "user_id": self.user_id,
488            "order_id": order_id,
489            "exchange_id": exchange_id,
490            "instrument_id": instrument_id,
491            "direction": req.direction,
492            "offset": req.offset,
493            "volume_orign": req.volume,
494            "volume_left": req.volume,
495            "price_type": req.price_type,
496            "limit_price": req.limit_price,
497            "status": "ALIVE",
498        });
499
500        self.dm.merge_data(
501            serde_json::json!({
502                "trade": {
503                    self.user_id.clone(): {
504                        "orders": {
505                            order_id.clone(): order_init
506                        }
507                    }
508                }
509            }),
510            false,
511            false,
512        );
513
514        // 返回初始订单对象
515        Ok(Order {
516            order_id: order_id.clone(),
517            exchange_id: exchange_id.to_string(),
518            instrument_id: instrument_id.to_string(),
519            direction: req.direction.clone(),
520            offset: req.offset.clone(),
521            volume_orign: req.volume,
522            volume_left: req.volume,
523            limit_price: req.limit_price,
524            price_type: req.price_type.clone(),
525            volume_condition: "ANY".to_string(),
526            time_condition: time_condition.to_string(),
527            insert_date_time: 0,
528            status: "ALIVE".to_string(),
529            frozen_margin: 0f64,
530            epoch: None,
531            seqno: 0,
532            user_id: self.user_id.clone(),
533            exchange_order_id: String::default(),
534            last_msg: String::default(),
535        })
536    }
537
538    /// 撤单
539    pub async fn cancel_order(&self, order_id: &str) -> Result<()> {
540        if !self.is_ready() {
541            return Err(TqError::InternalError("交易会话未就绪".to_string()));
542        }
543
544        // 发送撤单请求
545        let cancel_req = serde_json::json!({
546            "aid": "cancel_order",
547            "user_id": self.user_id,
548            "order_id": order_id
549        });
550
551        info!("发送撤单请求: order_id={}", order_id);
552        self.ws.send(&cancel_req).await?;
553        Ok(())
554    }
555
556    /// 获取账户信息
557    pub async fn get_account(&self) -> Result<Account> {
558        self.dm.get_account_data(&self.user_id, "CNY")
559    }
560
561    /// 获取指定合约的持仓
562    pub async fn get_position(&self, symbol: &str) -> Result<Position> {
563        self.dm.get_position_data(&self.user_id, symbol)
564    }
565
566    /// 获取所有持仓
567    pub async fn get_positions(&self) -> Result<HashMap<String, Position>> {
568        let data = self.dm.get_by_path(&["trade", &self.user_id, "positions"]);
569        if data.is_none() {
570            return Ok(HashMap::new());
571        }
572
573        let positions_map = match data.unwrap() {
574            serde_json::Value::Object(map) => map,
575            _ => return Ok(HashMap::new()),
576        };
577
578        let mut positions = HashMap::new();
579        for (symbol, pos_data) in positions_map.iter() {
580            // 跳过内部元数据字段(以 _ 开头)
581            if symbol.starts_with('_') {
582                continue;
583            }
584
585            if let Ok(position) = serde_json::from_value::<Position>(pos_data.clone()) {
586                positions.insert(symbol.clone(), position);
587            }
588        }
589
590        Ok(positions)
591    }
592
593    /// 获取所有委托单
594    pub async fn get_orders(&self) -> Result<HashMap<String, Order>> {
595        let data = self.dm.get_by_path(&["trade", &self.user_id, "orders"]);
596        if data.is_none() {
597            return Ok(HashMap::new());
598        }
599
600        let orders_map = match data.unwrap() {
601            serde_json::Value::Object(map) => map,
602            _ => return Ok(HashMap::new()),
603        };
604
605        let mut orders = HashMap::new();
606        for (order_id, order_data) in orders_map.iter() {
607            // 跳过内部元数据字段(以 _ 开头)
608            if order_id.starts_with('_') {
609                continue;
610            }
611
612            if let Ok(order) = serde_json::from_value::<Order>(order_data.clone()) {
613                orders.insert(order_id.clone(), order);
614            }
615        }
616
617        Ok(orders)
618    }
619
620    /// 获取所有成交记录
621    pub async fn get_trades(&self) -> Result<HashMap<String, Trade>> {
622        let data = self.dm.get_by_path(&["trade", &self.user_id, "trades"]);
623        if data.is_none() {
624            return Ok(HashMap::new());
625        }
626
627        let trades_map = match data.unwrap() {
628            serde_json::Value::Object(map) => map,
629            _ => return Ok(HashMap::new()),
630        };
631
632        let mut trades = HashMap::new();
633        for (trade_id, trade_data) in trades_map.iter() {
634            // 跳过内部元数据字段(以 _ 开头)
635            if trade_id.starts_with('_') {
636                continue;
637            }
638
639            if let Ok(trade) = serde_json::from_value::<Trade>(trade_data.clone()) {
640                trades.insert(trade_id.clone(), trade);
641            }
642        }
643
644        Ok(trades)
645    }
646
647    /// 连接交易服务器
648    pub async fn connect(&self) -> Result<()> {
649        // 使用 compare_exchange 确保只连接一次
650        if self
651            .running
652            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
653            .is_err()
654        {
655            return Ok(()); // 已经在运行
656        }
657
658        info!(
659            "连接交易服务器: broker={}, user_id={}",
660            self.broker, self.user_id
661        );
662
663        // 初始化 WebSocket
664        self.ws.init(false).await?;
665
666        // 先启动数据监听(注册数据更新回调)
667        self.start_watching().await;
668
669        // 再发送登录请求(避免错过初始数据)
670        self.send_login().await?;
671
672        Ok(())
673    }
674
675    /// 检查交易会话是否已就绪
676    pub fn is_ready(&self) -> bool {
677        // 简化实现:检查是否已登录(原子操作,无需异步)
678        // TODO: 添加更详细的就绪检查
679        self.logged_in.load(Ordering::SeqCst) && self.ws.is_ready()
680    }
681
682    /// 关闭交易会话
683    pub async fn close(&self) -> Result<()> {
684        // 使用 compare_exchange 确保只关闭一次
685        if self
686            .running
687            .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
688            .is_err()
689        {
690            return Ok(()); // 已经关闭
691        }
692
693        info!("关闭交易会话");
694        self.ws.close().await?;
695        Ok(())
696    }
697}