tqsdk_rs/
auth.rs

1//! 认证模块
2//!
3//! 实现天勤账号认证和权限检查
4
5use crate::errors::{Result, TqError};
6use async_trait::async_trait;
7// JWT 相关导入(暂时不用)
8// use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
9use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION, USER_AGENT};
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12use std::time::Duration;
13use tracing::{debug, info};
14
15/// 版本号
16pub const VERSION: &str = "3.8.1";
17/// 认证服务器地址
18pub const TQ_AUTH_URL: &str = "https://auth.shinnytech.com";
19/// 客户端 ID
20pub const CLIENT_ID: &str = "shinny_tq";
21/// 客户端密钥
22pub const CLIENT_SECRET: &str = "be30b9f4-6862-488a-99ad-21bde0400081";
23
24/// 认证器接口
25#[async_trait]
26pub trait Authenticator: Send + Sync {
27    /// 获取包含认证信息的 HTTP Header
28    fn base_header(&self) -> HeaderMap;
29
30    /// 执行登录操作
31    async fn login(&mut self) -> Result<()>;
32
33    /// 获取指定期货公司的交易服务器地址
34    async fn get_td_url(&self, broker_id: &str, account_id: &str) -> Result<BrokerInfo>;
35
36    /// 获取行情服务器地址
37    async fn get_md_url(&self, stock: bool, backtest: bool) -> Result<String>;
38
39    /// 检查是否具有指定的功能权限
40    fn has_feature(&self, feature: &str) -> bool;
41
42    /// 检查是否有查看指定合约行情数据的权限
43    fn has_md_grants(&self, symbols: &[&str]) -> Result<()>;
44
45    /// 检查是否有交易指定合约的权限
46    fn has_td_grants(&self, symbol: &str) -> Result<()>;
47
48    /// 获取认证 ID
49    fn get_auth_id(&self) -> &str;
50
51    /// 获取访问令牌
52    fn get_access_token(&self) -> &str;
53}
54
55/// 权限信息
56#[derive(Debug, Clone, Default)]
57pub struct Grants {
58    /// 功能权限
59    pub features: HashSet<String>,
60    /// 账户权限
61    pub accounts: HashSet<String>,
62}
63
64/// 认证响应
65#[allow(dead_code)]
66#[derive(Debug, Deserialize)]
67struct AuthResponse {
68    access_token: String,
69    expires_in: i64,
70    refresh_expires_in: i64,
71    refresh_token: String,
72    token_type: String,
73    #[serde(rename = "not-before-policy")]
74    not_before_policy: i32,
75    session_state: String,
76    scope: String,
77}
78
79/// Access Token Claims
80#[derive(Debug, Serialize, Deserialize)]
81struct AccessTokenClaims {
82    jti: String,
83    exp: i64,
84    nbf: i64,
85    iat: i64,
86    iss: String,
87    sub: String,
88    typ: String,
89    azp: String,
90    auth_time: i64,
91    session_state: String,
92    acr: String,
93    scope: String,
94    grants: GrantsClaims,
95    creation_time: i64,
96    setname: bool,
97    mobile: String,
98    #[serde(rename = "mobileVerified")]
99    mobile_verified: String,
100    preferred_username: String,
101    id: String,
102    username: String,
103}
104
105#[derive(Debug, Serialize, Deserialize)]
106struct GrantsClaims {
107    features: Vec<String>,
108    otg_ids: String,
109    expiry_date: String,
110    accounts: Vec<String>,
111}
112
113/// 期货公司信息
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct BrokerInfo {
116    pub category: Vec<String>,
117    pub url: String,
118    pub broker_type: Option<String>,
119    pub smtype: Option<String>,
120    pub smconfig: Option<String>,
121    pub condition_type: Option<String>,
122    pub condition_config: Option<String>,
123}
124
125/// 行情服务器 URL 响应
126#[derive(Debug, Deserialize)]
127struct MdUrlResponse {
128    mdurl: String,
129}
130
131/// 天勤认证实现
132pub struct TqAuth {
133    username: String,
134    password: String,
135    auth_url: String,
136    access_token: String,
137    refresh_token: String,
138    auth_id: String,
139    grants: Grants,
140}
141
142impl TqAuth {
143    /// 创建新的认证器
144    pub fn new(username: String, password: String) -> Self {
145        let auth_url = std::env::var("TQ_AUTH_URL").unwrap_or_else(|_| TQ_AUTH_URL.to_string());
146
147        TqAuth {
148            username,
149            password,
150            auth_url,
151            access_token: String::new(),
152            refresh_token: String::new(),
153            auth_id: String::new(),
154            grants: Grants::default(),
155        }
156    }
157
158    /// 请求 Token
159    async fn request_token(&mut self) -> Result<()> {
160        let url = format!(
161            "{}/auth/realms/shinnytech/protocol/openid-connect/token",
162            self.auth_url
163        );
164
165        let params = [
166            ("client_id", CLIENT_ID),
167            ("client_secret", CLIENT_SECRET),
168            ("username", &self.username),
169            ("password", &self.password),
170            ("grant_type", "password"),
171        ];
172
173        info!("正在请求认证 token...");
174
175        let client = reqwest::Client::builder()
176            .no_proxy()
177            .timeout(Duration::from_secs(30))
178            .build()?;
179
180        let response = client
181            .post(&url)
182            .form(&params)
183            .header(USER_AGENT, format!("tqsdk-python {}", VERSION))
184            .header(ACCEPT, "application/json")
185            .send()
186            .await?;
187
188        if !response.status().is_success() {
189            let status = response.status();
190            let body = response.text().await?;
191            return Err(TqError::AuthenticationError(format!(
192                "认证失败 ({}): {}",
193                status, body
194            )));
195        }
196
197        let auth_resp: AuthResponse = response.json().await?;
198        self.access_token = auth_resp.access_token;
199        self.refresh_token = auth_resp.refresh_token;
200
201        debug!("Token 获取成功");
202        Ok(())
203    }
204
205    /// 解析 JWT Token
206    fn parse_token(&mut self) -> Result<()> {
207        // 解析 JWT(不验证签名,因为我们信任天勤服务器)
208        use jsonwebtoken::dangerous::insecure_decode;
209        // 这里我们使用不验证签名的方式解析
210        let token_data = insecure_decode::<AccessTokenClaims>(&self.access_token)?;
211
212        let claims = token_data.claims;
213        self.auth_id = claims.sub;
214
215        // 提取权限
216        for feature in claims.grants.features {
217            self.grants.features.insert(feature);
218        }
219
220        for account in claims.grants.accounts {
221            self.grants.accounts.insert(account);
222        }
223
224        debug!(
225            "权限解析完成: {} 个功能, {} 个账户",
226            self.grants.features.len(),
227            self.grants.accounts.len()
228        );
229
230        Ok(())
231    }
232}
233
234#[async_trait]
235impl Authenticator for TqAuth {
236    fn base_header(&self) -> HeaderMap {
237        let mut headers = HeaderMap::new();
238        headers.insert(USER_AGENT, HeaderValue::from_static("tqsdk-python 3.8.1"));
239        headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
240        if !self.access_token.is_empty() {
241            if let Ok(value) = HeaderValue::from_str(&format!("Bearer {}", self.access_token)) {
242                headers.insert(AUTHORIZATION, value);
243            }
244        }
245        headers
246    }
247
248    async fn login(&mut self) -> Result<()> {
249        self.request_token().await?;
250        self.parse_token()?;
251        info!("TqAuth 登录成功, User: {},  AuthId: {}", self.username, self.auth_id);
252        Ok(())
253    }
254
255    async fn get_td_url(&self, broker_id: &str, account_id: &str) -> Result<BrokerInfo> {
256        let url = format!("https://files.shinnytech.com/{}.json", broker_id);
257
258        let client = reqwest::Client::builder()
259            .no_proxy()
260            .timeout(Duration::from_secs(30))
261            .build()?;
262
263        let response = client
264            .get(&url)
265            .query(&[("account_id", account_id), ("auth", &self.username)])
266            .headers(self.base_header())
267            .send()
268            .await?;
269
270        if !response.status().is_success() {
271            return Err(TqError::ConfigError(format!(
272                "不支持该期货公司: {}",
273                broker_id
274            )));
275        }
276
277        let broker_infos: std::collections::HashMap<String, BrokerInfo> = response.json().await?;
278
279        broker_infos.get(broker_id).cloned().ok_or_else(|| {
280            TqError::ConfigError(format!("该期货公司 {} 暂不支持 TqSdk 登录", broker_id))
281        })
282    }
283
284    async fn get_md_url(&self, stock: bool, backtest: bool) -> Result<String> {
285        let url = format!(
286            "https://api.shinnytech.com/ns?stock={}&backtest={}",
287            stock, backtest
288        );
289
290        let client = reqwest::Client::builder()
291            .no_proxy()
292            .timeout(Duration::from_secs(30))
293            .build()?;
294
295        let response = client.get(&url).headers(self.base_header()).send().await?;
296
297        if !response.status().is_success() {
298            let status = response.status();
299            let body = response.text().await?;
300            return Err(TqError::NetworkError(format!(
301                "获取行情服务器地址失败 ({}): {}",
302                status, body
303            )));
304        }
305
306        let md_url_resp: MdUrlResponse = response.json().await?;
307        Ok(md_url_resp.mdurl)
308    }
309
310    fn has_feature(&self, feature: &str) -> bool {
311        self.grants.features.contains(feature)
312    }
313
314    fn has_md_grants(&self, symbols: &[&str]) -> Result<()> {
315        for symbol in symbols {
316            let prefix = symbol.split('.').next().unwrap_or("");
317
318            // 检查期货、现货、KQ、KQD 交易所
319            if matches!(
320                prefix,
321                "CFFEX" | "SHFE" | "DCE" | "CZCE" | "INE" | "GFEX" | "SSWE" | "KQ" | "KQD"
322            ) {
323                if !self.has_feature("futr") {
324                    return Err(TqError::permission_denied_futures());
325                }
326                continue;
327            }
328
329            // 检查股票交易所
330            if prefix == "CSI" || matches!(prefix, "SSE" | "SZSE") {
331                if !self.has_feature("sec") {
332                    return Err(TqError::permission_denied_stocks());
333                }
334                continue;
335            }
336
337            // 检查限制指数
338            if matches!(
339                *symbol,
340                "SSE.000016" | "SSE.000300" | "SSE.000905" | "SSE.000852"
341            ) {
342                if !self.has_feature("lmt_idx") {
343                    return Err(TqError::PermissionDenied(format!(
344                        "您的账户不支持查看 {} 的行情数据",
345                        symbol
346                    )));
347                }
348                continue;
349            }
350
351            // 未知交易所
352            return Err(TqError::PermissionDenied(format!(
353                "不支持的合约: {}",
354                symbol
355            )));
356        }
357
358        Ok(())
359    }
360
361    fn has_td_grants(&self, symbol: &str) -> Result<()> {
362        let prefix = symbol.split('.').next().unwrap_or("");
363
364        // 检查期货、现货、KQ、KQD 交易所
365        if matches!(
366            prefix,
367            "CFFEX" | "SHFE" | "DCE" | "CZCE" | "INE" | "GFEX" | "SSWE" | "KQ" | "KQD"
368        ) {
369            if self.has_feature("futr") {
370                return Ok(());
371            }
372            return Err(TqError::PermissionDenied(format!(
373                "您的账户不支持交易 {}",
374                symbol
375            )));
376        }
377
378        // 检查股票交易所
379        if prefix == "CSI" || matches!(prefix, "SSE" | "SZSE") {
380            if self.has_feature("sec") {
381                return Ok(());
382            }
383            return Err(TqError::PermissionDenied(format!(
384                "您的账户不支持交易 {}",
385                symbol
386            )));
387        }
388
389        Err(TqError::PermissionDenied(format!(
390            "不支持的合约: {}",
391            symbol
392        )))
393    }
394
395    fn get_auth_id(&self) -> &str {
396        &self.auth_id
397    }
398
399    fn get_access_token(&self) -> &str {
400        &self.access_token
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_tq_auth_creation() {
410        let auth = TqAuth::new("test_user".to_string(), "test_pass".to_string());
411        assert_eq!(auth.username, "test_user");
412        assert_eq!(auth.password, "test_pass");
413    }
414}