wechat_minapp/
client.rs

1use crate::{
2    Result,
3    access_token::{AccessToken, get_access_token, get_stable_access_token},
4    constants,
5    credential::{Credential, CredentialBuilder},
6    error::Error::InternalServer,
7    response::Response,
8};
9use chrono::{Duration, Utc};
10use std::{
11    collections::HashMap,
12    sync::{
13        Arc,
14        atomic::{AtomicBool, Ordering},
15    },
16};
17use tokio::sync::{Notify, RwLock};
18use tracing::{debug, instrument};
19
20/// 存储微信小程序的 appid 和 secret
21#[derive(Debug, Clone)]
22pub struct Client {
23    inner: Arc<ClientInner>,
24    access_token: Arc<RwLock<AccessToken>>,
25    refreshing: Arc<AtomicBool>,
26    notify: Arc<Notify>,
27}
28
29impl Client {
30    /// ```ignore
31    /// use wechat_minapp::Client;
32    ///
33    /// #[tokio::main]
34    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
35    ///     let app_id = "your app id";
36    ///     let secret = "your app secret";
37    ///     
38    ///     let client = Client::new(app_id, secret);
39    ///
40    ///     Ok(())
41    /// }
42    /// ```
43    pub fn new(app_id: &str, secret: &str) -> Self {
44        let client = reqwest::Client::new();
45
46        Self {
47            inner: Arc::new(ClientInner {
48                app_id: app_id.into(),
49                secret: secret.into(),
50                client,
51            }),
52            access_token: Arc::new(RwLock::new(AccessToken {
53                access_token: "".to_string(),
54                expired_at: Utc::now(),
55                force_refresh: None,
56            })),
57            refreshing: Arc::new(AtomicBool::new(false)),
58            notify: Arc::new(Notify::new()),
59        }
60    }
61
62    pub(crate) fn request(&self) -> &reqwest::Client {
63        &self.inner.client
64    }
65
66    /// 登录凭证校验
67    /// https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/user-login/code2Session.html
68    /// ```rust
69    /// use wechat_minapp::Client;
70    /// use serde::Deserialize;
71    /// use crate::{Error, state::AppState};
72    /// use actix_web::{Responder, web};
73    /// #[derive(Deserialize, Default)]
74    /// #[serde(default)]
75    /// pub struct Logger {
76    ///     code: String,
77    /// }
78    /// pub async fn login(
79    ///     state: web::Data<AppState>,
80    ///     logger: web::Json<Logger>,
81    /// ) -> Result<impl Responder, Error> {
82    ///     let credential = state.client.login(&logger.code).await?;
83    ///     Ok(())
84    /// }
85    /// ```
86    #[instrument(skip(self, code))]
87    pub async fn login(&self, code: &str) -> Result<Credential> {
88        debug!("code: {}", code);
89
90        let mut map: HashMap<&str, &str> = HashMap::new();
91
92        map.insert("appid", &self.inner.app_id);
93        map.insert("secret", &self.inner.secret);
94        map.insert("js_code", code);
95        map.insert("grant_type", "authorization_code");
96
97        let response = self
98            .inner
99            .client
100            .get(constants::AUTHENTICATION_END_POINT)
101            .query(&map)
102            .send()
103            .await?;
104
105        debug!("authentication response: {:#?}", response);
106
107        if response.status().is_success() {
108            let response = response.json::<Response<CredentialBuilder>>().await?;
109
110            let credential = response.extract()?.build();
111
112            debug!("credential: {:#?}", credential);
113
114            Ok(credential)
115        } else {
116            Err(InternalServer(response.text().await?))
117        }
118    }
119
120    /// ```ignore
121    /// use wechat_minapp::Client;
122    ///
123    /// #[tokio::main]
124    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
125    ///     let app_id = "your app id";
126    ///     let secret = "your app secret";
127    ///     
128    ///     let client = Client::new(app_id, secret);
129    ///     let access_token = client.access_token().await?;
130    ///
131    ///     Ok(())
132    /// }
133    /// ```
134    ///
135    pub async fn access_token(&self) -> Result<String> {
136        // 第一次检查:快速路径
137        {
138            let guard = self.access_token.read().await;
139            if !is_token_expired(&guard) {
140                return Ok(guard.access_token.clone());
141            }
142        }
143
144        // 使用CAS竞争刷新权
145        if self
146            .refreshing
147            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
148            .is_ok()
149        {
150            // 获得刷新权
151            match self.refresh_access_token().await {
152                Ok(token) => {
153                    self.refreshing.store(false, Ordering::Release);
154                    self.notify.notify_waiters();
155                    Ok(token)
156                }
157                Err(e) => {
158                    self.refreshing.store(false, Ordering::Release);
159                    self.notify.notify_waiters();
160                    Err(e)
161                }
162            }
163        } else {
164            // 等待其他线程刷新完成
165            self.notify.notified().await;
166            // 刷新完成后重新读取
167            let guard = self.access_token.read().await;
168            Ok(guard.access_token.clone())
169        }
170    }
171
172    async fn refresh_access_token(&self) -> Result<String> {
173        let mut guard = self.access_token.write().await;
174
175        if !is_token_expired(&guard) {
176            debug!("token already refreshed by another thread");
177            return Ok(guard.access_token.clone());
178        }
179
180        debug!("performing network request to refresh token");
181
182        let builder = get_access_token(
183            self.inner.client.clone(),
184            &self.inner.app_id,
185            &self.inner.secret,
186        )
187        .await?;
188
189        guard.access_token = builder.access_token.clone();
190        guard.expired_at = builder.expired_at;
191
192        debug!("fresh access token: {:#?}", guard);
193
194        Ok(guard.access_token.clone())
195    }
196
197    /// ```ignore
198    /// use wechat_minapp::Client;
199    ///
200    /// #[tokio::main]
201    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
202    ///     let app_id = "your app id";
203    ///     let secret = "your app secret";
204    ///     
205    ///     let client = Client::new(app_id, secret);
206    ///     let stable_access_token = client.stable_access_token(Some(true)).await?;
207    ///
208    ///     Ok(())
209    /// }
210    /// ```
211    ///
212    ///
213    pub async fn stable_access_token(
214        &self,
215        force_refresh: impl Into<Option<bool>> + Clone + Send,
216    ) -> Result<String> {
217        // 第一次检查:快速路径
218        {
219            let guard = self.access_token.read().await;
220            if !is_token_expired(&guard) {
221                return Ok(guard.access_token.clone());
222            }
223        }
224
225        // 使用CAS竞争刷新权
226        if self
227            .refreshing
228            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
229            .is_ok()
230        {
231            // 获得刷新权
232            match self.refresh_stable_access_token(force_refresh).await {
233                Ok(token) => {
234                    self.refreshing.store(false, Ordering::Release);
235                    self.notify.notify_waiters();
236                    Ok(token)
237                }
238                Err(e) => {
239                    self.refreshing.store(false, Ordering::Release);
240                    self.notify.notify_waiters();
241                    Err(e)
242                }
243            }
244        } else {
245            // 等待其他线程刷新完成
246            self.notify.notified().await;
247            // 刷新完成后重新读取
248            let guard = self.access_token.read().await;
249            Ok(guard.access_token.clone())
250        }
251    }
252
253    async fn refresh_stable_access_token(
254        &self,
255        force_refresh: impl Into<Option<bool>> + Clone + Send,
256    ) -> Result<String> {
257        // 1. Acquire the write lock. This blocks if another thread won CAS but is refreshing.
258        let mut guard = self.access_token.write().await;
259
260        // 2. Double-check expiration under the write lock (CRITICAL)
261        // If another CAS-winner refreshed the token while we were waiting for the write lock,
262        // we return the new token without performing a new network call.
263        if !is_token_expired(&guard) {
264            // Token is now fresh, return it
265            debug!("token already refreshed by another thread");
266            return Ok(guard.access_token.clone());
267        }
268
269        // 3. Perform the network request since the token is still stale
270        debug!("performing network request to refresh token");
271
272        let builder = get_stable_access_token(
273            self.inner.client.clone(),
274            &self.inner.app_id,
275            &self.inner.secret,
276            force_refresh,
277        )
278        .await?;
279
280        // 4. Update the token
281        guard.access_token = builder.access_token.clone();
282        guard.expired_at = builder.expired_at;
283
284        debug!("fresh access token: {:#?}", guard);
285
286        // Return the newly fetched token (cloned here for consistency)
287        Ok(guard.access_token.clone())
288    }
289}
290
291#[derive(Debug)]
292struct ClientInner {
293    app_id: String,
294    secret: String,
295    client: reqwest::Client,
296}
297
298fn is_token_expired(token: &AccessToken) -> bool {
299    // 添加安全边界,提前刷新
300    let now = Utc::now();
301    token.expired_at.signed_duration_since(now) < Duration::minutes(5)
302}