wechat_minapp/
client.rs

1use crate::{
2    Result,
3    access_token::{AccessToken, get_access_token, get_stable_access_token},
4    constants,
5    credential::{Credential, CredentialBuilder},
6    error::Error::InternalServer,
7    response::Response,
8};
9use chrono::{Duration, Utc};
10use std::{
11    collections::HashMap,
12    sync::{
13        Arc,
14        atomic::{AtomicBool, Ordering},
15    },
16};
17use tokio::sync::{Notify, RwLock};
18use tracing::{debug, instrument};
19
20/// 存储微信小程序的 appid 和 secret
21#[derive(Debug, Clone)]
22pub struct Client {
23    inner: Arc<ClientInner>,
24    access_token: Arc<RwLock<AccessToken>>,
25    refreshing: Arc<AtomicBool>,
26    notify: Arc<Notify>,
27}
28
29impl Client {
30    /// ```ignore
31    /// use wechat_minapp::Client;
32    ///
33    /// #[tokio::main]
34    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
35    ///     let app_id = "your app id";
36    ///     let secret = "your app secret";
37    ///     
38    ///     let client = Client::new(app_id, secret);
39    ///
40    ///     Ok(())
41    /// }
42    /// ```
43    pub fn new(app_id: &str, secret: &str) -> Self {
44        let client = reqwest::Client::new();
45
46        Self {
47            inner: Arc::new(ClientInner {
48                app_id: app_id.into(),
49                secret: secret.into(),
50                client,
51            }),
52            access_token: Arc::new(RwLock::new(AccessToken {
53                access_token: "".to_string(),
54                expired_at: Utc::now(),
55                force_refresh: None,
56            })),
57            refreshing: Arc::new(AtomicBool::new(false)),
58            notify: Arc::new(Notify::new()),
59        }
60    }
61
62    pub(crate) fn request(&self) -> &reqwest::Client {
63        &self.inner.client
64    }
65
66    /// 登录凭证校验
67    /// https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/user-login/code2Session.html
68    /// ```rust
69    /// use axum::{extract::State, response::IntoResponse, Json};
70    /// use wechat_minapp::{client::Client, Result};
71    /// use serde::Deserialize;
72    ///
73    /// #[derive(Deserialize, Default)]
74    /// #[serde(default)]
75    /// pub(crate) struct Logger {
76    ///     code: String,
77    /// }
78    ///
79    /// pub(crate) async fn login(
80    ///     State(client): State<Client>,
81    ///     Json(logger): Json<Logger>,
82    /// ) -> Result<impl IntoResponse> {
83    ///    let credential = client.login(&logger.code).await?;
84    ///
85    ///     Ok(())
86    /// }
87    /// ```
88    #[instrument(skip(self, code))]
89    pub async fn login(&self, code: &str) -> Result<Credential> {
90        debug!("code: {}", code);
91
92        let mut map: HashMap<&str, &str> = HashMap::new();
93
94        map.insert("appid", &self.inner.app_id);
95        map.insert("secret", &self.inner.secret);
96        map.insert("js_code", code);
97        map.insert("grant_type", "authorization_code");
98
99        let response = self
100            .inner
101            .client
102            .get(constants::AUTHENTICATION_END_POINT)
103            .query(&map)
104            .send()
105            .await?;
106
107        debug!("authentication response: {:#?}", response);
108
109        if response.status().is_success() {
110            let response = response.json::<Response<CredentialBuilder>>().await?;
111
112            let credential = response.extract()?.build();
113
114            debug!("credential: {:#?}", credential);
115
116            Ok(credential)
117        } else {
118            Err(InternalServer(response.text().await?))
119        }
120    }
121
122    /// ```ignore
123    /// use wechat_minapp::Client;
124    ///
125    /// #[tokio::main]
126    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
127    ///     let app_id = "your app id";
128    ///     let secret = "your app secret";
129    ///     
130    ///     let client = Client::new(app_id, secret);
131    ///     let access_token = client.access_token().await?;
132    ///
133    ///     Ok(())
134    /// }
135    /// ```
136    ///
137    pub async fn access_token(&self) -> Result<String> {
138        // 第一次检查:快速路径
139        {
140            let guard = self.access_token.read().await;
141            if !is_token_expired(&guard) {
142                return Ok(guard.access_token.clone());
143            }
144        }
145
146        // 使用CAS竞争刷新权
147        if self
148            .refreshing
149            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
150            .is_ok()
151        {
152            // 获得刷新权
153            match self.refresh_access_token().await {
154                Ok(token) => {
155                    self.refreshing.store(false, Ordering::Release);
156                    self.notify.notify_waiters();
157                    Ok(token)
158                }
159                Err(e) => {
160                    self.refreshing.store(false, Ordering::Release);
161                    self.notify.notify_waiters();
162                    Err(e)
163                }
164            }
165        } else {
166            // 等待其他线程刷新完成
167            self.notify.notified().await;
168            // 刷新完成后重新读取
169            let guard = self.access_token.read().await;
170            Ok(guard.access_token.clone())
171        }
172    }
173
174    async fn refresh_access_token(&self) -> Result<String> {
175        let mut guard = self.access_token.write().await;
176
177        if !is_token_expired(&guard) {
178            debug!("token already refreshed by another thread");
179            return Ok(guard.access_token.clone());
180        }
181
182        debug!("performing network request to refresh token");
183
184        let builder = get_access_token(
185            self.inner.client.clone(),
186            &self.inner.app_id,
187            &self.inner.secret,
188        )
189        .await?;
190
191        guard.access_token = builder.access_token.clone();
192        guard.expired_at = builder.expired_at;
193
194        debug!("fresh access token: {:#?}", guard);
195
196        Ok(guard.access_token.clone())
197    }
198
199    /// ```ignore
200    /// use wechat_minapp::Client;
201    ///
202    /// #[tokio::main]
203    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
204    ///     let app_id = "your app id";
205    ///     let secret = "your app secret";
206    ///     
207    ///     let client = Client::new(app_id, secret);
208    ///     let stable_access_token = client.stable_access_token(Some(true)).await?;
209    ///
210    ///     Ok(())
211    /// }
212    /// ```
213    ///
214    ///
215    pub async fn stable_access_token(
216        &self,
217        force_refresh: impl Into<Option<bool>> + Clone + Send,
218    ) -> Result<String> {
219        // 第一次检查:快速路径
220        {
221            let guard = self.access_token.read().await;
222            if !is_token_expired(&guard) {
223                return Ok(guard.access_token.clone());
224            }
225        }
226
227        // 使用CAS竞争刷新权
228        if self
229            .refreshing
230            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
231            .is_ok()
232        {
233            // 获得刷新权
234            match self.refresh_stable_access_token(force_refresh).await {
235                Ok(token) => {
236                    self.refreshing.store(false, Ordering::Release);
237                    self.notify.notify_waiters();
238                    Ok(token)
239                }
240                Err(e) => {
241                    self.refreshing.store(false, Ordering::Release);
242                    self.notify.notify_waiters();
243                    Err(e)
244                }
245            }
246        } else {
247            // 等待其他线程刷新完成
248            self.notify.notified().await;
249            // 刷新完成后重新读取
250            let guard = self.access_token.read().await;
251            Ok(guard.access_token.clone())
252        }
253    }
254
255    async fn refresh_stable_access_token(
256        &self,
257        force_refresh: impl Into<Option<bool>> + Clone + Send,
258    ) -> Result<String> {
259        // 1. Acquire the write lock. This blocks if another thread won CAS but is refreshing.
260        let mut guard = self.access_token.write().await;
261
262        // 2. Double-check expiration under the write lock (CRITICAL)
263        // If another CAS-winner refreshed the token while we were waiting for the write lock,
264        // we return the new token without performing a new network call.
265        if !is_token_expired(&guard) {
266            // Token is now fresh, return it
267            debug!("token already refreshed by another thread");
268            return Ok(guard.access_token.clone());
269        }
270
271        // 3. Perform the network request since the token is still stale
272        debug!("performing network request to refresh token");
273
274        let builder = get_stable_access_token(
275            self.inner.client.clone(),
276            &self.inner.app_id,
277            &self.inner.secret,
278            force_refresh,
279        )
280        .await?;
281
282        // 4. Update the token
283        guard.access_token = builder.access_token.clone();
284        guard.expired_at = builder.expired_at;
285
286        debug!("fresh access token: {:#?}", guard);
287
288        // Return the newly fetched token (cloned here for consistency)
289        Ok(guard.access_token.clone())
290    }
291}
292
293#[derive(Debug)]
294struct ClientInner {
295    app_id: String,
296    secret: String,
297    client: reqwest::Client,
298}
299
300fn is_token_expired(token: &AccessToken) -> bool {
301    // 添加安全边界,提前刷新
302    let now = Utc::now();
303    token.expired_at.signed_duration_since(now) < Duration::minutes(5)
304}