1use crate::errors::{Result, TqError};
6use async_trait::async_trait;
7use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION, USER_AGENT};
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12use std::time::Duration;
13use tracing::{debug, info};
14
15pub const VERSION: &str = "3.8.1";
17pub const TQ_AUTH_URL: &str = "https://auth.shinnytech.com";
19pub const CLIENT_ID: &str = "shinny_tq";
21pub const CLIENT_SECRET: &str = "be30b9f4-6862-488a-99ad-21bde0400081";
23
24#[async_trait]
26pub trait Authenticator: Send + Sync {
27 fn base_header(&self) -> HeaderMap;
29
30 async fn login(&mut self) -> Result<()>;
32
33 async fn get_td_url(&self, broker_id: &str, account_id: &str) -> Result<BrokerInfo>;
35
36 async fn get_md_url(&self, stock: bool, backtest: bool) -> Result<String>;
38
39 fn has_feature(&self, feature: &str) -> bool;
41
42 fn has_md_grants(&self, symbols: &[&str]) -> Result<()>;
44
45 fn has_td_grants(&self, symbol: &str) -> Result<()>;
47
48 fn get_auth_id(&self) -> &str;
50
51 fn get_access_token(&self) -> &str;
53}
54
55#[derive(Debug, Clone, Default)]
57pub struct Grants {
58 pub features: HashSet<String>,
60 pub accounts: HashSet<String>,
62}
63
64#[allow(dead_code)]
66#[derive(Debug, Deserialize)]
67struct AuthResponse {
68 access_token: String,
69 expires_in: i64,
70 refresh_expires_in: i64,
71 refresh_token: String,
72 token_type: String,
73 #[serde(rename = "not-before-policy")]
74 not_before_policy: i32,
75 session_state: String,
76 scope: String,
77}
78
79#[derive(Debug, Serialize, Deserialize)]
81struct AccessTokenClaims {
82 jti: String,
83 exp: i64,
84 nbf: i64,
85 iat: i64,
86 iss: String,
87 sub: String,
88 typ: String,
89 azp: String,
90 auth_time: i64,
91 session_state: String,
92 acr: String,
93 scope: String,
94 grants: GrantsClaims,
95 creation_time: i64,
96 setname: bool,
97 mobile: String,
98 #[serde(rename = "mobileVerified")]
99 mobile_verified: String,
100 preferred_username: String,
101 id: String,
102 username: String,
103}
104
105#[derive(Debug, Serialize, Deserialize)]
106struct GrantsClaims {
107 features: Vec<String>,
108 otg_ids: String,
109 expiry_date: String,
110 accounts: Vec<String>,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct BrokerInfo {
116 pub category: Vec<String>,
117 pub url: String,
118 pub broker_type: Option<String>,
119 pub smtype: Option<String>,
120 pub smconfig: Option<String>,
121 pub condition_type: Option<String>,
122 pub condition_config: Option<String>,
123}
124
125#[derive(Debug, Deserialize)]
127struct MdUrlResponse {
128 mdurl: String,
129}
130
131pub struct TqAuth {
133 username: String,
134 password: String,
135 auth_url: String,
136 access_token: String,
137 refresh_token: String,
138 auth_id: String,
139 grants: Grants,
140}
141
142impl TqAuth {
143 pub fn new(username: String, password: String) -> Self {
145 let auth_url = std::env::var("TQ_AUTH_URL").unwrap_or_else(|_| TQ_AUTH_URL.to_string());
146
147 TqAuth {
148 username,
149 password,
150 auth_url,
151 access_token: String::new(),
152 refresh_token: String::new(),
153 auth_id: String::new(),
154 grants: Grants::default(),
155 }
156 }
157
158 async fn request_token(&mut self) -> Result<()> {
160 let url = format!(
161 "{}/auth/realms/shinnytech/protocol/openid-connect/token",
162 self.auth_url
163 );
164
165 let params = [
166 ("client_id", CLIENT_ID),
167 ("client_secret", CLIENT_SECRET),
168 ("username", &self.username),
169 ("password", &self.password),
170 ("grant_type", "password"),
171 ];
172
173 info!("正在请求认证 token...");
174
175 let client = reqwest::Client::builder()
176 .no_proxy()
177 .timeout(Duration::from_secs(30))
178 .build()?;
179
180 let response = client
181 .post(&url)
182 .form(¶ms)
183 .header(USER_AGENT, format!("tqsdk-python {}", VERSION))
184 .header(ACCEPT, "application/json")
185 .send()
186 .await?;
187
188 if !response.status().is_success() {
189 let status = response.status();
190 let body = response.text().await?;
191 return Err(TqError::AuthenticationError(format!(
192 "认证失败 ({}): {}",
193 status, body
194 )));
195 }
196
197 let auth_resp: AuthResponse = response.json().await?;
198 self.access_token = auth_resp.access_token;
199 self.refresh_token = auth_resp.refresh_token;
200
201 debug!("Token 获取成功");
202 Ok(())
203 }
204
205 fn parse_token(&mut self) -> Result<()> {
207 use jsonwebtoken::dangerous::insecure_decode;
209 let token_data = insecure_decode::<AccessTokenClaims>(&self.access_token)?;
211
212 let claims = token_data.claims;
213 self.auth_id = claims.sub;
214
215 for feature in claims.grants.features {
217 self.grants.features.insert(feature);
218 }
219
220 for account in claims.grants.accounts {
221 self.grants.accounts.insert(account);
222 }
223
224 debug!(
225 "权限解析完成: {} 个功能, {} 个账户",
226 self.grants.features.len(),
227 self.grants.accounts.len()
228 );
229
230 Ok(())
231 }
232}
233
234#[async_trait]
235impl Authenticator for TqAuth {
236 fn base_header(&self) -> HeaderMap {
237 let mut headers = HeaderMap::new();
238 headers.insert(USER_AGENT, HeaderValue::from_static("tqsdk-python 3.8.1"));
239 headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
240 if !self.access_token.is_empty() {
241 if let Ok(value) = HeaderValue::from_str(&format!("Bearer {}", self.access_token)) {
242 headers.insert(AUTHORIZATION, value);
243 }
244 }
245 headers
246 }
247
248 async fn login(&mut self) -> Result<()> {
249 self.request_token().await?;
250 self.parse_token()?;
251 info!("TqAuth 登录成功, User: {}, AuthId: {}", self.username, self.auth_id);
252 Ok(())
253 }
254
255 async fn get_td_url(&self, broker_id: &str, account_id: &str) -> Result<BrokerInfo> {
256 let url = format!("https://files.shinnytech.com/{}.json", broker_id);
257
258 let client = reqwest::Client::builder()
259 .no_proxy()
260 .timeout(Duration::from_secs(30))
261 .build()?;
262
263 let response = client
264 .get(&url)
265 .query(&[("account_id", account_id), ("auth", &self.username)])
266 .headers(self.base_header())
267 .send()
268 .await?;
269
270 if !response.status().is_success() {
271 return Err(TqError::ConfigError(format!(
272 "不支持该期货公司: {}",
273 broker_id
274 )));
275 }
276
277 let broker_infos: std::collections::HashMap<String, BrokerInfo> = response.json().await?;
278
279 broker_infos.get(broker_id).cloned().ok_or_else(|| {
280 TqError::ConfigError(format!("该期货公司 {} 暂不支持 TqSdk 登录", broker_id))
281 })
282 }
283
284 async fn get_md_url(&self, stock: bool, backtest: bool) -> Result<String> {
285 let url = format!(
286 "https://api.shinnytech.com/ns?stock={}&backtest={}",
287 stock, backtest
288 );
289
290 let client = reqwest::Client::builder()
291 .no_proxy()
292 .timeout(Duration::from_secs(30))
293 .build()?;
294
295 let response = client.get(&url).headers(self.base_header()).send().await?;
296
297 if !response.status().is_success() {
298 let status = response.status();
299 let body = response.text().await?;
300 return Err(TqError::NetworkError(format!(
301 "获取行情服务器地址失败 ({}): {}",
302 status, body
303 )));
304 }
305
306 let md_url_resp: MdUrlResponse = response.json().await?;
307 Ok(md_url_resp.mdurl)
308 }
309
310 fn has_feature(&self, feature: &str) -> bool {
311 self.grants.features.contains(feature)
312 }
313
314 fn has_md_grants(&self, symbols: &[&str]) -> Result<()> {
315 for symbol in symbols {
316 let prefix = symbol.split('.').next().unwrap_or("");
317
318 if matches!(
320 prefix,
321 "CFFEX" | "SHFE" | "DCE" | "CZCE" | "INE" | "GFEX" | "SSWE" | "KQ" | "KQD"
322 ) {
323 if !self.has_feature("futr") {
324 return Err(TqError::permission_denied_futures());
325 }
326 continue;
327 }
328
329 if prefix == "CSI" || matches!(prefix, "SSE" | "SZSE") {
331 if !self.has_feature("sec") {
332 return Err(TqError::permission_denied_stocks());
333 }
334 continue;
335 }
336
337 if matches!(
339 *symbol,
340 "SSE.000016" | "SSE.000300" | "SSE.000905" | "SSE.000852"
341 ) {
342 if !self.has_feature("lmt_idx") {
343 return Err(TqError::PermissionDenied(format!(
344 "您的账户不支持查看 {} 的行情数据",
345 symbol
346 )));
347 }
348 continue;
349 }
350
351 return Err(TqError::PermissionDenied(format!(
353 "不支持的合约: {}",
354 symbol
355 )));
356 }
357
358 Ok(())
359 }
360
361 fn has_td_grants(&self, symbol: &str) -> Result<()> {
362 let prefix = symbol.split('.').next().unwrap_or("");
363
364 if matches!(
366 prefix,
367 "CFFEX" | "SHFE" | "DCE" | "CZCE" | "INE" | "GFEX" | "SSWE" | "KQ" | "KQD"
368 ) {
369 if self.has_feature("futr") {
370 return Ok(());
371 }
372 return Err(TqError::PermissionDenied(format!(
373 "您的账户不支持交易 {}",
374 symbol
375 )));
376 }
377
378 if prefix == "CSI" || matches!(prefix, "SSE" | "SZSE") {
380 if self.has_feature("sec") {
381 return Ok(());
382 }
383 return Err(TqError::PermissionDenied(format!(
384 "您的账户不支持交易 {}",
385 symbol
386 )));
387 }
388
389 Err(TqError::PermissionDenied(format!(
390 "不支持的合约: {}",
391 symbol
392 )))
393 }
394
395 fn get_auth_id(&self) -> &str {
396 &self.auth_id
397 }
398
399 fn get_access_token(&self) -> &str {
400 &self.access_token
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn test_tq_auth_creation() {
410 let auth = TqAuth::new("test_user".to_string(), "test_pass".to_string());
411 assert_eq!(auth.username, "test_user");
412 assert_eq!(auth.password, "test_pass");
413 }
414}