tqsdk_rs/
quote.rs

1//! Quote 订阅模块
2//!
3//! 实现行情订阅功能
4
5use crate::datamanager::DataManager;
6use crate::errors::Result;
7use crate::types::Quote;
8use crate::websocket::TqQuoteWebsocket;
9use std::collections::HashSet;
10use std::sync::Arc;
11use async_channel::{Receiver, Sender, unbounded};
12use tokio::sync::RwLock;
13use tracing::{debug, info, warn};
14
15/// Quote 订阅
16pub struct QuoteSubscription {
17    dm: Arc<DataManager>,
18    ws: Arc<TqQuoteWebsocket>,
19    symbols: Arc<RwLock<HashSet<String>>>,
20    quote_tx: Sender<Quote>,
21    quote_rx: Receiver<Quote>,
22    on_quote: Arc<RwLock<Option<Arc<dyn Fn(Arc<Quote>) + Send + Sync>>>>,
23    on_error: Arc<RwLock<Option<Arc<dyn Fn(Arc<String>) + Send + Sync>>>>,
24    running: Arc<RwLock<bool>>,
25}
26
27impl QuoteSubscription {
28    /// 创建新的 Quote 订阅
29    pub fn new(
30        dm: Arc<DataManager>,
31        ws: Arc<TqQuoteWebsocket>,
32        initial_symbols: Vec<String>,
33    ) -> Self {
34        let symbols: HashSet<String> = initial_symbols.into_iter().collect();
35
36        // 创建 async-channel(使用 unbounded)
37        let (quote_tx, quote_rx) = unbounded();
38
39        QuoteSubscription {
40            dm,
41            ws,
42            symbols: Arc::new(RwLock::new(symbols)),
43            quote_tx,
44            quote_rx,
45            on_quote: Arc::new(RwLock::new(None)),
46            on_error: Arc::new(RwLock::new(None)),
47            running: Arc::new(RwLock::new(false)),
48        }
49    }
50
51    /// 启动订阅监听
52    pub async fn start(&self) -> Result<()> {
53        let mut running = self.running.write().await;
54        if *running {
55            return Ok(());
56        }
57        *running = true;
58        drop(running);
59
60        debug!("启动 Quote 订阅");
61
62        // 先启动监听(注册数据更新回调)
63        self.start_watching().await;
64
65        // 再发送订阅请求(避免错过初始数据)
66        self.send_subscription().await?;
67
68        Ok(())
69    }
70
71    /// 添加合约
72    pub async fn add_symbols(&self, symbols: &[&str]) -> Result<()> {
73        if symbols.is_empty() {
74            return Ok(());
75        }
76
77        let mut symbol_set = self.symbols.write().await;
78        for &symbol in symbols {
79            symbol_set.insert(symbol.to_string());
80        }
81        drop(symbol_set);
82
83        self.send_subscription().await
84    }
85
86    /// 移除合约
87    pub async fn remove_symbols(&self, symbols: &[&str]) -> Result<()> {
88        if symbols.is_empty() {
89            return Ok(());
90        }
91
92        let mut symbol_set = self.symbols.write().await;
93        for &symbol in symbols {
94            symbol_set.remove(symbol);
95        }
96        drop(symbol_set);
97
98        self.send_subscription().await
99    }
100
101    /// 发送订阅请求
102    async fn send_subscription(&self) -> Result<()> {
103        let symbols = self.symbols.read().await;
104        let ins_list: Vec<String> = symbols.iter().cloned().collect();
105        let ins_list_str = ins_list.join(",");
106        drop(symbols);
107
108        debug!("发送 Quote 订阅请求: {} 个合约", ins_list.len());
109        debug!("订阅合约列表: {}", ins_list_str);
110
111        let req = serde_json::json!({
112            "aid": "subscribe_quote",
113            "ins_list": ins_list_str
114        });
115
116        self.ws.send(&req).await?;
117        Ok(())
118    }
119
120    /// 获取 Quote 更新通道(克隆接收端)
121    pub fn quote_channel(&self) -> Receiver<Quote> {
122        self.quote_rx.clone()
123    }
124
125    /// 注册回调
126    pub async fn on_quote<F>(&self, handler: F)
127    where
128        F: Fn(Arc<Quote>) + Send + Sync + 'static,
129    {
130        let mut guard = self.on_quote.write().await;
131        *guard = Some(Arc::new(handler));
132    }
133
134    /// 注册错误回调
135    pub async fn on_error<F>(&self, handler: F)
136    where
137        F: Fn(Arc<String>) + Send + Sync + 'static,
138    {
139        let mut guard = self.on_error.write().await;
140        *guard = Some(Arc::new(handler));
141    }
142
143    /// 启动监听
144    async fn start_watching(&self) {
145        let dm_clone = Arc::clone(&self.dm);
146        let symbols = Arc::clone(&self.symbols);
147        let quote_tx = self.quote_tx.clone();
148        let on_quote = Arc::clone(&self.on_quote);
149        let running = Arc::clone(&self.running);
150
151        info!("QuoteSubscription 开始监听数据更新");
152
153        // 注册数据更新回调
154        let dm_for_callback = Arc::clone(&dm_clone);
155        dm_clone.on_data(move || {
156            let dm = Arc::clone(&dm_for_callback);
157            let symbols = Arc::clone(&symbols);
158            let quote_tx = quote_tx.clone();
159            let on_quote = Arc::clone(&on_quote);
160            let running = Arc::clone(&running);
161
162            tokio::spawn(async move {
163                let is_running = *running.read().await;
164                if !is_running {
165                    return;
166                }
167
168                let symbol_list: Vec<String> = {
169                    let s = symbols.read().await;
170                    s.iter().cloned().collect()
171                };
172
173                for symbol in symbol_list {
174                    // 检查是否有更新
175                    let path: Vec<&str> = vec!["quotes", &symbol];
176                    if dm.is_changing(&path) {
177                        match dm.get_quote_data(&symbol) {
178                            Ok(quote) => {
179                                debug!(
180                                    "获取到 Quote 更新: symbol={}, last_price={}",
181                                    symbol, quote.last_price
182                                );
183
184                                // 包装为 Arc(零拷贝共享)
185                                let quote_arc = Arc::new(quote);
186
187                                // 发送到 async-channel(支持多个订阅者)
188                                // 注意:channel 仍然发送 Quote 而不是 Arc<Quote>,保持向后兼容
189                                let _ = quote_tx.send((*quote_arc).clone()).await;
190
191                                // 调用回调(使用 Arc)
192                                if let Some(callback) = on_quote.read().await.as_ref() {
193                                    let cb = Arc::clone(callback);
194                                    let q = Arc::clone(&quote_arc);
195                                    tokio::spawn(async move {
196                                        cb(q);
197                                    });
198                                }
199                            }
200                            Err(e) => {
201                                warn!("获取 Quote 失败: symbol={}, error={}", symbol, e);
202                            }
203                        }
204                    }
205                }
206            });
207        });
208    }
209
210    /// 关闭订阅
211    pub async fn close(&self) -> Result<()> {
212        let mut running = self.running.write().await;
213        if !*running {
214            return Ok(());
215        }
216        *running = false;
217
218        info!("关闭 Quote 订阅");
219        Ok(())
220    }
221}