wechat_minapp/
client.rs

1use crate::{
2    Result,
3    access_token::{AccessToken, get_access_token, get_stable_access_token},
4    constants,
5    credential::{Credential, CredentialBuilder},
6    error::Error::InternalServer,
7    response::Response,
8};
9use chrono::{Duration, Utc};
10use std::{
11    collections::HashMap,
12    sync::{
13        Arc,
14        atomic::{AtomicBool, Ordering},
15    },
16};
17use tokio::sync::{Notify, RwLock};
18use tracing::{debug, instrument};
19
20///
21/// 提供与微信小程序后端 API 交互的核心功能,包括用户登录、访问令牌管理等。
22///
23/// # 功能特性
24///
25/// - 用户登录凭证校验
26/// - 访问令牌自动管理(支持普通令牌和稳定版令牌)
27/// - 线程安全的令牌刷新机制
28/// - 内置 HTTP 客户端
29///
30/// # 快速开始
31///
32/// ```no_run
33/// use wechat_minapp::Client;
34///
35/// #[tokio::main]
36/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
37///     // 初始化客户端
38///     let app_id = "your_app_id";
39///     let secret = "your_app_secret";
40///     let client = Client::new(app_id, secret);
41///
42///     // 用户登录
43///     let code = "user_login_code_from_frontend";
44///     let credential = client.login(code).await?;
45///     println!("用户OpenID: {}", credential.openid());
46///
47///     // 获取访问令牌
48///     let access_token = client.access_token().await?;
49///     println!("访问令牌: {}", access_token);
50///
51///     Ok(())
52/// }
53/// ```
54///
55/// # 令牌管理
56///
57/// 客户端自动管理访问令牌的生命周期:
58///
59/// - 令牌过期前自动刷新
60/// - 多线程环境下的安全并发访问
61/// - 避免重复刷新(令牌锁机制)
62/// - 支持强制刷新选项
63///
64/// # 线程安全
65///
66/// `Client` 实现了 `Send` 和 `Sync`,可以在多线程环境中安全使用。
67#[derive(Debug, Clone)]
68pub struct Client {
69    inner: Arc<ClientInner>,
70    access_token: Arc<RwLock<AccessToken>>,
71    refreshing: Arc<AtomicBool>,
72    notify: Arc<Notify>,
73}
74
75impl Client {
76    /// 创建新的微信小程序客户端
77    ///
78    /// # 参数
79    ///
80    /// - `app_id`: 小程序 AppID
81    /// - `secret`: 小程序 AppSecret
82    ///
83    /// # 返回
84    ///
85    /// 新的 `Client` 实例
86    ///
87    /// # 示例
88    ///
89    /// ```
90    /// use wechat_minapp::Client;
91    ///
92    /// let client = Client::new("wx1234567890abcdef", "your_app_secret_here");
93    /// ```
94    pub fn new(app_id: &str, secret: &str) -> Self {
95        let client = reqwest::Client::new();
96
97        Self {
98            inner: Arc::new(ClientInner {
99                app_id: app_id.into(),
100                secret: secret.into(),
101                client,
102            }),
103            access_token: Arc::new(RwLock::new(AccessToken {
104                access_token: "".to_string(),
105                expired_at: Utc::now(),
106                force_refresh: None,
107            })),
108            refreshing: Arc::new(AtomicBool::new(false)),
109            notify: Arc::new(Notify::new()),
110        }
111    }
112
113    pub(crate) fn request(&self) -> &reqwest::Client {
114        &self.inner.client
115    }
116
117    /// 用户登录凭证校验
118    ///
119    /// 通过微信前端获取的临时登录凭证 code,换取用户的唯一标识 OpenID 和会话密钥。
120    ///
121    /// # 参数
122    ///
123    /// - `code`: 微信前端通过 `wx.login()` 获取的临时登录凭证
124    ///
125    /// # 返回
126    ///
127    /// 成功返回 `Ok(Credential)`,包含用户身份信息
128    ///
129    /// # 错误
130    ///
131    /// - 网络错误
132    /// - 微信 API 返回错误
133    /// - 响应解析错误
134    ///
135    /// # 示例
136    ///
137    /// ```no_run
138    /// use wechat_minapp::Client;
139    ///
140    /// #[tokio::main]
141    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
142    ///     let client = Client::new("app_id", "secret");
143    ///     let code = "0816abc123def456";
144    ///     let credential = client.login(code).await?;
145    ///
146    ///     println!("用户OpenID: {}", credential.openid());
147    ///     println!("会话密钥: {}", credential.session_key());
148    ///     
149    ///     Ok(())
150    /// }
151    /// ```
152    ///
153    /// # API 文档
154    ///
155    /// [微信官方文档 - code2Session](https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/user-login/code2Session.html)
156    #[instrument(skip(self, code))]
157    pub async fn login(&self, code: &str) -> Result<Credential> {
158        debug!("code: {}", code);
159
160        let mut map: HashMap<&str, &str> = HashMap::new();
161
162        map.insert("appid", &self.inner.app_id);
163        map.insert("secret", &self.inner.secret);
164        map.insert("js_code", code);
165        map.insert("grant_type", "authorization_code");
166
167        let response = self
168            .inner
169            .client
170            .get(constants::AUTHENTICATION_END_POINT)
171            .query(&map)
172            .send()
173            .await?;
174
175        debug!("authentication response: {:#?}", response);
176
177        if response.status().is_success() {
178            let response = response.json::<Response<CredentialBuilder>>().await?;
179
180            let credential = response.extract()?.build();
181
182            debug!("credential: {:#?}", credential);
183
184            Ok(credential)
185        } else {
186            Err(InternalServer(response.text().await?))
187        }
188    }
189
190    /// 获取访问令牌
191    ///
192    /// 获取用于调用微信小程序接口的访问令牌。如果当前令牌已过期或即将过期,会自动刷新。
193    ///
194    /// # 返回
195    ///
196    /// 成功返回 `Ok(String)`,包含有效的访问令牌
197    ///
198    /// # 错误
199    ///
200    /// - 网络错误
201    /// - 微信 API 返回错误
202    /// - 令牌刷新失败
203    ///
204    /// # 示例
205    ///
206    /// ```no_run
207    /// use wechat_minapp::Client;
208    ///
209    /// #[tokio::main]
210    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
211    ///     let client = Client::new("app_id", "secret");
212    ///     let access_token = client.access_token().await?;
213    ///     
214    ///     println!("访问令牌: {}", access_token);
215    ///     Ok(())
216    /// }
217    /// ```
218    ///
219    /// # 注意
220    ///
221    /// - 令牌有效期为 2 小时
222    /// - 客户端会自动管理令牌刷新,无需手动处理
223    /// - 多线程环境下安全
224    pub async fn access_token(&self) -> Result<String> {
225        // 第一次检查:快速路径
226        {
227            let guard = self.access_token.read().await;
228            if !is_token_expired(&guard) {
229                return Ok(guard.access_token.clone());
230            }
231        }
232
233        // 使用CAS竞争刷新权
234        if self
235            .refreshing
236            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
237            .is_ok()
238        {
239            // 获得刷新权
240            match self.refresh_access_token().await {
241                Ok(token) => {
242                    self.refreshing.store(false, Ordering::Release);
243                    self.notify.notify_waiters();
244                    Ok(token)
245                }
246                Err(e) => {
247                    self.refreshing.store(false, Ordering::Release);
248                    self.notify.notify_waiters();
249                    Err(e)
250                }
251            }
252        } else {
253            // 等待其他线程刷新完成
254            self.notify.notified().await;
255            // 刷新完成后重新读取
256            let guard = self.access_token.read().await;
257            Ok(guard.access_token.clone())
258        }
259    }
260
261    async fn refresh_access_token(&self) -> Result<String> {
262        let mut guard = self.access_token.write().await;
263
264        if !is_token_expired(&guard) {
265            debug!("token already refreshed by another thread");
266            return Ok(guard.access_token.clone());
267        }
268
269        debug!("performing network request to refresh token");
270
271        let builder = get_access_token(
272            self.inner.client.clone(),
273            &self.inner.app_id,
274            &self.inner.secret,
275        )
276        .await?;
277
278        guard.access_token = builder.access_token.clone();
279        guard.expired_at = builder.expired_at;
280
281        debug!("fresh access token: {:#?}", guard);
282
283        Ok(guard.access_token.clone())
284    }
285
286    /// 获取稳定版访问令牌
287    ///
288    /// 获取稳定版的访问令牌,相比普通令牌有更长的有效期和更好的稳定性。
289    ///
290    /// # 参数
291    ///
292    /// - `force_refresh`: 是否强制刷新令牌
293    ///   - `Some(true)`: 强制从微信服务器获取最新令牌
294    ///   - `Some(false)` 或 `None`: 仅在令牌过期时刷新
295    ///
296    /// # 返回
297    ///
298    /// 成功返回 `Ok(String)`,包含有效的稳定版访问令牌
299    ///
300    /// # 错误
301    ///
302    /// - 网络错误
303    /// - 微信 API 返回错误
304    /// - 令牌刷新失败
305    ///
306    /// # 示例
307    ///
308    /// ```no_run
309    /// use wechat_minapp::Client;
310    ///
311    /// #[tokio::main]
312    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
313    ///     let client = Client::new("app_id", "secret");
314    ///     
315    ///     // 仅在过期时刷新
316    ///     let token1 = client.stable_access_token(None).await?;
317    ///     
318    ///     // 强制刷新
319    ///     let token2 = client.stable_access_token(true).await?;
320    ///     
321    ///     Ok(())
322    /// }
323    /// ```
324    ///
325    /// # 注意
326    ///
327    /// - 稳定版令牌有效期更长,推荐在生产环境使用
328    /// - 强制刷新会忽略本地缓存,直接请求新令牌
329    pub async fn stable_access_token(
330        &self,
331        force_refresh: impl Into<Option<bool>> + Clone + Send,
332    ) -> Result<String> {
333        // 第一次检查:快速路径
334        {
335            let guard = self.access_token.read().await;
336            if !is_token_expired(&guard) {
337                return Ok(guard.access_token.clone());
338            }
339        }
340
341        // 使用CAS竞争刷新权
342        if self
343            .refreshing
344            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
345            .is_ok()
346        {
347            // 获得刷新权
348            match self.refresh_stable_access_token(force_refresh).await {
349                Ok(token) => {
350                    self.refreshing.store(false, Ordering::Release);
351                    self.notify.notify_waiters();
352                    Ok(token)
353                }
354                Err(e) => {
355                    self.refreshing.store(false, Ordering::Release);
356                    self.notify.notify_waiters();
357                    Err(e)
358                }
359            }
360        } else {
361            // 等待其他线程刷新完成
362            self.notify.notified().await;
363            // 刷新完成后重新读取
364            let guard = self.access_token.read().await;
365            Ok(guard.access_token.clone())
366        }
367    }
368
369    async fn refresh_stable_access_token(
370        &self,
371        force_refresh: impl Into<Option<bool>> + Clone + Send,
372    ) -> Result<String> {
373        // 1. Acquire the write lock. This blocks if another thread won CAS but is refreshing.
374        let mut guard = self.access_token.write().await;
375
376        // 2. Double-check expiration under the write lock (CRITICAL)
377        // If another CAS-winner refreshed the token while we were waiting for the write lock,
378        // we return the new token without performing a new network call.
379        if !is_token_expired(&guard) {
380            // Token is now fresh, return it
381            debug!("token already refreshed by another thread");
382            return Ok(guard.access_token.clone());
383        }
384
385        // 3. Perform the network request since the token is still stale
386        debug!("performing network request to refresh token");
387
388        let builder = get_stable_access_token(
389            self.inner.client.clone(),
390            &self.inner.app_id,
391            &self.inner.secret,
392            force_refresh,
393        )
394        .await?;
395
396        // 4. Update the token
397        guard.access_token = builder.access_token.clone();
398        guard.expired_at = builder.expired_at;
399
400        debug!("fresh access token: {:#?}", guard);
401
402        // Return the newly fetched token (cloned here for consistency)
403        Ok(guard.access_token.clone())
404    }
405}
406
407#[derive(Debug)]
408struct ClientInner {
409    app_id: String,
410    secret: String,
411    client: reqwest::Client,
412}
413
414/// 检查令牌是否过期
415///
416/// 添加安全边界,在令牌过期前5分钟就认为需要刷新
417fn is_token_expired(token: &AccessToken) -> bool {
418    // 添加安全边界,提前刷新
419    let now = Utc::now();
420    token.expired_at.signed_duration_since(now) < Duration::minutes(5)
421}